[fix] lora weight

This commit is contained in:
jingyaogong 2025-12-22 21:27:29 +08:00
parent 048d84abc7
commit a9c56b20e9

View File

@ -34,15 +34,7 @@ def apply_lora(model, rank=8):
def load_lora(model, path): def load_lora(model, path):
state_dict = torch.load(path, map_location=model.device) state_dict = torch.load(path, map_location=model.device)
state_dict = {(k[7:] if k.startswith('module.') else k): v for k, v in state_dict.items()}
# 兼容DDP训练保存的权重带有module.前缀),去除前缀以匹配当前模型
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
new_state_dict[k[7:]] = v
else:
new_state_dict[k] = v
state_dict = new_state_dict
for name, module in model.named_modules(): for name, module in model.named_modules():
if hasattr(module, 'lora'): if hasattr(module, 'lora'):
@ -54,6 +46,7 @@ def save_lora(model, path):
state_dict = {} state_dict = {}
for name, module in model.named_modules(): for name, module in model.named_modules():
if hasattr(module, 'lora'): if hasattr(module, 'lora'):
lora_state = {f'{name}.lora.{k}': v for k, v in module.lora.state_dict().items()} clean_name = name[7:] if name.startswith("module.") else name
lora_state = {f'{clean_name}.lora.{k}': v for k, v in module.lora.state_dict().items()}
state_dict.update(lora_state) state_dict.update(lora_state)
torch.save(state_dict, path) torch.save(state_dict, path)