mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-14 04:07:17 +08:00
Apply suggestions from code review
fix some bugs in merge_lora Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
41ef5fd8b8
commit
daf6cc0c2e
@ -47,12 +47,15 @@ def merge_lora(model, path):
|
|||||||
state_dict = {(k[7:] if k.startswith('module.') else k): v for k, v in state_dict.items()}
|
state_dict = {(k[7:] if k.startswith('module.') else k): v for k, v in state_dict.items()}
|
||||||
|
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if isinstance(module, nn.Linear):
|
if isinstance(module, nn.Linear) and module.weight.shape[0] == module.weight.shape[1]:
|
||||||
key_A = f"{name}.lora.A.weight"
|
key_A = f"{name}.lora.A.weight"
|
||||||
key_B = f"{name}.lora.B.weight"
|
key_B = f"{name}.lora.B.weight"
|
||||||
if key_A in state_dict and key_B in state_dict:
|
if key_A in state_dict and key_B in state_dict:
|
||||||
# 直接合并权重: W_new = W_old + B @ A
|
# 直接合并权重: W_new = W_old + B @ A
|
||||||
module.weight.data += state_dict[key_B] @ state_dict[key_A]
|
device = module.weight.device
|
||||||
|
A_weight = state_dict[key_A].to(device)
|
||||||
|
B_weight = state_dict[key_B].to(device)
|
||||||
|
module.weight.data += B_weight @ A_weight
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user