[fix] moe unused

This commit is contained in:
jingyaogong 2025-12-31 21:00:06 +08:00
parent 9236260a4a
commit 5dd4df7e18
9 changed files with 12 additions and 12 deletions

View File

@ -161,7 +161,7 @@ if __name__ == "__main__":
# ========== 7. DDP包模型 ========== # ========== 7. DDP包模型 ==========
if dist.is_initialized(): if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank]) model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
# ========== 8. 开始训练 ========== # ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):

View File

@ -210,7 +210,7 @@ if __name__ == "__main__":
# ========== 7. DDP包模型 ========== # ========== 7. DDP包模型 ==========
if dist.is_initialized(): if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank]) model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config_student, 'use_moe', False))
# ========== 8. 开始训练 ========== # ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):

View File

@ -195,7 +195,7 @@ if __name__ == "__main__":
# ========== 7. DDP包模型 ========== # ========== 7. DDP包模型 ==========
if dist.is_initialized(): if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank]) model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
# ========== 8. 开始训练 ========== # ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):

View File

@ -96,7 +96,7 @@ if __name__ == "__main__":
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔") parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔") parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") parser.add_argument('--num_hidden_layers', default=16, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度中文1token≈1.5~1.7字符)") parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度中文1token≈1.5~1.7字符)")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)") parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl", help="训练数据路径") parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl", help="训练数据路径")
@ -149,7 +149,7 @@ if __name__ == "__main__":
# ========== 7. DDP包模型 ========== # ========== 7. DDP包模型 ==========
if dist.is_initialized(): if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank]) model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
# ========== 8. 开始训练 ========== # ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):

View File

@ -272,7 +272,7 @@ if __name__ == "__main__":
# ========== 7. DDP包模型 ========== # ========== 7. DDP包模型 ==========
if dist.is_initialized(): if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank]) model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
# ========== 8. 开始训练 ========== # ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):

View File

@ -162,7 +162,7 @@ if __name__ == "__main__":
# ========== 8. DDP包模型 ========== # ========== 8. DDP包模型 ==========
if dist.is_initialized(): if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank]) model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
# ========== 9. 开始训练 ========== # ========== 9. 开始训练 ==========
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):

View File

@ -344,8 +344,8 @@ if __name__ == "__main__":
if dist.is_initialized(): if dist.is_initialized():
actor_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} actor_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
critic_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} critic_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank]) actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank]) critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
old_actor_model.to(args.device) old_actor_model.to(args.device)
# ========== 8. 开始训练 ========== # ========== 8. 开始训练 ==========

View File

@ -95,7 +95,7 @@ if __name__ == "__main__":
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔") parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔") parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") parser.add_argument('--num_hidden_layers', default=16, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度中文1token≈1.5~1.7字符)") parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度中文1token≈1.5~1.7字符)")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)") parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument("--data_path", type=str, default="../dataset/pretrain_hq.jsonl", help="预训练数据路径") parser.add_argument("--data_path", type=str, default="../dataset/pretrain_hq.jsonl", help="预训练数据路径")
@ -148,7 +148,7 @@ if __name__ == "__main__":
# ========== 7. DDP包模型 ========== # ========== 7. DDP包模型 ==========
if dist.is_initialized(): if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank]) model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
# ========== 8. 开始训练 ========== # ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):

View File

@ -322,7 +322,7 @@ if __name__ == "__main__":
# ========== 7. DDP包模型 ========== # ========== 7. DDP包模型 ==========
if dist.is_initialized(): if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank]) model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
# ========== 8. 开始训练 ========== # ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):