mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
[fix] moe unused
This commit is contained in:
parent
9236260a4a
commit
5dd4df7e18
@ -161,7 +161,7 @@ if __name__ == "__main__":
|
|||||||
# ========== 7. DDP包模型 ==========
|
# ========== 7. DDP包模型 ==========
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
|
||||||
|
|
||||||
# ========== 8. 开始训练 ==========
|
# ========== 8. 开始训练 ==========
|
||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
|
|||||||
@ -210,7 +210,7 @@ if __name__ == "__main__":
|
|||||||
# ========== 7. DDP包模型 ==========
|
# ========== 7. DDP包模型 ==========
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config_student, 'use_moe', False))
|
||||||
|
|
||||||
# ========== 8. 开始训练 ==========
|
# ========== 8. 开始训练 ==========
|
||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
|
|||||||
@ -195,7 +195,7 @@ if __name__ == "__main__":
|
|||||||
# ========== 7. DDP包模型 ==========
|
# ========== 7. DDP包模型 ==========
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
|
||||||
|
|
||||||
# ========== 8. 开始训练 ==========
|
# ========== 8. 开始训练 ==========
|
||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
|
|||||||
@ -96,7 +96,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
|
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
|
||||||
parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔")
|
parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔")
|
||||||
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
|
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
|
||||||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
parser.add_argument('--num_hidden_layers', default=16, type=int, help="隐藏层数量")
|
||||||
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)")
|
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)")
|
||||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||||
parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl", help="训练数据路径")
|
parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl", help="训练数据路径")
|
||||||
@ -149,7 +149,7 @@ if __name__ == "__main__":
|
|||||||
# ========== 7. DDP包模型 ==========
|
# ========== 7. DDP包模型 ==========
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
|
||||||
|
|
||||||
# ========== 8. 开始训练 ==========
|
# ========== 8. 开始训练 ==========
|
||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
|
|||||||
@ -272,7 +272,7 @@ if __name__ == "__main__":
|
|||||||
# ========== 7. DDP包模型 ==========
|
# ========== 7. DDP包模型 ==========
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
|
||||||
|
|
||||||
# ========== 8. 开始训练 ==========
|
# ========== 8. 开始训练 ==========
|
||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
|
|||||||
@ -162,7 +162,7 @@ if __name__ == "__main__":
|
|||||||
# ========== 8. DDP包模型 ==========
|
# ========== 8. DDP包模型 ==========
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
|
||||||
|
|
||||||
# ========== 9. 开始训练 ==========
|
# ========== 9. 开始训练 ==========
|
||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
|
|||||||
@ -344,8 +344,8 @@ if __name__ == "__main__":
|
|||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
actor_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
actor_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||||
critic_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
critic_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||||
actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank])
|
actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
|
||||||
critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank])
|
critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
|
||||||
old_actor_model.to(args.device)
|
old_actor_model.to(args.device)
|
||||||
|
|
||||||
# ========== 8. 开始训练 ==========
|
# ========== 8. 开始训练 ==========
|
||||||
|
|||||||
@ -95,7 +95,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
|
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
|
||||||
parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔")
|
parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔")
|
||||||
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
|
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
|
||||||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
parser.add_argument('--num_hidden_layers', default=16, type=int, help="隐藏层数量")
|
||||||
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)")
|
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)")
|
||||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||||
parser.add_argument("--data_path", type=str, default="../dataset/pretrain_hq.jsonl", help="预训练数据路径")
|
parser.add_argument("--data_path", type=str, default="../dataset/pretrain_hq.jsonl", help="预训练数据路径")
|
||||||
@ -148,7 +148,7 @@ if __name__ == "__main__":
|
|||||||
# ========== 7. DDP包模型 ==========
|
# ========== 7. DDP包模型 ==========
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
|
||||||
|
|
||||||
# ========== 8. 开始训练 ==========
|
# ========== 8. 开始训练 ==========
|
||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
|
|||||||
@ -322,7 +322,7 @@ if __name__ == "__main__":
|
|||||||
# ========== 7. DDP包模型 ==========
|
# ========== 7. DDP包模型 ==========
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=getattr(lm_config, 'use_moe', False))
|
||||||
|
|
||||||
# ========== 8. 开始训练 ==========
|
# ========== 8. 开始训练 ==========
|
||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user