update lora

This commit is contained in:
jingyaogong 2025-04-27 15:45:06 +08:00
parent 2e118e9d3d
commit 5ffde04b7c
2 changed files with 17 additions and 11 deletions

View File

@ -18,7 +18,7 @@ class LoRA(nn.Module):
return self.B(self.A(x))
def apply_lora(model, rank=16):
def apply_lora(model, rank=8):
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and module.weight.shape[0] == module.weight.shape[1]:
lora = LoRA(module.weight.shape[0], module.weight.shape[1], rank=rank).to(model.device)

View File

@ -1,5 +1,6 @@
import os
import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
@ -7,13 +8,15 @@ import argparse
import time
import math
import warnings
import torch
from torch import optim, nn
import torch.distributed as dist
from contextlib import nullcontext
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from dataset.lm_dataset import SFTDataset
from model.model_lora import *
from model.model_lora import load_lora, save_lora, apply_lora
warnings.filterwarnings('ignore')
@ -80,8 +83,10 @@ def train_epoch(epoch, wandb):
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
model.eval()
lora_save_path = f'{args.save_dir}/lora/{args.lora_name}_{lm_config.hidden_size}.pth'
os.makedirs(os.path.dirname(lora_save_path), exist_ok=True)
# 【区别1】只保存lora权重即可
save_lora(model, f'{args.save_dir}/lora/{args.lora_name}_{lm_config.hidden_size}.pth')
save_lora(model, lora_save_path)
model.train()
@ -89,7 +94,7 @@ def init_model(lm_config):
tokenizer = AutoTokenizer.from_pretrained('../model/')
model = MiniMindForCausalLM(lm_config)
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/rlhf_{lm_config.hidden_size}{moe_path}.pth'
ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=args.device)
model.load_state_dict(state_dict, strict=False)
return model.to(args.device), tokenizer
@ -110,9 +115,9 @@ def init_distributed_mode():
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind SFT with LoRA")
parser.add_argument("--out_dir", type=str, default="../out")
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--learning_rate", type=float, default=5e-5)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", action="store_true")
@ -123,17 +128,18 @@ if __name__ == "__main__":
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--warmup_iters", type=int, default=0)
parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--save_interval", type=int, default=1)
parser.add_argument("--save_interval", type=int, default=100)
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--hidden_size', default=512, type=int)
parser.add_argument('--num_hidden_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=512, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument("--data_path", type=str, default="../dataset/lora_identity.jsonl")
parser.add_argument("--lora_name", type=str, default="lora_identity", help="根据任务保存成lora_(英文/医学/心理...)")
parser.add_argument("--data_path", type=str, default="../dataset/lora_medical.jsonl")
parser.add_argument("--lora_name", type=str, default="lora_medical", help="根据任务保存成lora_(英文/医学/心理...)")
args = parser.parse_args()
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
use_moe=args.use_moe)
args.save_dir = os.path.join(args.out_dir)
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)