Fix: support loading DDP-saved LoRA weights for inference

This commit is contained in:
whitesword 2025-12-22 20:50:25 +08:00
parent fe24501602
commit 3a18fdd666

View File

@ -34,6 +34,16 @@ def apply_lora(model, rank=8):
def load_lora(model, path):
state_dict = torch.load(path, map_location=model.device)
# 兼容DDP训练保存的权重带有module.前缀),去除前缀以匹配当前模型
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
new_state_dict[k[7:]] = v
else:
new_state_dict[k] = v
state_dict = new_state_dict
for name, module in model.named_modules():
if hasattr(module, 'lora'):
lora_state = {k.replace(f'{name}.lora.', ''): v for k, v in state_dict.items() if f'{name}.lora.' in k}