From 4a68da72d5a6c0c8817805b2b627bb935280b12a Mon Sep 17 00:00:00 2001 From: jingyaogong Date: Sun, 31 May 2026 13:38:51 +0800 Subject: [PATCH] [fix] lora dims --- model/model_lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model/model_lora.py b/model/model_lora.py index 45b4961..9d96955 100644 --- a/model/model_lora.py +++ b/model/model_lora.py @@ -20,8 +20,8 @@ class LoRA(nn.Module): def apply_lora(model, rank=16): for name, module in model.named_modules(): - if isinstance(module, nn.Linear) and module.weight.shape[0] == module.weight.shape[1]: - lora = LoRA(module.weight.shape[0], module.weight.shape[1], rank=rank).to(model.device) + if isinstance(module, nn.Linear) and module.in_features == module.out_features: + lora = LoRA(module.in_features, module.out_features, rank=rank).to(model.device) setattr(module, "lora", lora) original_forward = module.forward