diff --git a/model/model_lora.py b/model/model_lora.py index 45b4961..9d96955 100644 --- a/model/model_lora.py +++ b/model/model_lora.py @@ -20,8 +20,8 @@ class LoRA(nn.Module): def apply_lora(model, rank=16): for name, module in model.named_modules(): - if isinstance(module, nn.Linear) and module.weight.shape[0] == module.weight.shape[1]: - lora = LoRA(module.weight.shape[0], module.weight.shape[1], rank=rank).to(model.device) + if isinstance(module, nn.Linear) and module.in_features == module.out_features: + lora = LoRA(module.in_features, module.out_features, rank=rank).to(model.device) setattr(module, "lora", lora) original_forward = module.forward