mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-06-06 00:04:50 +00:00
[update] minimind-3
This commit is contained in:
+14
-11
@@ -23,7 +23,7 @@ def init_model(args):
|
||||
model.load_state_dict(torch.load(ckp, map_location=args.device), strict=True)
|
||||
if args.lora_weight != 'None':
|
||||
apply_lora(model)
|
||||
load_lora(model, f'./{args.save_dir}/lora/{args.lora_weight}_{args.hidden_size}.pth')
|
||||
load_lora(model, f'./{args.save_dir}/{args.lora_weight}_{args.hidden_size}.pth')
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True)
|
||||
get_model_params(model, model.config)
|
||||
@@ -35,13 +35,14 @@ def main():
|
||||
parser.add_argument('--save_dir', default='out', type=str, help="模型权重目录")
|
||||
parser.add_argument('--weight', default='full_sft', type=str, help="权重名称前缀(pretrain, full_sft, rlhf, reason, ppo_actor, grpo, spo)")
|
||||
parser.add_argument('--lora_weight', default='None', type=str, help="LoRA权重名称(None表示不使用,可选:lora_identity, lora_medical)")
|
||||
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度(512=Small-26M, 640=MoE-145M, 768=Base-104M)")
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量(Small/MoE=8, Base=16)")
|
||||
parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度")
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||
parser.add_argument('--inference_rope_scaling', default=False, action='store_true', help="启用RoPE位置编码外推(4倍,仅解决位置编码问题)")
|
||||
parser.add_argument('--max_new_tokens', default=8192, type=int, help="最大生成长度(注意:并非模型实际长文本能力)")
|
||||
parser.add_argument('--temperature', default=0.85, type=float, help="生成温度,控制随机性(0-1,越大越随机)")
|
||||
parser.add_argument('--top_p', default=0.85, type=float, help="nucleus采样阈值(0-1)")
|
||||
parser.add_argument('--top_p', default=0.95, type=float, help="nucleus采样阈值(0-1)")
|
||||
parser.add_argument('--open_thinking', default=0, type=int, help="是否开启自适应思考(0=否,1=是)")
|
||||
parser.add_argument('--historys', default=0, type=int, help="携带历史对话轮数(需为偶数,0表示不携带历史)")
|
||||
parser.add_argument('--show_speed', default=1, type=int, help="显示decode速度(tokens/s)")
|
||||
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help="运行设备")
|
||||
@@ -65,23 +66,25 @@ def main():
|
||||
|
||||
prompt_iter = prompts if input_mode == 0 else iter(lambda: input('💬: '), '')
|
||||
for prompt in prompt_iter:
|
||||
setup_seed(2026) # or setup_seed(random.randint(0, 2048))
|
||||
setup_seed(random.randint(0, 31415926))
|
||||
setup_seed(42)
|
||||
if input_mode == 0: print(f'💬: {prompt}')
|
||||
conversation = conversation[-args.historys:] if args.historys else []
|
||||
conversation.append({"role": "user", "content": prompt})
|
||||
|
||||
templates = {"conversation": conversation, "tokenize": False, "add_generation_prompt": True}
|
||||
if args.weight == 'reason': templates["enable_thinking"] = True # 仅Reason模型使用
|
||||
inputs = tokenizer.apply_chat_template(**templates) if args.weight != 'pretrain' else (tokenizer.bos_token + prompt)
|
||||
if 'pretrain' in args.weight:
|
||||
inputs = tokenizer.bos_token + prompt
|
||||
else:
|
||||
inputs = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True, open_thinking=bool(args.open_thinking))
|
||||
|
||||
inputs = tokenizer(inputs, return_tensors="pt", truncation=True).to(args.device)
|
||||
|
||||
print('🤖: ', end='')
|
||||
print('🧠: ', end='')
|
||||
st = time.time()
|
||||
generated_ids = model.generate(
|
||||
inputs=inputs["input_ids"], attention_mask=inputs["attention_mask"],
|
||||
max_new_tokens=args.max_new_tokens, do_sample=True, streamer=streamer,
|
||||
pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id,
|
||||
top_p=args.top_p, temperature=args.temperature, repetition_penalty=1.0
|
||||
top_p=args.top_p, temperature=args.temperature, repetition_penalty=1
|
||||
)
|
||||
response = tokenizer.decode(generated_ids[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
|
||||
conversation.append({"role": "assistant", "content": response})
|
||||
|
||||
Reference in New Issue
Block a user