[feat] add args

This commit is contained in:
jingyaogong 2025-10-30 10:05:12 +08:00
parent 800fed4639
commit bf123b585d
11 changed files with 57 additions and 57 deletions

View File

@ -2,6 +2,7 @@ import argparse
import random
import warnings
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from model.model_lora import *
@ -14,7 +15,7 @@ def init_model(args):
model = MiniMindForCausalLM(MiniMindConfig(
hidden_size=args.hidden_size,
num_hidden_layers=args.num_hidden_layers,
use_moe=args.use_moe,
use_moe=bool(args.use_moe),
inference_rope_scaling=args.inference_rope_scaling
))
moe_suffix = '_moe' if args.use_moe else ''
@ -36,7 +37,7 @@ def main():
parser.add_argument('--lora_weight', default='None', type=str, help="LoRA权重名称None表示不使用可选lora_identity, lora_medical")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度512=Small-26M, 640=MoE-145M, 768=Base-104M")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量Small/MoE=8, Base=16")
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE架构")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument('--inference_rope_scaling', default=False, action='store_true', help="启用RoPE位置编码外推4倍仅解决位置编码问题")
parser.add_argument('--max_new_tokens', default=8192, type=int, help="最大生成长度(注意:并非模型实际长文本能力)")
parser.add_argument('--temperature', default=0.85, type=float, help="生成温度控制随机性0-1越大越随机")

View File

@ -25,26 +25,24 @@ app = FastAPI()
def init_model(args):
if args.load == 0:
tokenizer = AutoTokenizer.from_pretrained('../model/')
moe_path = '_moe' if args.use_moe else ''
modes = {0: 'pretrain', 1: 'full_sft', 2: 'rlhf', 3: 'reason'}
ckp = f'../{args.out_dir}/{modes[args.model_mode]}_{args.hidden_size}{moe_path}.pth'
tokenizer = AutoTokenizer.from_pretrained('../model/')
if args.load_from == 'model':
moe_suffix = '_moe' if args.use_moe else ''
ckp = f'../{args.save_dir}/{args.weight}_{args.hidden_size}{moe_suffix}.pth'
model = MiniMindForCausalLM(MiniMindConfig(
hidden_size=args.hidden_size,
num_hidden_layers=args.num_hidden_layers,
max_seq_len=args.max_seq_len,
use_moe=args.use_moe
use_moe=bool(args.use_moe)
))
model.load_state_dict(torch.load(ckp, map_location=device), strict=True)
if args.lora_name != 'None':
if args.lora_weight != 'None':
apply_lora(model)
load_lora(model, f'../{args.out_dir}/{args.lora_name}_{args.hidden_size}.pth')
load_lora(model, f'../{args.save_dir}/lora/{args.lora_weight}_{args.hidden_size}.pth')
else:
model_path = '../MiniMind2'
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path)
print(f'MiniMind模型参数量: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f}M(illion)')
model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(args.load_from)
print(f'MiniMind模型参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f} M(illion)')
return model.eval().to(device), tokenizer
@ -163,15 +161,16 @@ async def chat_completions(request: ChatRequest):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Server for MiniMind")
parser.add_argument('--out_dir', default='out', type=str)
parser.add_argument('--lora_name', default='None', type=str)
parser.add_argument('--hidden_size', default=768, type=int)
parser.add_argument('--num_hidden_layers', default=16, type=int)
parser.add_argument('--max_seq_len', default=8192, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument('--load', default=0, type=int, help="0: 从原生torch权重1: 利用transformers加载")
parser.add_argument('--model_mode', default=1, type=int,
help="0: 预训练模型1: SFT-Chat模型2: RLHF-Chat模型3: Reason模型")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model, tokenizer = init_model(parser.parse_args())
parser.add_argument('--load_from', default='model', type=str, help="模型加载路径model=原生torch权重其他路径=transformers格式")
parser.add_argument('--save_dir', default='out', type=str, help="模型权重目录")
parser.add_argument('--weight', default='full_sft', type=str, help="权重名称前缀pretrain, full_sft, dpo, reason, ppo_actor, grpo, spo")
parser.add_argument('--lora_weight', default='None', type=str, help="LoRA权重名称None表示不使用可选lora_identity, lora_medical")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度512=Small-26M, 640=MoE-145M, 768=Base-104M")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量Small/MoE=8, Base=16")
parser.add_argument('--max_seq_len', default=8192, type=int, help="最大序列长度")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help="运行设备")
args = parser.parse_args()
device = args.device
model, tokenizer = init_model(args)
uvicorn.run(app, host="0.0.0.0", port=8998)

View File

@ -107,10 +107,10 @@ if __name__ == "__main__":
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=1024, type=int, help="训练的最大截断长度")
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument("--data_path", type=str, default="../dataset/r1_mix_1024.jsonl", help="推理蒸馏数据路径")
parser.add_argument('--from_weight', default='dpo', type=str, help="基于哪个权重训练默认dpo")
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训0否1是")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Reasoning", help="wandb项目名")
args = parser.parse_args()
@ -122,7 +122,7 @@ if __name__ == "__main__":
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========

View File

