update lora

This commit is contained in:
jingyaogong
2025-04-27 15:45:06 +08:00
parent 2e118e9d3d
commit 5ffde04b7c
2 changed files with 17 additions and 11 deletions
+1 -1
View File
@@ -18,7 +18,7 @@ class LoRA(nn.Module):
return self.B(self.A(x)) return self.B(self.A(x))
def apply_lora(model, rank=16): def apply_lora(model, rank=8):
for name, module in model.named_modules(): for name, module in model.named_modules():
if isinstance(module, nn.Linear) and module.weight.shape[0] == module.weight.shape[1]: if isinstance(module, nn.Linear) and module.weight.shape[0] == module.weight.shape[1]:
lora = LoRA(module.weight.shape[0], module.weight.shape[1], rank=rank).to(model.device) lora = LoRA(module.weight.shape[0], module.weight.shape[1], rank=rank).to(model.device)
+16 -10
View File
@@ -1,5 +1,6 @@
import os import os
import sys import sys
__package__ = "trainer" __package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
@@ -7,13 +8,15 @@ import argparse
import time import time
import math import math
import warnings import warnings
import torch
from torch import optim, nn
import torch.distributed as dist import torch.distributed as dist
from contextlib import nullcontext from contextlib import nullcontext
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from dataset.lm_dataset import SFTDataset from dataset.lm_dataset import SFTDataset
from model.model_lora import * from model.model_lora import load_lora, save_lora, apply_lora
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
@@ -80,8 +83,10 @@ def train_epoch(epoch, wandb):
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
model.eval() model.eval()
lora_save_path = f'{args.save_dir}/lora/{args.lora_name}_{lm_config.hidden_size}.pth'
os.makedirs(os.path.dirname(lora_save_path), exist_ok=True)
# 【区别1】只保存lora权重即可 # 【区别1】只保存lora权重即可
save_lora(model, f'{args.save_dir}/lora/{args.lora_name}_{lm_config.hidden_size}.pth') save_lora(model, lora_save_path)
model.train() model.train()
@@ -89,7 +94,7 @@ def init_model(lm_config):
tokenizer = AutoTokenizer.from_pretrained('../model/') tokenizer = AutoTokenizer.from_pretrained('../model/')
model = MiniMindForCausalLM(lm_config) model = MiniMindForCausalLM(lm_config)
moe_path = '_moe' if lm_config.use_moe else '' moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/rlhf_{lm_config.hidden_size}{moe_path}.pth' ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=args.device) state_dict = torch.load(ckp, map_location=args.device)
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
return model.to(args.device), tokenizer return model.to(args.device), tokenizer
@@ -110,9 +115,9 @@ def init_distributed_mode():
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind SFT with LoRA") parser = argparse.ArgumentParser(description="MiniMind SFT with LoRA")
parser.add_argument("--out_dir", type=str, default="../out") parser.add_argument("--out_dir", type=str, default="../out")
parser.add_argument("--epochs", type=int, default=50) parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--learning_rate", type=float, default=5e-5) parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16") parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", action="store_true") parser.add_argument("--use_wandb", action="store_true")
@@ -123,17 +128,18 @@ if __name__ == "__main__":
parser.add_argument("--grad_clip", type=float, default=1.0) parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--warmup_iters", type=int, default=0) parser.add_argument("--warmup_iters", type=int, default=0)
parser.add_argument("--log_interval", type=int, default=100) parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--save_interval", type=int, default=1) parser.add_argument("--save_interval", type=int, default=100)
parser.add_argument('--local_rank', type=int, default=-1) parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--hidden_size', default=512, type=int) parser.add_argument('--hidden_size', default=512, type=int)
parser.add_argument('--num_hidden_layers', default=8, type=int) parser.add_argument('--num_hidden_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=512, type=int) parser.add_argument('--max_seq_len', default=512, type=int)
parser.add_argument('--use_moe', default=False, type=bool) parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument("--data_path", type=str, default="../dataset/lora_identity.jsonl") parser.add_argument("--data_path", type=str, default="../dataset/lora_medical.jsonl")
parser.add_argument("--lora_name", type=str, default="lora_identity", help="根据任务保存成lora_(英文/医学/心理...)") parser.add_argument("--lora_name", type=str, default="lora_medical", help="根据任务保存成lora_(英文/医学/心理...)")
args = parser.parse_args() args = parser.parse_args()
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe) lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
use_moe=args.use_moe)
args.save_dir = os.path.join(args.out_dir) args.save_dir = os.path.join(args.out_dir)
os.makedirs(args.save_dir, exist_ok=True) os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True) os.makedirs(args.out_dir, exist_ok=True)