mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
[fix] experts unused
This commit is contained in:
parent
bc8fd82166
commit
c65335b56f
@ -315,7 +315,9 @@ class MOEFeedForward(nn.Module):
|
||||
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
|
||||
y = torch.empty_like(x, dtype=x.dtype)
|
||||
for i, expert in enumerate(self.experts):
|
||||
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
|
||||
expert_out = expert(x[flat_topk_idx == i])
|
||||
if expert_out.shape[0] > 0: y[flat_topk_idx == i] = expert_out.to(y.dtype)
|
||||
else: y[flat_topk_idx == i] = expert_out.to(y.dtype) + 0 * sum(p.sum() for p in expert.parameters())
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
y = y.view(*orig_shape)
|
||||
else:
|
||||
|
||||
@ -161,7 +161,7 @@ if __name__ == "__main__":
|
||||
# ========== 7. DDP包模型 ==========
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
|
||||
@ -210,7 +210,7 @@ if __name__ == "__main__":
|
||||
# ========== 7. DDP包模型 ==========
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config_student, 'use_moe', False))
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
|
||||
@ -195,7 +195,7 @@ if __name__ == "__main__":
|
||||
# ========== 7. DDP包模型 ==========
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
|
||||
@ -149,7 +149,7 @@ if __name__ == "__main__":
|
||||
# ========== 7. DDP包模型 ==========
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
|
||||
@ -272,7 +272,7 @@ if __name__ == "__main__":
|
||||
# ========== 7. DDP包模型 ==========
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
|
||||
@ -162,7 +162,7 @@ if __name__ == "__main__":
|
||||
# ========== 8. DDP包模型 ==========
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 9. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
|
||||
@ -344,8 +344,8 @@ if __name__ == "__main__":
|
||||
if dist.is_initialized():
|
||||
actor_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
critic_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
|
||||
critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
|
||||
actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank])
|
||||
critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank])
|
||||
old_actor_model.to(args.device)
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
|
||||
@ -148,7 +148,7 @@ if __name__ == "__main__":
|
||||
# ========== 7. DDP包模型 ==========
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
|
||||
@ -322,7 +322,7 @@ if __name__ == "__main__":
|
||||
# ========== 7. DDP包模型 ==========
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user