mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-06-06 00:04:50 +00:00
[fix] lora dims
This commit is contained in:
+2
-2
@@ -20,8 +20,8 @@ class LoRA(nn.Module):
|
||||
|
||||
def apply_lora(model, rank=16):
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, nn.Linear) and module.weight.shape[0] == module.weight.shape[1]:
|
||||
lora = LoRA(module.weight.shape[0], module.weight.shape[1], rank=rank).to(model.device)
|
||||
if isinstance(module, nn.Linear) and module.in_features == module.out_features:
|
||||
lora = LoRA(module.in_features, module.out_features, rank=rank).to(model.device)
|
||||
setattr(module, "lora", lora)
|
||||
original_forward = module.forward
|
||||
|
||||
|
||||
Reference in New Issue
Block a user