mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-05-01 19:58:15 +08:00
Merge pull request #594 from whiteswordLI/fix/lora-load-ddp-weights
Fix: support loading DDP-saved LoRA weights for inference
This commit is contained in:
commit
048d84abc7
@ -34,6 +34,16 @@ def apply_lora(model, rank=8):
|
|||||||
|
|
||||||
def load_lora(model, path):
|
def load_lora(model, path):
|
||||||
state_dict = torch.load(path, map_location=model.device)
|
state_dict = torch.load(path, map_location=model.device)
|
||||||
|
|
||||||
|
# 兼容DDP训练保存的权重(带有module.前缀),去除前缀以匹配当前模型
|
||||||
|
new_state_dict = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if k.startswith('module.'):
|
||||||
|
new_state_dict[k[7:]] = v
|
||||||
|
else:
|
||||||
|
new_state_dict[k] = v
|
||||||
|
state_dict = new_state_dict
|
||||||
|
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if hasattr(module, 'lora'):
|
if hasattr(module, 'lora'):
|
||||||
lora_state = {k.replace(f'{name}.lora.', ''): v for k, v in state_dict.items() if f'{name}.lora.' in k}
|
lora_state = {k.replace(f'{name}.lora.', ''): v for k, v in state_dict.items() if f'{name}.lora.' in k}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user