diff --git a/model/model_lora.py b/model/model_lora.py index ea53a27..b7c1d4c 100644 --- a/model/model_lora.py +++ b/model/model_lora.py @@ -18,7 +18,7 @@ class LoRA(nn.Module): return self.B(self.A(x)) -def apply_lora(model, rank=16): +def apply_lora(model, rank=8): for name, module in model.named_modules(): if isinstance(module, nn.Linear) and module.weight.shape[0] == module.weight.shape[1]: lora = LoRA(module.weight.shape[0], module.weight.shape[1], rank=rank).to(model.device) diff --git a/trainer/train_lora.py b/trainer/train_lora.py index 557c056..4711062 100644 --- a/trainer/train_lora.py +++ b/trainer/train_lora.py @@ -1,5 +1,6 @@ import os import sys + __package__ = "trainer" sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) @@ -7,13 +8,15 @@ import argparse import time import math import warnings +import torch +from torch import optim, nn import torch.distributed as dist from contextlib import nullcontext from torch.utils.data import DataLoader, DistributedSampler from transformers import AutoTokenizer, AutoModelForCausalLM from model.model_minimind import MiniMindConfig, MiniMindForCausalLM from dataset.lm_dataset import SFTDataset -from model.model_lora import * +from model.model_lora import load_lora, save_lora, apply_lora warnings.filterwarnings('ignore') @@ -80,8 +83,10 @@ def train_epoch(epoch, wandb): if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): model.eval() + lora_save_path = f'{args.save_dir}/lora/{args.lora_name}_{lm_config.hidden_size}.pth' + os.makedirs(os.path.dirname(lora_save_path), exist_ok=True) # 【区别1】只保存lora权重即可 - save_lora(model, f'{args.save_dir}/lora/{args.lora_name}_{lm_config.hidden_size}.pth') + save_lora(model, lora_save_path) model.train() @@ -89,7 +94,7 @@ def init_model(lm_config): tokenizer = AutoTokenizer.from_pretrained('../model/') model = MiniMindForCausalLM(lm_config) moe_path = '_moe' if lm_config.use_moe else '' - ckp = f'{args.save_dir}/rlhf_{lm_config.hidden_size}{moe_path}.pth' + ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth' state_dict = torch.load(ckp, map_location=args.device) model.load_state_dict(state_dict, strict=False) return model.to(args.device), tokenizer @@ -110,9 +115,9 @@ def init_distributed_mode(): if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind SFT with LoRA") parser.add_argument("--out_dir", type=str, default="../out") - parser.add_argument("--epochs", type=int, default=50) - parser.add_argument("--batch_size", type=int, default=16) - parser.add_argument("--learning_rate", type=float, default=5e-5) + parser.add_argument("--epochs", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--learning_rate", type=float, default=1e-4) parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") parser.add_argument("--dtype", type=str, default="bfloat16") parser.add_argument("--use_wandb", action="store_true") @@ -123,17 +128,18 @@ if __name__ == "__main__": parser.add_argument("--grad_clip", type=float, default=1.0) parser.add_argument("--warmup_iters", type=int, default=0) parser.add_argument("--log_interval", type=int, default=100) - parser.add_argument("--save_interval", type=int, default=1) + parser.add_argument("--save_interval", type=int, default=100) parser.add_argument('--local_rank', type=int, default=-1) parser.add_argument('--hidden_size', default=512, type=int) parser.add_argument('--num_hidden_layers', default=8, type=int) parser.add_argument('--max_seq_len', default=512, type=int) parser.add_argument('--use_moe', default=False, type=bool) - parser.add_argument("--data_path", type=str, default="../dataset/lora_identity.jsonl") - parser.add_argument("--lora_name", type=str, default="lora_identity", help="根据任务保存成lora_(英文/医学/心理...)") + parser.add_argument("--data_path", type=str, default="../dataset/lora_medical.jsonl") + parser.add_argument("--lora_name", type=str, default="lora_medical", help="根据任务保存成lora_(英文/医学/心理...)") args = parser.parse_args() - lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe) + lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, + use_moe=args.use_moe) args.save_dir = os.path.join(args.out_dir) os.makedirs(args.save_dir, exist_ok=True) os.makedirs(args.out_dir, exist_ok=True)