mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-25 08:48:16 +08:00
Merge daf6cc0c2e into 83e52f6a27
This commit is contained in:
commit
03a71c9463
@ -22,8 +22,7 @@ def init_model(args):
|
||||
ckp = f'./{args.save_dir}/{args.weight}_{args.hidden_size}{moe_suffix}.pth'
|
||||
model.load_state_dict(torch.load(ckp, map_location=args.device), strict=True)
|
||||
if args.lora_weight != 'None':
|
||||
apply_lora(model)
|
||||
load_lora(model, f'./{args.save_dir}/lora/{args.lora_weight}_{args.hidden_size}.pth')
|
||||
merge_lora(model, f'./{args.save_dir}/lora/{args.lora_weight}_{args.hidden_size}.pth')
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True)
|
||||
get_model_params(model, model.config)
|
||||
|
||||
@ -41,6 +41,23 @@ def load_lora(model, path):
|
||||
lora_state = {k.replace(f'{name}.lora.', ''): v for k, v in state_dict.items() if f'{name}.lora.' in k}
|
||||
module.lora.load_state_dict(lora_state)
|
||||
|
||||
def merge_lora(model, path):
|
||||
state_dict = torch.load(path, map_location=model.device)
|
||||
# 移除可能的module前缀,确保key与模型层级名称一致
|
||||
state_dict = {(k[7:] if k.startswith('module.') else k): v for k, v in state_dict.items()}
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, nn.Linear) and module.weight.shape[0] == module.weight.shape[1]:
|
||||
key_A = f"{name}.lora.A.weight"
|
||||
key_B = f"{name}.lora.B.weight"
|
||||
if key_A in state_dict and key_B in state_dict:
|
||||
# 直接合并权重: W_new = W_old + B @ A
|
||||
device = module.weight.device
|
||||
A_weight = state_dict[key_A].to(device)
|
||||
B_weight = state_dict[key_B].to(device)
|
||||
module.weight.data += B_weight @ A_weight
|
||||
|
||||
return model
|
||||
|
||||
def save_lora(model, path):
|
||||
raw_model = getattr(model, '_orig_mod', model)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user