diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py index 6b3f426..04eef89 100644 --- a/trainer/train_full_sft.py +++ b/trainer/train_full_sft.py @@ -96,7 +96,7 @@ if __name__ == "__main__": parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔") parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔") parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") - parser.add_argument('--num_hidden_layers', default=16, type=int, help="隐藏层数量") + parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)") parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)") parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl", help="训练数据路径") diff --git a/trainer/train_pretrain.py b/trainer/train_pretrain.py index 09eb284..6cbac88 100644 --- a/trainer/train_pretrain.py +++ b/trainer/train_pretrain.py @@ -95,7 +95,7 @@ if __name__ == "__main__": parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔") parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔") parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") - parser.add_argument('--num_hidden_layers', default=16, type=int, help="隐藏层数量") + parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)") parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)") parser.add_argument("--data_path", type=str, default="../dataset/pretrain_hq.jsonl", help="预训练数据路径")