mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-25 08:48:16 +08:00
[feat] update args
This commit is contained in:
parent
bf123b585d
commit
08ce3da228
@ -11,7 +11,7 @@ warnings.filterwarnings('ignore')
|
||||
|
||||
def init_model(args):
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.load_from)
|
||||
if args.load_from == 'model':
|
||||
if 'model' in args.load_from:
|
||||
model = MiniMindForCausalLM(MiniMindConfig(
|
||||
hidden_size=args.hidden_size,
|
||||
num_hidden_layers=args.num_hidden_layers,
|
||||
|
||||
@ -25,15 +25,16 @@ app = FastAPI()
|
||||
|
||||
|
||||
def init_model(args):
|
||||
tokenizer = AutoTokenizer.from_pretrained('../model/')
|
||||
if args.load_from == 'model':
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.load_from)
|
||||
if 'model' in args.load_from:
|
||||
moe_suffix = '_moe' if args.use_moe else ''
|
||||
ckp = f'../{args.save_dir}/{args.weight}_{args.hidden_size}{moe_suffix}.pth'
|
||||
model = MiniMindForCausalLM(MiniMindConfig(
|
||||
hidden_size=args.hidden_size,
|
||||
num_hidden_layers=args.num_hidden_layers,
|
||||
max_seq_len=args.max_seq_len,
|
||||
use_moe=bool(args.use_moe)
|
||||
use_moe=bool(args.use_moe),
|
||||
inference_rope_scaling=args.inference_rope_scaling
|
||||
))
|
||||
model.load_state_dict(torch.load(ckp, map_location=device), strict=True)
|
||||
if args.lora_weight != 'None':
|
||||
@ -41,7 +42,6 @@ def init_model(args):
|
||||
load_lora(model, f'../{args.save_dir}/lora/{args.lora_weight}_{args.hidden_size}.pth')
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.load_from)
|
||||
print(f'MiniMind模型参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f} M(illion)')
|
||||
return model.eval().to(device), tokenizer
|
||||
|
||||
@ -161,7 +161,7 @@ async def chat_completions(request: ChatRequest):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Server for MiniMind")
|
||||
parser.add_argument('--load_from', default='model', type=str, help="模型加载路径(model=原生torch权重,其他路径=transformers格式)")
|
||||
parser.add_argument('--load_from', default='../model', type=str, help="模型加载路径(model=原生torch权重,其他路径=transformers格式)")
|
||||
parser.add_argument('--save_dir', default='out', type=str, help="模型权重目录")
|
||||
parser.add_argument('--weight', default='full_sft', type=str, help="权重名称前缀(pretrain, full_sft, dpo, reason, ppo_actor, grpo, spo)")
|
||||
parser.add_argument('--lora_weight', default='None', type=str, help="LoRA权重名称(None表示不使用,可选:lora_identity, lora_medical)")
|
||||
@ -169,6 +169,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量(Small/MoE=8, Base=16)")
|
||||
parser.add_argument('--max_seq_len', default=8192, type=int, help="最大序列长度")
|
||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||
parser.add_argument('--inference_rope_scaling', default=False, action='store_true', help="启用RoPE位置编码外推(4倍,仅解决位置编码问题)")
|
||||
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help="运行设备")
|
||||
args = parser.parse_args()
|
||||
device = args.device
|
||||
|
||||
Loading…
Reference in New Issue
Block a user