perf: merge LoRA weights into base model for inference

This commit is contained in:
whitesword 2025-12-23 20:55:52 +08:00
parent 11b962da06
commit 41ef5fd8b8
2 changed files with 15 additions and 2 deletions

View File

@ -22,8 +22,7 @@ def init_model(args):
ckp = f'./{args.save_dir}/{args.weight}_{args.hidden_size}{moe_suffix}.pth'
model.load_state_dict(torch.load(ckp, map_location=args.device), strict=True)
if args.lora_weight != 'None':
apply_lora(model)
load_lora(model, f'./{args.save_dir}/lora/{args.lora_weight}_{args.hidden_size}.pth')
merge_lora(model, f'./{args.save_dir}/lora/{args.lora_weight}_{args.hidden_size}.pth')
else:
model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True)
print(f'MiniMind模型参数: {sum(p.numel() for p in model.parameters()) / 1e6:.2f} M(illion)')

View File

@ -41,6 +41,20 @@ def load_lora(model, path):
lora_state = {k.replace(f'{name}.lora.', ''): v for k, v in state_dict.items() if f'{name}.lora.' in k}
module.lora.load_state_dict(lora_state)
def merge_lora(model, path):
state_dict = torch.load(path, map_location=model.device)
# 移除可能的module前缀确保key与模型层级名称一致
state_dict = {(k[7:] if k.startswith('module.') else k): v for k, v in state_dict.items()}
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
key_A = f"{name}.lora.A.weight"
key_B = f"{name}.lora.B.weight"
if key_A in state_dict and key_B in state_dict:
# 直接合并权重: W_new = W_old + B @ A
module.weight.data += state_dict[key_B] @ state_dict[key_A]
return model
def save_lora(model, path):
state_dict = {}