@ -148,12 +148,12 @@ if __name__ == "__main__":
parser.add_argument('--student_num_layers', default=8, type=int, help="学生模型隐藏层数量")
parser.add_argument('--teacher_hidden_size', default=768, type=int, help="教师模型隐藏层维度")
parser.add_argument('--teacher_num_layers', default=16, type=int, help="教师模型隐藏层数量")
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument('--from_student_weight', default='full_sft', type=str, help="学生模型基于哪个权重")
parser.add_argument('--from_teacher_weight', default='full_sft', type=str, help="教师模型基于哪个权重")
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训0否1是")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument('--alpha', default=0.5, type=float, help="CE损失权重总损失=alpha*CE+(1-alpha)*KL")
parser.add_argument('--temperature', default=2.0, type=float, help="蒸馏温度")
parser.add_argument('--temperature', default=1.5, type=float, help="蒸馏温度推荐范围1.0-2.0")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Distillation", help="wandb项目名")
args = parser.parse_args()
@ -165,8 +165,8 @@ if __name__ == "__main__":
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config_student = MiniMindConfig(hidden_size=args.student_hidden_size, num_hidden_layers=args.student_num_layers, use_moe=args.use_moe)
lm_config_teacher = MiniMindConfig(hidden_size=args.teacher_hidden_size, num_hidden_layers=args.teacher_num_layers, use_moe=args.use_moe)
lm_config_student = MiniMindConfig(hidden_size=args.student_hidden_size, num_hidden_layers=args.student_num_layers, use_moe=bool(args.use_moe))
lm_config_teacher = MiniMindConfig(hidden_size=args.teacher_hidden_size, num_hidden_layers=args.teacher_num_layers, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config_student, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========

View File

@ -132,10 +132,10 @@ if __name__ == "__main__":
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=1024, type=int, help="训练的最大截断长度")
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument("--data_path", type=str, default="../dataset/dpo.jsonl", help="DPO训练数据路径")
parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练")
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训0否1是")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument('--beta', default=0.1, type=float, help="DPO中的beta参数")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-DPO", help="wandb项目名")
@ -148,7 +148,7 @@ if __name__ == "__main__":
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========

View File

@ -95,10 +95,10 @@ if __name__ == "__main__":
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=512, type=int, help="训练的最大截断长度")
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl", help="训练数据路径")
parser.add_argument('--from_weight', default='pretrain', type=str, help="基于哪个权重训练为none则不基于任何权重训练")
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训0否1是")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT", help="wandb项目名")
args = parser.parse_args()
@ -110,7 +110,7 @@ if __name__ == "__main__":
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========

View File

@ -201,15 +201,15 @@ if __name__ == "__main__":
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度")
parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度")
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径")
parser.add_argument("--num_generations", type=int, default=8, help="每个prompt生成的样本数")
parser.add_argument("--beta", type=float, default=0.02, help="KL惩罚系数")
parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型1:推理模型')
parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型0=普通模型1=推理模型)')
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训0否1是")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-GRPO", help="wandb项目名")
args = parser.parse_args()
@ -222,7 +222,7 @@ if __name__ == "__main__":
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=args.use_moe)
max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========

View File

@ -90,10 +90,10 @@ if __name__ == "__main__":
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=512, type=int, help="训练的最大截断长度")
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument("--data_path", type=str, default="../dataset/lora_identity.jsonl", help="LoRA训练数据路径")
parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练默认full_sft")
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训0否1是")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA", help="wandb项目名")
args = parser.parse_args()
@ -105,7 +105,7 @@ if __name__ == "__main__":
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.lora_name, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========

View File

@ -250,17 +250,17 @@ if __name__ == "__main__":
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度")
parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度")
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径")
parser.add_argument("--clip_epsilon", type=float, default=0.1, help="PPO裁剪参数")
parser.add_argument("--vf_coef", type=float, default=0.5, help="Value function系数")
parser.add_argument("--kl_coef", type=float, default=0.02, help="KL散度惩罚系数")
parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型1:推理模型')
parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型0=普通模型1=推理模型)')
parser.add_argument("--update_old_actor_freq", type=int, default=4, help="更新old_actor_model的频率")
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训0否1是")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-PPO", help="wandb项目名")
args = parser.parse_args()
@ -272,7 +272,7 @@ if __name__ == "__main__":
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========

View File

@ -94,10 +94,10 @@ if __name__ == "__main__":
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=512, type=int, help="训练的最大截断长度")
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument("--data_path", type=str, default="../dataset/pretrain_hq.jsonl", help="预训练数据路径")
parser.add_argument('--from_weight', default='none', type=str, help="基于哪个权重训练为none则从头开始")
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训0否1是")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain", help="wandb项目名")
args = parser.parse_args()
@ -109,7 +109,7 @@ if __name__ == "__main__":
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========

View File

@ -249,14 +249,14 @@ if __name__ == "__main__":
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度")
parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度")
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径")
parser.add_argument("--beta", type=float, default=0.02, help="KL惩罚系数")
parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型1:推理模型')
parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型0=普通模型1=推理模型)')
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训0否1是")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-SPO", help="wandb项目名")
args = parser.parse_args()
@ -269,7 +269,7 @@ if __name__ == "__main__":
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=args.use_moe)
max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========