[fix] lora dims

This commit is contained in:
jingyaogong
2026-05-31 13:38:51 +08:00
parent 4497610ec0
commit 4a68da72d5
+2 -2
View File
@@ -20,8 +20,8 @@ class LoRA(nn.Module):
def apply_lora(model, rank=16):
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and module.weight.shape[0] == module.weight.shape[1]:
lora = LoRA(module.weight.shape[0], module.weight.shape[1], rank=rank).to(model.device)
if isinstance(module, nn.Linear) and module.in_features == module.out_features:
lora = LoRA(module.in_features, module.out_features, rank=rank).to(model.device)
setattr(module, "lora", lora)
original_forward = module.forward