From 41ef5fd8b88d0410a8c831ea8b8ebb619d3f6554 Mon Sep 17 00:00:00 2001 From: whitesword Date: Tue, 23 Dec 2025 20:55:52 +0800 Subject: [PATCH 1/2] perf: merge LoRA weights into base model for inference --- eval_llm.py | 3 +-- model/model_lora.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/eval_llm.py b/eval_llm.py index 8af17d4..6955697 100755 --- a/eval_llm.py +++ b/eval_llm.py @@ -22,8 +22,7 @@ def init_model(args): ckp = f'./{args.save_dir}/{args.weight}_{args.hidden_size}{moe_suffix}.pth' model.load_state_dict(torch.load(ckp, map_location=args.device), strict=True) if args.lora_weight != 'None': - apply_lora(model) - load_lora(model, f'./{args.save_dir}/lora/{args.lora_weight}_{args.hidden_size}.pth') + merge_lora(model, f'./{args.save_dir}/lora/{args.lora_weight}_{args.hidden_size}.pth') else: model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True) print(f'MiniMind模型参数: {sum(p.numel() for p in model.parameters()) / 1e6:.2f} M(illion)') diff --git a/model/model_lora.py b/model/model_lora.py index c675f3b..9fe8419 100644 --- a/model/model_lora.py +++ b/model/model_lora.py @@ -41,6 +41,20 @@ def load_lora(model, path): lora_state = {k.replace(f'{name}.lora.', ''): v for k, v in state_dict.items() if f'{name}.lora.' in k} module.lora.load_state_dict(lora_state) +def merge_lora(model, path): + state_dict = torch.load(path, map_location=model.device) + # 移除可能的module前缀,确保key与模型层级名称一致 + state_dict = {(k[7:] if k.startswith('module.') else k): v for k, v in state_dict.items()} + + for name, module in model.named_modules(): + if isinstance(module, nn.Linear): + key_A = f"{name}.lora.A.weight" + key_B = f"{name}.lora.B.weight" + if key_A in state_dict and key_B in state_dict: + # 直接合并权重: W_new = W_old + B @ A + module.weight.data += state_dict[key_B] @ state_dict[key_A] + + return model def save_lora(model, path): state_dict = {} From daf6cc0c2e55f511be0ea0cc6c8bcd74dbc103bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AD=90=E6=B5=A9?= <95196117+whiteswordLI@users.noreply.github.com> Date: Wed, 24 Dec 2025 16:14:18 +0800 Subject: [PATCH 2/2] Apply suggestions from code review fix some bugs in merge_lora Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- model/model_lora.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/model/model_lora.py b/model/model_lora.py index 9fe8419..fa80012 100644 --- a/model/model_lora.py +++ b/model/model_lora.py @@ -47,12 +47,15 @@ def merge_lora(model, path): state_dict = {(k[7:] if k.startswith('module.') else k): v for k, v in state_dict.items()} for name, module in model.named_modules(): - if isinstance(module, nn.Linear): + if isinstance(module, nn.Linear) and module.weight.shape[0] == module.weight.shape[1]: key_A = f"{name}.lora.A.weight" key_B = f"{name}.lora.B.weight" if key_A in state_dict and key_B in state_dict: # 直接合并权重: W_new = W_old + B @ A - module.weight.data += state_dict[key_B] @ state_dict[key_A] + device = module.weight.device + A_weight = state_dict[key_A].to(device) + B_weight = state_dict[key_B].to(device) + module.weight.data += B_weight @ A_weight return model