[feat] update args

This commit is contained in:
jingyaogong 2025-10-30 10:48:31 +08:00
parent bf123b585d
commit 08ce3da228
2 changed files with 7 additions and 6 deletions

View File

@ -11,7 +11,7 @@ warnings.filterwarnings('ignore')
def init_model(args):
tokenizer = AutoTokenizer.from_pretrained(args.load_from)
if args.load_from == 'model':
if 'model' in args.load_from:
model = MiniMindForCausalLM(MiniMindConfig(
hidden_size=args.hidden_size,
num_hidden_layers=args.num_hidden_layers,

View File

@ -25,15 +25,16 @@ app = FastAPI()
def init_model(args):
tokenizer = AutoTokenizer.from_pretrained('../model/')
if args.load_from == 'model':
tokenizer = AutoTokenizer.from_pretrained(args.load_from)
if 'model' in args.load_from:
moe_suffix = '_moe' if args.use_moe else ''
ckp = f'../{args.save_dir}/{args.weight}_{args.hidden_size}{moe_suffix}.pth'
model = MiniMindForCausalLM(MiniMindConfig(
hidden_size=args.hidden_size,
num_hidden_layers=args.num_hidden_layers,
max_seq_len=args.max_seq_len,
use_moe=bool(args.use_moe)
use_moe=bool(args.use_moe),
inference_rope_scaling=args.inference_rope_scaling
))
model.load_state_dict(torch.load(ckp, map_location=device), strict=True)
if args.lora_weight != 'None':
@ -41,7 +42,6 @@ def init_model(args):
load_lora(model, f'../{args.save_dir}/lora/{args.lora_weight}_{args.hidden_size}.pth')
else:
model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(args.load_from)
print(f'MiniMind模型参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f} M(illion)')
return model.eval().to(device), tokenizer
@ -161,7 +161,7 @@ async def chat_completions(request: ChatRequest):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Server for MiniMind")
parser.add_argument('--load_from', default='model', type=str, help="模型加载路径model=原生torch权重其他路径=transformers格式")
parser.add_argument('--load_from', default='../model', type=str, help="模型加载路径model=原生torch权重其他路径=transformers格式")
parser.add_argument('--save_dir', default='out', type=str, help="模型权重目录")
parser.add_argument('--weight', default='full_sft', type=str, help="权重名称前缀pretrain, full_sft, dpo, reason, ppo_actor, grpo, spo")
parser.add_argument('--lora_weight', default='None', type=str, help="LoRA权重名称None表示不使用可选lora_identity, lora_medical")
@ -169,6 +169,7 @@ if __name__ == "__main__":
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量Small/MoE=8, Base=16")
parser.add_argument('--max_seq_len', default=8192, type=int, help="最大序列长度")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument('--inference_rope_scaling', default=False, action='store_true', help="启用RoPE位置编码外推4倍仅解决位置编码问题")
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help="运行设备")
args = parser.parse_args()
device = args.device