From a9c56b20e94ffafae84092e0edff4aaca2f32d85 Mon Sep 17 00:00:00 2001 From: jingyaogong Date: Mon, 22 Dec 2025 21:27:29 +0800 Subject: [PATCH] [fix] lora weight --- model/model_lora.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/model/model_lora.py b/model/model_lora.py index 8011526..c675f3b 100644 --- a/model/model_lora.py +++ b/model/model_lora.py @@ -34,15 +34,7 @@ def apply_lora(model, rank=8): def load_lora(model, path): state_dict = torch.load(path, map_location=model.device) - - # 兼容DDP训练保存的权重(带有module.前缀),去除前缀以匹配当前模型 - new_state_dict = {} - for k, v in state_dict.items(): - if k.startswith('module.'): - new_state_dict[k[7:]] = v - else: - new_state_dict[k] = v - state_dict = new_state_dict + state_dict = {(k[7:] if k.startswith('module.') else k): v for k, v in state_dict.items()} for name, module in model.named_modules(): if hasattr(module, 'lora'): @@ -54,6 +46,7 @@ def save_lora(model, path): state_dict = {} for name, module in model.named_modules(): if hasattr(module, 'lora'): - lora_state = {f'{name}.lora.{k}': v for k, v in module.lora.state_dict().items()} + clean_name = name[7:] if name.startswith("module.") else name + lora_state = {f'{clean_name}.lora.{k}': v for k, v in module.lora.state_dict().items()} state_dict.update(lora_state) torch.save(state_dict, path)