From 5dd4df7e182e90c88f78b88a586df271be1a3e98 Mon Sep 17 00:00:00 2001 From: jingyaogong Date: Wed, 31 Dec 2025 21:00:06 +0800 Subject: [PATCH] [fix] moe unused --- trainer/train_distill_reason.py | 2 +- trainer/train_distillation.py | 2 +- trainer/train_dpo.py | 2 +- trainer/train_full_sft.py | 4 ++-- trainer/train_grpo.py | 2 +- trainer/train_lora.py | 2 +- trainer/train_ppo.py | 4 ++-- trainer/train_pretrain.py | 4 ++-- trainer/train_spo.py | 2 +- 9 files changed, 12 insertions(+), 12 deletions(-) diff --git a/trainer/train_distill_reason.py b/trainer/train_distill_reason.py index be69738..6ff7ea3 100644 --- a/trainer/train_distill_reason.py +++ b/trainer/train_distill_reason.py @@ -161,7 +161,7 @@ if __name__ == "__main__": # ========== 7. DDP包模型 ========== if dist.is_initialized(): model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[local_rank]) + model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False)) # ========== 8. 开始训练 ========== for epoch in range(start_epoch, args.epochs): diff --git a/trainer/train_distillation.py b/trainer/train_distillation.py index 712bac6..c99bbd7 100644 --- a/trainer/train_distillation.py +++ b/trainer/train_distillation.py @@ -210,7 +210,7 @@ if __name__ == "__main__": # ========== 7. DDP包模型 ========== if dist.is_initialized(): model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[local_rank]) + model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config_student, 'use_moe', False)) # ========== 8. 开始训练 ========== for epoch in range(start_epoch, args.epochs): diff --git a/trainer/train_dpo.py b/trainer/train_dpo.py index b20b53f..d3c3405 100644 --- a/trainer/train_dpo.py +++ b/trainer/train_dpo.py @@ -195,7 +195,7 @@ if __name__ == "__main__": # ========== 7. DDP包模型 ========== if dist.is_initialized(): model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[local_rank]) + model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False)) # ========== 8. 开始训练 ========== for epoch in range(start_epoch, args.epochs): diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py index cea2ffb..6b3f426 100644 --- a/trainer/train_full_sft.py +++ b/trainer/train_full_sft.py @@ -96,7 +96,7 @@ if __name__ == "__main__": parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔") parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔") parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") - parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") + parser.add_argument('--num_hidden_layers', default=16, type=int, help="隐藏层数量") parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)") parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)") parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl", help="训练数据路径") @@ -149,7 +149,7 @@ if __name__ == "__main__": # ========== 7. DDP包模型 ========== if dist.is_initialized(): model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[local_rank]) + model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False)) # ========== 8. 开始训练 ========== for epoch in range(start_epoch, args.epochs): diff --git a/trainer/train_grpo.py b/trainer/train_grpo.py index 897d9a8..580768a 100755 --- a/trainer/train_grpo.py +++ b/trainer/train_grpo.py @@ -272,7 +272,7 @@ if __name__ == "__main__": # ========== 7. DDP包模型 ========== if dist.is_initialized(): model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[local_rank]) + model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False)) # ========== 8. 开始训练 ========== for epoch in range(start_epoch, args.epochs): diff --git a/trainer/train_lora.py b/trainer/train_lora.py index 89cb7a9..1fba33d 100644 --- a/trainer/train_lora.py +++ b/trainer/train_lora.py @@ -162,7 +162,7 @@ if __name__ == "__main__": # ========== 8. DDP包模型 ========== if dist.is_initialized(): model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[local_rank]) + model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False)) # ========== 9. 开始训练 ========== for epoch in range(start_epoch, args.epochs): diff --git a/trainer/train_ppo.py b/trainer/train_ppo.py index cb0ec38..67cb301 100644 --- a/trainer/train_ppo.py +++ b/trainer/train_ppo.py @@ -344,8 +344,8 @@ if __name__ == "__main__": if dist.is_initialized(): actor_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} critic_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank]) - critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank]) + actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False)) + critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False)) old_actor_model.to(args.device) # ========== 8. 开始训练 ========== diff --git a/trainer/train_pretrain.py b/trainer/train_pretrain.py index 5f05341..09eb284 100644 --- a/trainer/train_pretrain.py +++ b/trainer/train_pretrain.py @@ -95,7 +95,7 @@ if __name__ == "__main__": parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔") parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔") parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") - parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") + parser.add_argument('--num_hidden_layers', default=16, type=int, help="隐藏层数量") parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)") parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)") parser.add_argument("--data_path", type=str, default="../dataset/pretrain_hq.jsonl", help="预训练数据路径") @@ -148,7 +148,7 @@ if __name__ == "__main__": # ========== 7. DDP包模型 ========== if dist.is_initialized(): model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[local_rank]) + model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False)) # ========== 8. 开始训练 ========== for epoch in range(start_epoch, args.epochs): diff --git a/trainer/train_spo.py b/trainer/train_spo.py index 37493e4..700b14e 100755 --- a/trainer/train_spo.py +++ b/trainer/train_spo.py @@ -322,7 +322,7 @@ if __name__ == "__main__": # ========== 7. DDP包模型 ========== if dist.is_initialized(): model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[local_rank]) + model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False)) # ========== 8. 开始训练 ========== for epoch in range(start_epoch, args.epochs):