From c65335b56fa8066ec86c88297fb4af050d9ca014 Mon Sep 17 00:00:00 2001 From: jingyaogong Date: Wed, 31 Dec 2025 21:47:04 +0800 Subject: [PATCH] [fix] experts unused --- model/model_minimind.py | 4 +++- trainer/train_distill_reason.py | 2 +- trainer/train_distillation.py | 2 +- trainer/train_dpo.py | 2 +- trainer/train_full_sft.py | 2 +- trainer/train_grpo.py | 2 +- trainer/train_lora.py | 2 +- trainer/train_ppo.py | 4 ++-- trainer/train_pretrain.py | 2 +- trainer/train_spo.py | 2 +- 10 files changed, 13 insertions(+), 11 deletions(-) diff --git a/model/model_minimind.py b/model/model_minimind.py index f7e49d1..fd40738 100755 --- a/model/model_minimind.py +++ b/model/model_minimind.py @@ -315,7 +315,9 @@ class MOEFeedForward(nn.Module): x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0) y = torch.empty_like(x, dtype=x.dtype) for i, expert in enumerate(self.experts): - y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致 + expert_out = expert(x[flat_topk_idx == i]) + if expert_out.shape[0] > 0: y[flat_topk_idx == i] = expert_out.to(y.dtype) + else: y[flat_topk_idx == i] = expert_out.to(y.dtype) + 0 * sum(p.sum() for p in expert.parameters()) y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) y = y.view(*orig_shape) else: diff --git a/trainer/train_distill_reason.py b/trainer/train_distill_reason.py index 6ff7ea3..be69738 100644 --- a/trainer/train_distill_reason.py +++ b/trainer/train_distill_reason.py @@ -161,7 +161,7 @@ if __name__ == "__main__": # ========== 7. DDP包模型 ========== if dist.is_initialized(): model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False)) + model = DistributedDataParallel(model, device_ids=[local_rank]) # ========== 8. 开始训练 ========== for epoch in range(start_epoch, args.epochs): diff --git a/trainer/train_distillation.py b/trainer/train_distillation.py index c99bbd7..712bac6 100644 --- a/trainer/train_distillation.py +++ b/trainer/train_distillation.py @@ -210,7 +210,7 @@ if __name__ == "__main__": # ========== 7. DDP包模型 ========== if dist.is_initialized(): model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config_student, 'use_moe', False)) + model = DistributedDataParallel(model, device_ids=[local_rank]) # ========== 8. 开始训练 ========== for epoch in range(start_epoch, args.epochs): diff --git a/trainer/train_dpo.py b/trainer/train_dpo.py index d3c3405..b20b53f 100644 --- a/trainer/train_dpo.py +++ b/trainer/train_dpo.py @@ -195,7 +195,7 @@ if __name__ == "__main__": # ========== 7. DDP包模型 ========== if dist.is_initialized(): model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False)) + model = DistributedDataParallel(model, device_ids=[local_rank]) # ========== 8. 开始训练 ========== for epoch in range(start_epoch, args.epochs): diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py index 04eef89..cea2ffb 100644 --- a/trainer/train_full_sft.py +++ b/trainer/train_full_sft.py @@ -149,7 +149,7 @@ if __name__ == "__main__": # ========== 7. DDP包模型 ========== if dist.is_initialized(): model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False)) + model = DistributedDataParallel(model, device_ids=[local_rank]) # ========== 8. 开始训练 ========== for epoch in range(start_epoch, args.epochs): diff --git a/trainer/train_grpo.py b/trainer/train_grpo.py index 580768a..897d9a8 100755 --- a/trainer/train_grpo.py +++ b/trainer/train_grpo.py @@ -272,7 +272,7 @@ if __name__ == "__main__": # ========== 7. DDP包模型 ========== if dist.is_initialized(): model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False)) + model = DistributedDataParallel(model, device_ids=[local_rank]) # ========== 8. 开始训练 ========== for epoch in range(start_epoch, args.epochs): diff --git a/trainer/train_lora.py b/trainer/train_lora.py index 1fba33d..89cb7a9 100644 --- a/trainer/train_lora.py +++ b/trainer/train_lora.py @@ -162,7 +162,7 @@ if __name__ == "__main__": # ========== 8. DDP包模型 ========== if dist.is_initialized(): model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False)) + model = DistributedDataParallel(model, device_ids=[local_rank]) # ========== 9. 开始训练 ========== for epoch in range(start_epoch, args.epochs): diff --git a/trainer/train_ppo.py b/trainer/train_ppo.py index 67cb301..cb0ec38 100644 --- a/trainer/train_ppo.py +++ b/trainer/train_ppo.py @@ -344,8 +344,8 @@ if __name__ == "__main__": if dist.is_initialized(): actor_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} critic_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False)) - critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False)) + actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank]) + critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank]) old_actor_model.to(args.device) # ========== 8. 开始训练 ========== diff --git a/trainer/train_pretrain.py b/trainer/train_pretrain.py index 6cbac88..5f05341 100644 --- a/trainer/train_pretrain.py +++ b/trainer/train_pretrain.py @@ -148,7 +148,7 @@ if __name__ == "__main__": # ========== 7. DDP包模型 ========== if dist.is_initialized(): model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False)) + model = DistributedDataParallel(model, device_ids=[local_rank]) # ========== 8. 开始训练 ========== for epoch in range(start_epoch, args.epochs): diff --git a/trainer/train_spo.py b/trainer/train_spo.py index 700b14e..37493e4 100755 --- a/trainer/train_spo.py +++ b/trainer/train_spo.py @@ -322,7 +322,7 @@ if __name__ == "__main__": # ========== 7. DDP包模型 ========== if dist.is_initialized(): model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False)) + model = DistributedDataParallel(model, device_ids=[local_rank]) # ========== 8. 开始训练 ========== for epoch in range(start_epoch, args.epochs):