From e8484874f5005fc133ce41cac7abbf5c787c2a19 Mon Sep 17 00:00:00 2001 From: jingyaogong Date: Sun, 26 Oct 2025 18:49:52 +0800 Subject: [PATCH] [feat] pause-training --- trainer/train_distill_reason.py | 230 ++++++++++-------------- trainer/train_distillation.py | 266 ++++++++++++--------------- trainer/train_dpo.py | 281 ++++++++++++----------------- trainer/train_full_sft.py | 220 +++++++++-------------- trainer/train_grpo.py | 285 ++++++++++++++--------------- trainer/train_lora.py | 252 +++++++++++--------------- trainer/train_ppo.py | 309 ++++++++++++++++---------------- trainer/train_pretrain.py | 221 ++++++++++------------- trainer/train_spo.py | 296 ++++++++++++++---------------- trainer/trainer_utils.py | 139 ++++++++++++++ 10 files changed, 1171 insertions(+), 1328 deletions(-) create mode 100644 trainer/trainer_utils.py diff --git a/trainer/train_distill_reason.py b/trainer/train_distill_reason.py index 3bb15a8..f1e3526 100644 --- a/trainer/train_distill_reason.py +++ b/trainer/train_distill_reason.py @@ -6,7 +6,6 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) import argparse import time -import math import warnings import torch import torch.distributed as dist @@ -14,23 +13,14 @@ from contextlib import nullcontext from torch import optim, nn from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, DistributedSampler -from transformers import AutoTokenizer, AutoModelForCausalLM -from model.model_minimind import MiniMindConfig, MiniMindForCausalLM +from model.model_minimind import MiniMindConfig from dataset.lm_dataset import SFTDataset +from trainer.trainer_utils import * warnings.filterwarnings('ignore') -def Logger(content): - if not ddp or dist.get_rank() == 0: - print(content) - - -def get_lr(current_step, total_steps, lr): - return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps)) - - -def train_epoch(epoch, wandb): +def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=None): # 思考标签占位符 start_of_think_ids = tokenizer('').input_ids end_of_think_ids = tokenizer('').input_ids @@ -38,28 +28,30 @@ def train_epoch(epoch, wandb): end_of_answer_ids = tokenizer('').input_ids loss_fct = nn.CrossEntropyLoss(reduction='none') start_time = time.time() - for step, (X, Y, loss_mask) in enumerate(train_loader): + + for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1): X = X.to(args.device) Y = Y.to(args.device) loss_mask = loss_mask.to(args.device) - lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate) + lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate) for param_group in optimizer.param_groups: param_group['lr'] = lr - with ctx: + with autocast_ctx: res = model(X) loss = loss_fct( res.logits.view(-1, res.logits.size(-1)), Y.view(-1) ).view(Y.size()) + + # 特殊标签位置增加权重(推理蒸馏特有) sp_ids = torch.isin(Y.view(-1), torch.tensor(start_of_think_ids + end_of_think_ids + start_of_answer_ids + end_of_answer_ids ).to(args.device)) - # 在 sp_ids 对应的位置增加额外的惩罚 loss_mask = loss_mask.view(-1) loss_mask_sum = loss_mask.sum() - loss_mask[sp_ids] = 10 + loss_mask[sp_ids] = 10 # 对思考标签增加10倍权重 loss_mask = loss_mask.view(Y.size()) loss = (loss * loss_mask).sum() / loss_mask_sum loss += res.aux_loss @@ -70,148 +62,112 @@ def train_epoch(epoch, wandb): if (step + 1) % args.accumulation_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) - scaler.step(optimizer) scaler.update() - optimizer.zero_grad(set_to_none=True) - if step % args.log_interval == 0 or step == iter_per_epoch - 1: + if step % args.log_interval == 0 or step == iters - 1: spend_time = time.time() - start_time - Logger( - 'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format( - epoch + 1, - args.epochs, - step, - iter_per_epoch, - loss.item() * args.accumulation_steps, - optimizer.param_groups[-1]['lr'], - spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60)) + current_loss = loss.item() * args.accumulation_steps + current_lr = optimizer.param_groups[-1]['lr'] + eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60 + + Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:') + + if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min}) - if (wandb is not None) and (not ddp or dist.get_rank() == 0): - wandb.log({"loss": loss.item() * args.accumulation_steps, - "lr": optimizer.param_groups[-1]['lr'], - "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60}) - - if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0): + if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): model.eval() - moe_path = '_moe' if lm_config.use_moe else '' - ckp = f'{args.save_dir}/reason_{lm_config.hidden_size}{moe_path}.pth' - + moe_suffix = '_moe' if lm_config.use_moe else '' + ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' if isinstance(model, torch.nn.parallel.DistributedDataParallel): state_dict = model.module.state_dict() else: state_dict = model.state_dict() - state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存 torch.save(state_dict, ckp) + lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') model.train() -def init_model(lm_config): - tokenizer = AutoTokenizer.from_pretrained('../model') - model = MiniMindForCausalLM(lm_config) - moe_path = '_moe' if lm_config.use_moe else '' - ckp = f'{args.save_dir}/rlhf_{lm_config.hidden_size}{moe_path}.pth' - state_dict = torch.load(ckp, map_location=args.device) - model.load_state_dict(state_dict, strict=False) - Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万') - model = model.to(args.device) - return model, tokenizer - - -def init_distributed_mode(): - if not ddp: return - global ddp_local_rank, DEVICE - - dist.init_process_group(backend="nccl") - ddp_rank = int(os.environ["RANK"]) - ddp_local_rank = int(os.environ["LOCAL_RANK"]) - ddp_world_size = int(os.environ["WORLD_SIZE"]) - DEVICE = f"cuda:{ddp_local_rank}" - torch.cuda.set_device(DEVICE) - - if __name__ == "__main__": - parser = argparse.ArgumentParser(description="MiniMind Distill Reasoning") - parser.add_argument("--out_dir", type=str, default="../out") - parser.add_argument("--epochs", type=int, default=1) - parser.add_argument("--batch_size", type=int, default=8) - parser.add_argument("--learning_rate", type=float, default=1e-6) - parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") - parser.add_argument("--dtype", type=str, default="bfloat16") - parser.add_argument("--use_wandb", action="store_true") - parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT") - parser.add_argument("--num_workers", type=int, default=1) - parser.add_argument("--ddp", action="store_true") - parser.add_argument("--accumulation_steps", type=int, default=1) - parser.add_argument("--grad_clip", type=float, default=1.0) - parser.add_argument("--warmup_iters", type=int, default=0) - parser.add_argument("--log_interval", type=int, default=1) - parser.add_argument("--save_interval", type=int, default=50) - parser.add_argument('--local_rank', type=int, default=-1) - parser.add_argument('--hidden_size', default=512, type=int) - parser.add_argument('--num_hidden_layers', default=8, type=int) - parser.add_argument('--max_seq_len', default=1024, type=int) - parser.add_argument('--use_moe', default=False, type=bool) - parser.add_argument("--data_path", type=str, default="../dataset/r1_mix_1024.jsonl") - + parser = argparse.ArgumentParser(description="MiniMind Reasoning Distillation") + parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录") + parser.add_argument('--save_weight', default='reason', type=str, help="保存权重的前缀名") + parser.add_argument("--epochs", type=int, default=1, help="训练轮数") + parser.add_argument("--batch_size", type=int, default=8, help="batch size") + parser.add_argument("--learning_rate", type=float, default=1e-6, help="初始学习率") + parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备") + parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型") + parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数") + parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数") + parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值") + parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔") + parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔") + parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") + parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") + parser.add_argument('--max_seq_len', default=1024, type=int, help="训练的最大截断长度") + parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE") + parser.add_argument("--data_path", type=str, default="../dataset/r1_mix_1024.jsonl", help="推理蒸馏数据路径") + parser.add_argument('--from_weight', default='dpo', type=str, help="基于哪个权重训练,默认dpo") + parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是") + parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb") + parser.add_argument("--wandb_project", type=str, default="MiniMind-Reasoning", help="wandb项目名") args = parser.parse_args() - lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, - use_moe=args.use_moe) - args.save_dir = os.path.join(args.out_dir) + # ========== 1. 初始化环境和随机种子 ========== + local_rank = init_distributed_mode() + if dist.is_initialized(): args.device = f"cuda:{local_rank}" + setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) + + # ========== 2. 配置目录、模型参数、检查ckp ========== os.makedirs(args.save_dir, exist_ok=True) - os.makedirs(args.out_dir, exist_ok=True) - tokens_per_iter = args.batch_size * args.max_seq_len + lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe) + ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None + + # ========== 3. 设置混合精度 ========== device_type = "cuda" if "cuda" in args.device else "cpu" - - args.wandb_run_name = f"MiniMind-Distill-Reasoning-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}" - - ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast() - ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run? - ddp_local_rank, DEVICE = 0, "cuda:0" - base_seed = 1337 - torch.manual_seed(base_seed) - torch.cuda.manual_seed(base_seed) - - if ddp: - init_distributed_mode() - args.device = torch.device(DEVICE) - rank = dist.get_rank() - torch.manual_seed(base_seed + rank) - # 同时设置 CUDA 的随机种子 - torch.cuda.manual_seed(base_seed + rank) - - if args.use_wandb and (not ddp or ddp_local_rank == 0): + dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) + + # ========== 4. 配wandb ========== + wandb = None + if args.use_wandb and is_main_process(): import swanlab as wandb - - wandb.init(project=args.wandb_project, name=args.wandb_run_name) - else: - wandb = None - - model, tokenizer = init_model(lm_config) - + wandb_id = ckp_data.get('wandb_id') if ckp_data else None + resume = 'must' if wandb_id else None + wandb_run_name = f"MiniMind-Reasoning-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}" + wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) + + # ========== 5. 定义模型、数据、优化器 ========== + model, tokenizer = init_model(lm_config, args.from_weight) train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len) - train_sampler = DistributedSampler(train_ds) if ddp else None - train_loader = DataLoader( - train_ds, - batch_size=args.batch_size, - pin_memory=True, - drop_last=False, - shuffle=(train_sampler is None), - num_workers=args.num_workers, - sampler=train_sampler - ) - - scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) + train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None + scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) - - if ddp: + + # ========== 6. 从ckp恢复状态 ========== + start_epoch, start_step = 0, 0 + if ckp_data: + model.load_state_dict(ckp_data['model']) + optimizer.load_state_dict(ckp_data['optimizer']) + scaler.load_state_dict(ckp_data['scaler']) + start_epoch = ckp_data['epoch'] + start_step = ckp_data.get('step', 0) + + # ========== 7. DDP包模型 ========== + if dist.is_initialized(): model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) - - iter_per_epoch = len(train_loader) - for epoch in range(args.epochs): + model = DistributedDataParallel(model, device_ids=[local_rank]) + + # ========== 8. 开始训练 ========== + for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) - train_epoch(epoch, wandb) + if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点 + batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1) + loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) + Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始') + train_epoch(epoch, loader, len(loader) + start_step + 1, tokenizer, lm_config, start_step, wandb) + else: # 默认从头开始 + loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True) + train_epoch(epoch, loader, len(loader), tokenizer, lm_config, 0, wandb) diff --git a/trainer/train_distillation.py b/trainer/train_distillation.py index c97860c..f1de5d2 100644 --- a/trainer/train_distillation.py +++ b/trainer/train_distillation.py @@ -3,11 +3,10 @@ import sys __package__ = "trainer" sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + import argparse import time -import math import warnings - import torch import torch.nn.functional as F import torch.distributed as dist @@ -15,23 +14,14 @@ from contextlib import nullcontext from torch import optim from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, DistributedSampler -from transformers import AutoTokenizer, AutoModelForCausalLM -from model.model_minimind import MiniMindConfig, MiniMindForCausalLM +from model.model_minimind import MiniMindConfig from dataset.lm_dataset import SFTDataset +from trainer.trainer_utils import * warnings.filterwarnings('ignore') -def Logger(content): - if not ddp or dist.get_rank() == 0: - print(content) - - -def get_lr(current_step, total_steps, lr): - return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps)) - - -def distillation_loss_fn(student_logits, teacher_logits, temperature=1.0, reduction='batchmean'): +def distillation_loss(student_logits, teacher_logits, temperature=1.0, reduction='batchmean'): with torch.no_grad(): teacher_probs = F.softmax(teacher_logits / temperature, dim=-1).detach() @@ -45,25 +35,23 @@ def distillation_loss_fn(student_logits, teacher_logits, temperature=1.0, reduct return (temperature ** 2) * kl -def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0): +def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_step=0, wandb=None, alpha=0.0, temperature=1.0): start_time = time.time() - + if teacher_model is not None: teacher_model.eval() teacher_model.requires_grad_(False) - for step, (X, Y, loss_mask) in enumerate(train_loader): + for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1): X = X.to(args.device) Y = Y.to(args.device) loss_mask = loss_mask.to(args.device) - lr = get_lr(epoch * iter_per_epoch + step, - args.epochs * iter_per_epoch, - args.learning_rate) + lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate) for param_group in optimizer.param_groups: param_group['lr'] = lr # 前向传播(学生模型) - with ctx: + with autocast_ctx: res = model(X) student_logits = res.logits @@ -71,11 +59,11 @@ def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0): if teacher_model is not None: with torch.no_grad(): teacher_logits = teacher_model(X).logits - vocab_size_student = student_logits.size(-1) # N + vocab_size_student = student_logits.size(-1) teacher_logits = teacher_logits[..., :vocab_size_student] # ========== 计算损失 ========== - # 1) Ground-Truth CE Loss(可选) + # 1) Ground-Truth CE Loss loss_mask_flat = loss_mask.view(-1) ce_loss = F.cross_entropy( student_logits.view(-1, student_logits.size(-1)), @@ -87,10 +75,9 @@ def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0): if lm_config_student.use_moe: ce_loss += res.aux_loss - # 2) Distillation Loss(可选) + # 2) Distillation Loss if teacher_model is not None: - # 只在有效token位置做蒸馏 - distill_loss = distillation_loss_fn( + distill_loss = distillation_loss( student_logits.view(-1, student_logits.size(-1))[loss_mask_flat == 1], teacher_logits.view(-1, teacher_logits.size(-1))[loss_mask_flat == 1], temperature=temperature @@ -110,157 +97,126 @@ def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0): scaler.update() optimizer.zero_grad(set_to_none=True) - if step % args.log_interval == 0 or step == iter_per_epoch - 1: + if step % args.log_interval == 0 or step == iters - 1: spend_time = time.time() - start_time - Logger( - 'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format( - epoch, - args.epochs - 1, - step, - iter_per_epoch, - loss.item(), - optimizer.param_groups[-1]['lr'], - spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60 - ) - ) - - if (wandb is not None) and (not ddp or dist.get_rank() == 0): + current_loss = loss.item() * args.accumulation_steps + current_lr = optimizer.param_groups[-1]['lr'] + eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60 + + Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} ce:{ce_loss.item():.4f} distill:{distill_loss.item():.4f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:') + + if wandb: wandb.log({ - "loss": loss.item(), + "loss": current_loss, "ce_loss": ce_loss.item(), "distill_loss": distill_loss.item() if teacher_model is not None else 0.0, - "lr": optimizer.param_groups[-1]['lr'], - "last-time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60 + "lr": current_lr, + "epoch_Time": eta_min }) - if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0): + if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): model.eval() - moe_path = '_moe' if lm_config_student.use_moe else '' - ckp = f'{args.save_dir}/full_dist_{lm_config_student.hidden_size}{moe_path}.pth' + moe_suffix = '_moe' if lm_config_student.use_moe else '' + ckp = f'{args.save_dir}/{args.save_weight}_{lm_config_student.hidden_size}{moe_suffix}.pth' if isinstance(model, torch.nn.parallel.DistributedDataParallel): state_dict = model.module.state_dict() else: state_dict = model.state_dict() state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存 torch.save(state_dict, ckp) + lm_checkpoint(lm_config_student, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') model.train() -def init_student_model(lm_config): - tokenizer = AutoTokenizer.from_pretrained('../model/') - model = MiniMindForCausalLM(lm_config) - moe_path = '_moe' if lm_config.use_moe else '' - ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth' - state_dict = torch.load(ckp, map_location=args.device) - model.load_state_dict(state_dict, strict=False) - Logger(f'学生模型(LLM)总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万') - model = model.to(args.device) - - return model, tokenizer - - -def init_teacher_model(lm_config): - model = MiniMindForCausalLM(lm_config) - moe_path = '_moe' if lm_config.use_moe else '' - ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth' - state_dict = torch.load(ckp, map_location=args.device) - model.load_state_dict(state_dict, strict=False) - Logger(f'教师模型(LLM)总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万') - model = model.to(args.device) - return model - - -def init_distributed_mode(): - if not ddp: return - global ddp_local_rank, DEVICE - - dist.init_process_group(backend="nccl") - ddp_rank = int(os.environ["RANK"]) - ddp_local_rank = int(os.environ["LOCAL_RANK"]) - ddp_world_size = int(os.environ["WORLD_SIZE"]) - DEVICE = f"cuda:{ddp_local_rank}" - torch.cuda.set_device(DEVICE) - - if __name__ == "__main__": - parser = argparse.ArgumentParser(description="MiniMind Full SFT") - parser.add_argument("--out_dir", type=str, default="../out") - parser.add_argument("--epochs", type=int, default=6) - parser.add_argument("--batch_size", type=int, default=32) - parser.add_argument("--learning_rate", type=float, default=5e-6) - parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") - parser.add_argument("--dtype", type=str, default="bfloat16") - parser.add_argument("--use_wandb", action="store_true") - parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT") - parser.add_argument("--num_workers", type=int, default=1) - parser.add_argument("--ddp", action="store_true") - parser.add_argument("--accumulation_steps", type=int, default=1) - parser.add_argument("--grad_clip", type=float, default=1.0) - parser.add_argument("--warmup_iters", type=int, default=0) - parser.add_argument("--log_interval", type=int, default=100) - parser.add_argument("--save_interval", type=int, default=100) - parser.add_argument("--max_seq_len", type=int, default=512) - parser.add_argument('--local_rank', type=int, default=-1) - parser.add_argument("--data_path", type=str, default="../dataset/sft_xxx.jsonl") - + parser = argparse.ArgumentParser(description="MiniMind Knowledge Distillation") + parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录") + parser.add_argument('--save_weight', default='full_dist', type=str, help="保存权重的前缀名") + parser.add_argument("--epochs", type=int, default=6, help="训练轮数") + parser.add_argument("--batch_size", type=int, default=32, help="batch size") + parser.add_argument("--learning_rate", type=float, default=5e-6, help="初始学习率") + parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备") + parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型") + parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数") + parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数") + parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值") + parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔") + parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔") + parser.add_argument("--max_seq_len", type=int, default=512, help="训练的最大截断长度") + parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl", help="训练数据路径") + parser.add_argument('--student_hidden_size', default=512, type=int, help="学生模型隐藏层维度") + parser.add_argument('--student_num_layers', default=8, type=int, help="学生模型隐藏层数量") + parser.add_argument('--teacher_hidden_size', default=768, type=int, help="教师模型隐藏层维度") + parser.add_argument('--teacher_num_layers', default=16, type=int, help="教师模型隐藏层数量") + parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE") + parser.add_argument('--from_student_weight', default='full_sft', type=str, help="学生模型基于哪个权重") + parser.add_argument('--from_teacher_weight', default='full_sft', type=str, help="教师模型基于哪个权重") + parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是") + parser.add_argument('--alpha', default=0.5, type=float, help="CE损失权重,总损失=alpha*CE+(1-alpha)*KL") + parser.add_argument('--temperature', default=2.0, type=float, help="蒸馏温度") + parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb") + parser.add_argument("--wandb_project", type=str, default="MiniMind-Distillation", help="wandb项目名") args = parser.parse_args() - # 定义学生模型和教师模型 - lm_config_student = MiniMindConfig(hidden_size=512, num_hidden_layers=8) - lm_config_teacher = MiniMindConfig(hidden_size=768, num_hidden_layers=16) - args.save_dir = os.path.join(args.out_dir) + + # ========== 1. 初始化环境和随机种子 ========== + local_rank = init_distributed_mode() + if dist.is_initialized(): args.device = f"cuda:{local_rank}" + setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) + + # ========== 2. 配置目录、模型参数、检查ckp ========== os.makedirs(args.save_dir, exist_ok=True) - os.makedirs(args.out_dir, exist_ok=True) - tokens_per_iter = args.batch_size * args.max_seq_len + lm_config_student = MiniMindConfig(hidden_size=args.student_hidden_size, num_hidden_layers=args.student_num_layers, use_moe=args.use_moe) + lm_config_teacher = MiniMindConfig(hidden_size=args.teacher_hidden_size, num_hidden_layers=args.teacher_num_layers, use_moe=args.use_moe) + ckp_data = lm_checkpoint(lm_config_student, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None + + # ========== 3. 设置混合精度 ========== device_type = "cuda" if "cuda" in args.device else "cpu" - - args.wandb_run_name = f"MiniMind-Dist-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}" - - ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast() - ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run? - ddp_local_rank, DEVICE = 0, "cuda:0" - base_seed = 1337 - torch.manual_seed(base_seed) - torch.cuda.manual_seed(base_seed) - - if ddp: - init_distributed_mode() - args.device = torch.device(DEVICE) - rank = dist.get_rank() - torch.manual_seed(base_seed + rank) - # 同时设置 CUDA 的随机种子 - torch.cuda.manual_seed(base_seed + rank) - - if args.use_wandb and (not ddp or ddp_local_rank == 0): + dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) + + # ========== 4. 配wandb ========== + wandb = None + if args.use_wandb and is_main_process(): import swanlab as wandb - - wandb.init(project=args.wandb_project, name=args.wandb_run_name) - else: - wandb = None - - # 初始化学生模型和教师模型 - model, tokenizer = init_student_model(lm_config_student) - teacher_model = init_teacher_model(lm_config_teacher) - + wandb_id = ckp_data.get('wandb_id') if ckp_data else None + resume = 'must' if wandb_id else None + wandb_run_name = f"MiniMind-Distill-S{args.student_hidden_size}T{args.teacher_hidden_size}-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}" + wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) + + # ========== 5. 定义学生和教师模型 ========== + model, tokenizer = init_model(lm_config_student, args.from_student_weight) + Logger(f'学生模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M') + teacher_model, _ = init_model(lm_config_teacher, args.from_teacher_weight) + teacher_model.eval() + teacher_model.requires_grad_(False) + Logger(f'教师模型总参数量:{sum(p.numel() for p in teacher_model.parameters()) / 1e6:.3f} M') train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len) - train_sampler = DistributedSampler(train_ds) if ddp else None - train_loader = DataLoader( - train_ds, - batch_size=args.batch_size, - pin_memory=True, - drop_last=False, - shuffle=(train_sampler is None), - num_workers=args.num_workers, - sampler=train_sampler - ) - - scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) + train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None + scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) - - if ddp: + + # ========== 6. 从ckp恢复状态 ========== + start_epoch, start_step = 0, 0 + if ckp_data: + model.load_state_dict(ckp_data['model']) + optimizer.load_state_dict(ckp_data['optimizer']) + scaler.load_state_dict(ckp_data['scaler']) + start_epoch = ckp_data['epoch'] + start_step = ckp_data.get('step', 0) + + # ========== 7. DDP包模型 ========== + if dist.is_initialized(): model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) - - iter_per_epoch = len(train_loader) - for epoch in range(args.epochs): + model = DistributedDataParallel(model, device_ids=[local_rank]) + + # ========== 8. 开始训练 ========== + for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) - train_epoch(epoch, wandb) + if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点 + batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1) + loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) + Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始') + train_epoch(epoch, loader, len(loader) + start_step + 1, teacher_model, lm_config_student, start_step, wandb, args.alpha, args.temperature) + else: # 默认从头开始 + loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True) + train_epoch(epoch, loader, len(loader), teacher_model, lm_config_student, 0, wandb, args.alpha, args.temperature) diff --git a/trainer/train_dpo.py b/trainer/train_dpo.py index 8c7bb45..b4e7b37 100644 --- a/trainer/train_dpo.py +++ b/trainer/train_dpo.py @@ -6,7 +6,6 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) import argparse import time -import math import warnings import torch import torch.nn.functional as F @@ -15,55 +14,47 @@ from contextlib import nullcontext from torch import optim from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, DistributedSampler -from transformers import AutoTokenizer, AutoModelForCausalLM -from model.model_minimind import MiniMindConfig, MiniMindForCausalLM +from model.model_minimind import MiniMindConfig from dataset.lm_dataset import DPODataset +from trainer.trainer_utils import * warnings.filterwarnings('ignore') -def Logger(content): - if not ddp or dist.get_rank() == 0: - print(content) - - -def get_lr(current_step, total_steps, lr): - return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps)) - - -def logits_to_probs(logits, labels): +def logits_to_log_probs(logits, labels): # logits shape: (batch_size, seq_len, vocab_size) # labels shape: (batch_size, seq_len) - # probs shape: (batch_size, seq_len) + # log_probs shape: (batch_size, seq_len) log_probs = F.log_softmax(logits, dim=2) - probs = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(-1) - return probs + log_probs_per_token = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(-1) + return log_probs_per_token -def dpo_loss(ref_probs, probs, mask, beta): - # ref_probs 和 probs 都是 shape: (batch_size, seq_len) +def dpo_loss(ref_log_probs, policy_log_probs, mask, beta): + # ref_log_probs 和 policy_log_probs 都是 shape: (batch_size, seq_len) # https://github.com/jingyaogong/minimind/issues/298 - seq_lengths = mask.sum(dim=1, keepdim=True) # (batch_size, 1) - ref_probs = (ref_probs * mask).sum(dim=1) / seq_lengths.squeeze() - probs = (probs * mask).sum(dim=1) / seq_lengths.squeeze() + seq_lengths = mask.sum(dim=1, keepdim=True).clamp_min(1e-8) # 防止零长度mask导致除零NaN + ref_log_probs = (ref_log_probs * mask).sum(dim=1) / seq_lengths.squeeze() + policy_log_probs = (policy_log_probs * mask).sum(dim=1) / seq_lengths.squeeze() # 将 chosen 和 rejected 数据分开 - batch_size = ref_probs.shape[0] - chosen_ref_probs = ref_probs[:batch_size // 2] - reject_ref_probs = ref_probs[batch_size // 2:] - chosen_probs = probs[:batch_size // 2] - reject_probs = probs[batch_size // 2:] + batch_size = ref_log_probs.shape[0] + chosen_ref_log_probs = ref_log_probs[:batch_size // 2] + reject_ref_log_probs = ref_log_probs[batch_size // 2:] + chosen_policy_log_probs = policy_log_probs[:batch_size // 2] + reject_policy_log_probs = policy_log_probs[batch_size // 2:] - pi_logratios = chosen_probs - reject_probs - ref_logratios = chosen_ref_probs - reject_ref_probs + pi_logratios = chosen_policy_log_probs - reject_policy_log_probs + ref_logratios = chosen_ref_log_probs - reject_ref_log_probs logits = pi_logratios - ref_logratios loss = -F.logsigmoid(beta * logits) return loss.mean() -def train_epoch(epoch, wandb): +def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=None, beta=0.1): start_time = time.time() - for step, batch in enumerate(train_loader): + + for step, batch in enumerate(loader, start=start_step + 1): x_chosen = batch['x_chosen'].to(args.device) x_rejected = batch['x_rejected'].to(args.device) y_chosen = batch['y_chosen'].to(args.device) @@ -74,21 +65,21 @@ def train_epoch(epoch, wandb): y = torch.cat([y_chosen, y_rejected], dim=0) mask = torch.cat([mask_chosen, mask_rejected], dim=0) - lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate) + lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate) for param_group in optimizer.param_groups: param_group['lr'] = lr - with ctx: + with autocast_ctx: with torch.no_grad(): ref_outputs = ref_model(x) ref_logits = ref_outputs.logits - ref_probs = logits_to_probs(ref_logits, y) - ref_probs = ref_probs * mask + ref_log_probs = logits_to_log_probs(ref_logits, y) + outputs = model(x) logits = outputs.logits - probs = logits_to_probs(logits, y) - probs = probs * mask - loss = dpo_loss(ref_probs, probs, mask, beta=0.1) + policy_log_probs = logits_to_log_probs(logits, y) + + loss = dpo_loss(ref_log_probs, policy_log_probs, mask, beta=beta) loss = loss / args.accumulation_steps scaler.scale(loss).backward() @@ -100,150 +91,116 @@ def train_epoch(epoch, wandb): scaler.update() optimizer.zero_grad(set_to_none=True) - if step % args.log_interval == 0 or step == iter_per_epoch - 1: + if step % args.log_interval == 0 or step == iters - 1: spend_time = time.time() - start_time - Logger( - 'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format( - epoch + 1, - args.epochs, - step, - iter_per_epoch, - loss.item() * args.accumulation_steps, - optimizer.param_groups[-1]['lr'], - spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60)) + current_loss = loss.item() * args.accumulation_steps + current_lr = optimizer.param_groups[-1]['lr'] + eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60 + + Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:') + + if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min}) - if (wandb is not None) and (not ddp or dist.get_rank() == 0): - wandb.log({"loss": loss.item() * args.accumulation_steps, - "lr": optimizer.param_groups[-1]['lr'], - "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60}) - - if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0): + if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): model.eval() - moe_path = '_moe' if lm_config.use_moe else '' - ckp = f'{args.save_dir}/rlhf_{lm_config.hidden_size}{moe_path}.pth' - + moe_suffix = '_moe' if lm_config.use_moe else '' + ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' if isinstance(model, torch.nn.parallel.DistributedDataParallel): state_dict = model.module.state_dict() else: state_dict = model.state_dict() state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存 torch.save(state_dict, ckp) + lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') model.train() -def init_model(lm_config): - tokenizer = AutoTokenizer.from_pretrained('../model/') - model = MiniMindForCausalLM(lm_config) - moe_path = '_moe' if lm_config.use_moe else '' - ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth' - state_dict = torch.load(ckp, map_location=args.device) - model.load_state_dict(state_dict, strict=False) - # 初始化参考模型 - ref_model = MiniMindForCausalLM(lm_config) - ref_model.load_state_dict(state_dict, strict=False) - ref_model.eval() - ref_model.requires_grad_(False) - - Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万') - model = model.to(args.device) - ref_model = ref_model.to(args.device) - - return model, ref_model, tokenizer - - -def init_distributed_mode(): - if not ddp: return - global ddp_local_rank, DEVICE - - dist.init_process_group(backend="nccl") - ddp_rank = int(os.environ["RANK"]) - ddp_local_rank = int(os.environ["LOCAL_RANK"]) - ddp_world_size = int(os.environ["WORLD_SIZE"]) - DEVICE = f"cuda:{ddp_local_rank}" - torch.cuda.set_device(DEVICE) - - if __name__ == "__main__": - parser = argparse.ArgumentParser(description="MiniMind RLHF") - parser.add_argument("--out_dir", type=str, default="../out") - parser.add_argument("--epochs", type=int, default=2) - parser.add_argument("--batch_size", type=int, default=4) - # sft阶段学习率为 「5e-6」->「5e-7」长度512,建议离线正负样本「概率」偏好对齐阶段lr <=「1e-8」长度3000,否则很容易遗忘训坏 - parser.add_argument("--learning_rate", type=float, default=1e-8) - parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") - parser.add_argument("--dtype", type=str, default="bfloat16") - parser.add_argument("--use_wandb", action="store_true") - parser.add_argument("--wandb_project", type=str, default="MiniMind-RLHF-SFT") - parser.add_argument("--num_workers", type=int, default=1) - parser.add_argument("--ddp", action="store_true") - parser.add_argument("--accumulation_steps", type=int, default=1) - parser.add_argument("--grad_clip", type=float, default=1.0) - parser.add_argument("--warmup_iters", type=int, default=0) - parser.add_argument("--log_interval", type=int, default=100) - parser.add_argument("--save_interval", type=int, default=100) - parser.add_argument('--local_rank', type=int, default=-1) - parser.add_argument('--hidden_size', default=512, type=int) - parser.add_argument('--num_hidden_layers', default=8, type=int) - parser.add_argument('--max_seq_len', default=1024, type=int) - parser.add_argument('--use_moe', default=False, type=bool) - parser.add_argument("--data_path", type=str, default="../dataset/dpo.jsonl") - + parser = argparse.ArgumentParser(description="MiniMind DPO (Direct Preference Optimization)") + parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录") + parser.add_argument('--save_weight', default='dpo', type=str, help="保存权重的前缀名") + parser.add_argument("--epochs", type=int, default=2, help="训练轮数") + parser.add_argument("--batch_size", type=int, default=4, help="batch size") + parser.add_argument("--learning_rate", type=float, default=5e-8, help="初始学习率(建议<=5e-8避免遗忘)") + parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备") + parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型") + parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数") + parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数") + parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值") + parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔") + parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔") + parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") + parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") + parser.add_argument('--max_seq_len', default=1024, type=int, help="训练的最大截断长度") + parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE") + parser.add_argument("--data_path", type=str, default="../dataset/dpo.jsonl", help="DPO训练数据路径") + parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练") + parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是") + parser.add_argument('--beta', default=0.1, type=float, help="DPO中的beta参数") + parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb") + parser.add_argument("--wandb_project", type=str, default="MiniMind-DPO", help="wandb项目名") args = parser.parse_args() - lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, - use_moe=args.use_moe) - args.save_dir = os.path.join(args.out_dir) + # ========== 1. 初始化环境和随机种子 ========== + local_rank = init_distributed_mode() + if dist.is_initialized(): args.device = f"cuda:{local_rank}" + setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) + + # ========== 2. 配置目录、模型参数、检查ckp ========== os.makedirs(args.save_dir, exist_ok=True) - os.makedirs(args.out_dir, exist_ok=True) - tokens_per_iter = args.batch_size * args.max_seq_len + lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe) + ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None + + # ========== 3. 设置混合精度 ========== device_type = "cuda" if "cuda" in args.device else "cpu" - - args.wandb_run_name = f"MiniMind-Full-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}" - - ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast() - ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run? - ddp_local_rank, DEVICE = 0, "cuda:0" - base_seed = 1337 - torch.manual_seed(base_seed) - torch.cuda.manual_seed(base_seed) - - if ddp: - init_distributed_mode() - args.device = torch.device(DEVICE) - rank = dist.get_rank() - torch.manual_seed(base_seed + rank) - # 同时设置 CUDA 的随机种子 - torch.cuda.manual_seed(base_seed + rank) - - if args.use_wandb and (not ddp or ddp_local_rank == 0): + dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) + + # ========== 4. 配wandb ========== + wandb = None + if args.use_wandb and is_main_process(): import swanlab as wandb - - wandb.init(project=args.wandb_project, name=args.wandb_run_name) - else: - wandb = None - - model, ref_model, tokenizer = init_model(lm_config) - + wandb_id = ckp_data.get('wandb_id') if ckp_data else None + resume = 'must' if wandb_id else None + wandb_run_name = f"MiniMind-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}" + wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) + + # ========== 5. 定义模型和参考模型 ========== + model, tokenizer = init_model(lm_config, args.from_weight) + Logger(f'策略模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M') + # 初始化参考模型(ref_model冻结) + ref_model, _ = init_model(lm_config, args.from_weight) + ref_model.eval() + ref_model.requires_grad_(False) + Logger(f'参考模型总参数量:{sum(p.numel() for p in ref_model.parameters()) / 1e6:.3f} M') + train_ds = DPODataset(args.data_path, tokenizer, max_length=args.max_seq_len) - train_sampler = DistributedSampler(train_ds) if ddp else None - train_loader = DataLoader( - train_ds, - batch_size=args.batch_size, - pin_memory=True, - drop_last=False, - shuffle=(train_sampler is None), - num_workers=args.num_workers, - sampler=train_sampler - ) - - scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) + train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None + scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) - - if ddp: + + # ========== 6. 从ckp恢复状态 ========== + start_epoch, start_step = 0, 0 + if ckp_data: + model.load_state_dict(ckp_data['model']) + optimizer.load_state_dict(ckp_data['optimizer']) + scaler.load_state_dict(ckp_data['scaler']) + start_epoch = ckp_data['epoch'] + start_step = ckp_data.get('step', 0) + + # ========== 7. DDP包模型 ========== + if dist.is_initialized(): model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) - - iter_per_epoch = len(train_loader) - for epoch in range(args.epochs): + model = DistributedDataParallel(model, device_ids=[local_rank]) + + # ========== 8. 开始训练 ========== + for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) - train_epoch(epoch, wandb) + if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点 + batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1) + loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) + Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始') + train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, lm_config, start_step, wandb, args.beta) + else: # 默认从头开始 + loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True) + train_epoch(epoch, loader, len(loader), ref_model, lm_config, 0, wandb, args.beta) diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py index 041801f..09fa941 100644 --- a/trainer/train_full_sft.py +++ b/trainer/train_full_sft.py @@ -6,7 +6,6 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) import argparse import time -import math import warnings import torch import torch.distributed as dist @@ -14,34 +13,25 @@ from contextlib import nullcontext from torch import optim, nn from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, DistributedSampler -from transformers import AutoTokenizer, AutoModelForCausalLM -from model.model_minimind import MiniMindConfig, MiniMindForCausalLM +from model.model_minimind import MiniMindConfig from dataset.lm_dataset import SFTDataset +from trainer.trainer_utils import * warnings.filterwarnings('ignore') -def Logger(content): - if not ddp or dist.get_rank() == 0: - print(content) - - -def get_lr(current_step, total_steps, lr): - return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps)) - - -def train_epoch(epoch, wandb): +def train_epoch(epoch, loader, iters, start_step=0, wandb=None): loss_fct = nn.CrossEntropyLoss(reduction='none') start_time = time.time() - for step, (X, Y, loss_mask) in enumerate(train_loader): + for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1): X = X.to(args.device) Y = Y.to(args.device) loss_mask = loss_mask.to(args.device) - lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate) + lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate) for param_group in optimizer.param_groups: param_group['lr'] = lr - with ctx: + with autocast_ctx: res = model(X) loss = loss_fct( res.logits.view(-1, res.logits.size(-1)), @@ -63,141 +53,109 @@ def train_epoch(epoch, wandb): optimizer.zero_grad(set_to_none=True) - if step % args.log_interval == 0 or step == iter_per_epoch - 1: + if step % args.log_interval == 0 or step == iters - 1: spend_time = time.time() - start_time - Logger( - 'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format( - epoch + 1, - args.epochs, - step, - iter_per_epoch, - loss.item() * args.accumulation_steps, - optimizer.param_groups[-1]['lr'], - spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60)) + current_loss = loss.item() * args.accumulation_steps + current_lr = optimizer.param_groups[-1]['lr'] + eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60 + + Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:') + + if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min}) - if (wandb is not None) and (not ddp or dist.get_rank() == 0): - wandb.log({"loss": loss.item() * args.accumulation_steps, - "lr": optimizer.param_groups[-1]['lr'], - "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60}) - - if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0): + if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): model.eval() - moe_path = '_moe' if lm_config.use_moe else '' - ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth' + moe_suffix = '_moe' if lm_config.use_moe else '' + ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' if isinstance(model, torch.nn.parallel.DistributedDataParallel): state_dict = model.module.state_dict() else: state_dict = model.state_dict() state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存 torch.save(state_dict, ckp) + lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, + epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler) model.train() -def init_model(lm_config): - tokenizer = AutoTokenizer.from_pretrained('../model') - model = MiniMindForCausalLM(lm_config) - moe_path = '_moe' if lm_config.use_moe else '' - ckp = f'{args.save_dir}/pretrain_{lm_config.hidden_size}{moe_path}.pth' - state_dict = torch.load(ckp, map_location=args.device) - model.load_state_dict(state_dict, strict=False) - - Logger(f'LLM可训练总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万') - model = model.to(args.device) - return model, tokenizer - - -def init_distributed_mode(): - if not ddp: return - global ddp_local_rank, DEVICE - - dist.init_process_group(backend="nccl") - ddp_rank = int(os.environ["RANK"]) - ddp_local_rank = int(os.environ["LOCAL_RANK"]) - ddp_world_size = int(os.environ["WORLD_SIZE"]) - DEVICE = f"cuda:{ddp_local_rank}" - torch.cuda.set_device(DEVICE) - - if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind Full SFT") - parser.add_argument("--out_dir", type=str, default="../out") - parser.add_argument("--epochs", type=int, default=2) - parser.add_argument("--batch_size", type=int, default=16) - parser.add_argument("--learning_rate", type=float, default=5e-7) - parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") - parser.add_argument("--dtype", type=str, default="bfloat16") - parser.add_argument("--use_wandb", action="store_true") - parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT") - parser.add_argument("--num_workers", type=int, default=1) - parser.add_argument("--ddp", action="store_true") - parser.add_argument("--accumulation_steps", type=int, default=1) - parser.add_argument("--grad_clip", type=float, default=1.0) - parser.add_argument("--warmup_iters", type=int, default=0) - parser.add_argument("--log_interval", type=int, default=100) - parser.add_argument("--save_interval", type=int, default=100) - parser.add_argument('--local_rank', type=int, default=-1) - parser.add_argument('--hidden_size', default=512, type=int) - parser.add_argument('--num_hidden_layers', default=8, type=int) - parser.add_argument('--max_seq_len', default=512, type=int) - parser.add_argument('--use_moe', default=False, type=bool) - parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl") - + parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录") + parser.add_argument('--save_weight', default='full_sft', type=str, help="保存权重的前缀名") + parser.add_argument("--epochs", type=int, default=2, help="训练轮数") + parser.add_argument("--batch_size", type=int, default=16, help="batch size") + parser.add_argument("--learning_rate", type=float, default=5e-7, help="初始学习率") + parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备") + parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型") + parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数") + parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数") + parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值") + parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔") + parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔") + parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") + parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") + parser.add_argument('--max_seq_len', default=512, type=int, help="训练的最大截断长度") + parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE") + parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl", help="训练数据路径") + parser.add_argument('--from_weight', default='pretrain', type=str, help="基于哪个权重训练,为none则不基于任何权重训练") + parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是") + parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb") + parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT", help="wandb项目名") args = parser.parse_args() - lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, - use_moe=args.use_moe) - args.save_dir = os.path.join(args.out_dir) + # ========== 1. 初始化环境和随机种子 ========== + local_rank = init_distributed_mode() + if dist.is_initialized(): args.device = f"cuda:{local_rank}" + setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) + + # ========== 2. 配置目录、模型参数、检查ckp ========== os.makedirs(args.save_dir, exist_ok=True) - os.makedirs(args.out_dir, exist_ok=True) - tokens_per_iter = args.batch_size * args.max_seq_len + lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe) + ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None + + # ========== 3. 设置混合精度 ========== device_type = "cuda" if "cuda" in args.device else "cpu" - - args.wandb_run_name = f"MiniMind-Full-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}" - - ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast() - ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run? - ddp_local_rank, DEVICE = 0, "cuda:0" - base_seed = 1337 - torch.manual_seed(base_seed) - torch.cuda.manual_seed(base_seed) - - if ddp: - init_distributed_mode() - args.device = torch.device(DEVICE) - rank = dist.get_rank() - torch.manual_seed(base_seed + rank) - # 同时设置 CUDA 的随机种子 - torch.cuda.manual_seed(base_seed + rank) - - if args.use_wandb and (not ddp or ddp_local_rank == 0): + dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) + + # ========== 4. 配wandb ========== + wandb = None + if args.use_wandb and is_main_process(): import swanlab as wandb - - wandb.init(project=args.wandb_project, name=args.wandb_run_name) - else: - wandb = None - - model, tokenizer = init_model(lm_config) - + wandb_id = ckp_data.get('wandb_id') if ckp_data else None + resume = 'must' if wandb_id else None + wandb_run_name = f"MiniMind-Full-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}" + wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) + + # ========== 5. 定义模型、数据、优化器 ========== + model, tokenizer = init_model(lm_config, args.from_weight) train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len) - train_sampler = DistributedSampler(train_ds) if ddp else None - train_loader = DataLoader( - train_ds, - batch_size=args.batch_size, - pin_memory=True, - drop_last=False, - shuffle=(train_sampler is None), - num_workers=args.num_workers, - sampler=train_sampler - ) - - scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) + train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None + scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) - - if ddp: + + # ========== 6. 从ckp恢复状态 ========== + start_epoch, start_step = 0, 0 + if ckp_data: + model.load_state_dict(ckp_data['model']) + optimizer.load_state_dict(ckp_data['optimizer']) + scaler.load_state_dict(ckp_data['scaler']) + start_epoch = ckp_data['epoch'] + start_step = ckp_data.get('step', 0) + + # ========== 7. DDP包模型 ========== + if dist.is_initialized(): model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) - - iter_per_epoch = len(train_loader) - for epoch in range(args.epochs): + model = DistributedDataParallel(model, device_ids=[local_rank]) + + # ========== 8. 开始训练 ========== + for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) - train_epoch(epoch, wandb) + if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点 + batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1) + loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) + Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始') + train_epoch(epoch, loader, len(loader) + start_step + 1, start_step, wandb) + else: # 默认从头开始 + loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True) + train_epoch(epoch, loader, len(loader), 0, wandb) diff --git a/trainer/train_grpo.py b/trainer/train_grpo.py index 704800d..22a4e26 100755 --- a/trainer/train_grpo.py +++ b/trainer/train_grpo.py @@ -3,59 +3,49 @@ import sys __package__ = "trainer" sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + import argparse -import time import re import gc +import warnings import torch -from contextlib import nullcontext import torch.distributed as dist +from transformers import AutoTokenizer +from contextlib import nullcontext from torch import optim from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, DistributedSampler -from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel +from torch.optim.lr_scheduler import CosineAnnealingLR +from transformers import AutoModel from model.model_minimind import MiniMindConfig, MiniMindForCausalLM from dataset.lm_dataset import RLAIFDataset -from torch.optim.lr_scheduler import CosineAnnealingLR +from trainer.trainer_utils import * - -def Logger(content): - if not ddp or dist.get_rank() == 0: - print(content) +warnings.filterwarnings('ignore') def calculate_rewards(prompts, responses, reward_model, reward_tokenizer): """整合所有奖励函数计算总奖励""" - def reasoning_model_reward(rewards): - # 1. 格式奖励(仅针对训练推理模型时使用) pattern = r"^\n.*?\n\n\n.*?\n$" pattern2 = r"^\n.*?\n\n\n\n.*?\n$" - matches_pattern = [re.match(pattern, response, re.S) for response in responses] matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses] format_rewards = [] for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2): - if match_pattern: - format_rewards.append(0.5) - elif match_pattern2: + if match_pattern or match_pattern2: format_rewards.append(0.5) else: format_rewards.append(0.0) rewards += torch.tensor(format_rewards, device=args.device) - # 2. 标记奖励(防止严格奖励稀疏,仅针对训练推理模型时使用) def mark_num(text): reward = 0 - if text.count("") == 1: - reward += 0.25 - if text.count("") == 1: - reward += 0.25 - if text.count("") == 1: - reward += 0.25 - if text.count("") == 1: - reward += 0.25 + if text.count("") == 1: reward += 0.25 + if text.count("") == 1: reward += 0.25 + if text.count("") == 1: reward += 0.25 + if text.count("") == 1: reward += 0.25 return reward mark_rewards = [mark_num(response) for response in responses] @@ -63,12 +53,9 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer): return rewards rewards = torch.zeros(len(responses), device=args.device) - - # 3. 格式奖励 if args.reasoning == 1: - rewards = reasoning_model_reward(rewards) # 训练推理模型时使用 + rewards = reasoning_model_reward(rewards) - # 4. 使用reward model计算奖励 with torch.no_grad(): reward_model_scores = [] batch_size = len(prompts) @@ -105,8 +92,8 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer): return rewards -def grpo_train_epoch(epoch, wandb): - for step, batch in enumerate(train_loader): +def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, start_step=0, wandb=None): + for step, batch in enumerate(loader, start=start_step + 1): prompts = batch['prompt'] # list[str], length B prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False, padding_side="left", add_special_tokens=False).to(args.device) # input_ids: [B, P], attention_mask: [B, P] @@ -115,7 +102,9 @@ def grpo_train_epoch(epoch, wandb): prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:] with torch.no_grad(): - outputs = (model.module if ddp else model).generate( + # DDP 模型需要使用 .module 访问 generate 方法 + model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model + outputs = model_for_gen.generate( **prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8, num_return_sequences=args.num_generations, pad_token_id=tokenizer.pad_token_id) # [B*num_gen, P+R] @@ -161,36 +150,33 @@ def grpo_train_epoch(epoch, wandb): scheduler.step() optimizer.zero_grad() - if step % args.log_interval == 0 or step == iter_per_epoch - 1: + if step % args.log_interval == 0 or step == iters: policy_loss_val = loss.item() avg_reward_val = rewards.mean().item() avg_len_val = completion_mask.sum(dim=1).float().mean().item() current_lr = optimizer.param_groups[0]['lr'] - Logger( - f'Epoch: {epoch}, Step: {step + 1}/{iter_per_epoch}, ' - f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, ' - f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}') + Logger(f'Epoch: {epoch+1}, Step: {step}/{iters}, ' + f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, ' + f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}') - if wandb and (not ddp or dist.get_rank() == 0): - log_dict = { + if wandb and is_main_process(): + wandb.log({ "policy_loss": policy_loss_val, "reward": avg_reward_val, "avg_response_len": avg_len_val, "advantages_mean": advantages.mean().item(), "learning_rate": current_lr - } - wandb.log(log_dict) + }) - if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0): + if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): model.eval() - moe_path = '_moe' if lm_config.use_moe else '' - suffix = 'grpo' - ckp = f'{args.save_dir}/{suffix}_{lm_config.hidden_size}{moe_path}.pth' - - state_dict = model.module.state_dict() if isinstance(model, - torch.nn.parallel.DistributedDataParallel) else model.state_dict() + moe_suffix = '_moe' if lm_config.use_moe else '' + ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' + state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict() torch.save({k: v.half() for k, v in state_dict.items()}, ckp) + lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, + epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler) model.train() del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps @@ -199,119 +185,114 @@ def grpo_train_epoch(epoch, wandb): gc.collect() -def init_model(lm_config): - tokenizer = AutoTokenizer.from_pretrained('../model/') - model = MiniMindForCausalLM(lm_config) - moe_path = '_moe' if lm_config.use_moe else '' - ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth' - if args.reasoning == 1: - ckp = f'{args.save_dir}/reason_{lm_config.hidden_size}{moe_path}.pth' - state_dict = torch.load(ckp, map_location=args.device) - model.load_state_dict(state_dict, strict=False) +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="MiniMind GRPO (Group Relative Policy Optimization)") + parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录") + parser.add_argument('--save_weight', default='grpo', type=str, help="保存权重的前缀名") + parser.add_argument("--epochs", type=int, default=1, help="训练轮数") + parser.add_argument("--batch_size", type=int, default=2, help="batch size") + parser.add_argument("--learning_rate", type=float, default=8e-8, help="初始学习率") + parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备") + parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型") + parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数") + parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数") + parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值") + parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔") + parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔") + parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") + parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") + parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE") + parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度") + parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度") + parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径") + parser.add_argument("--num_generations", type=int, default=8, help="每个prompt生成的样本数") + parser.add_argument("--beta", type=float, default=0.02, help="KL惩罚系数") + parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型,1:推理模型') + parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径") + parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是") + parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb") + parser.add_argument("--wandb_project", type=str, default="MiniMind-GRPO", help="wandb项目名") + args = parser.parse_args() + # ========== 1. 初始化环境和随机种子 ========== + local_rank = init_distributed_mode() + if dist.is_initialized(): args.device = f"cuda:{local_rank}" + setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) + + # ========== 2. 配置目录、模型参数、检查ckp ========== + os.makedirs(args.save_dir, exist_ok=True) + lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, + max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=args.use_moe) + ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None + + # ========== 3. 设置混合精度 ========== + device_type = "cuda" if "cuda" in args.device else "cpu" + dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) + + # ========== 4. 配wandb ========== + wandb = None + if args.use_wandb and is_main_process(): + import swanlab as wandb + wandb_id = ckp_data.get('wandb_id') if ckp_data else None + resume = 'must' if wandb_id else None + wandb_run_name = f"MiniMind-GRPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}" + wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) + + # ========== 5. 初始化模型和数据 ========== + tokenizer = AutoTokenizer.from_pretrained('../model/') + moe_suffix = '_moe' if lm_config.use_moe else '' + base_weight = "reason" if args.reasoning == 1 else "full_sft" + ckp = f'{args.save_dir}/{base_weight}_{lm_config.hidden_size}{moe_suffix}.pth' + state_dict = torch.load(ckp, map_location=args.device) + # Policy模型 + model = MiniMindForCausalLM(lm_config) + model.load_state_dict(state_dict, strict=False) + model = model.to(args.device) + Logger(f'Policy模型总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} M') + # Reference模型 ref_model = MiniMindForCausalLM(lm_config) ref_model.load_state_dict(state_dict, strict=False) ref_model.eval().requires_grad_(False) - - Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万') - model = model.to(args.device) ref_model = ref_model.to(args.device) - - reward_name = "../../internlm2-1_8b-reward" + # Reward模型 reward_model = AutoModel.from_pretrained( - reward_name, - device_map="cuda", - torch_dtype=torch.float16, - trust_remote_code=True, + args.reward_model_path, device_map="cuda", torch_dtype=torch.float16, trust_remote_code=True ).to(args.device).eval().requires_grad_(False) - reward_tokenizer = AutoTokenizer.from_pretrained(reward_name, trust_remote_code=True) - - return model, ref_model, tokenizer, reward_model, reward_tokenizer - - -def init_distributed_mode(): - if not ddp: return - global ddp_local_rank, DEVICE - dist.init_process_group(backend="nccl") - ddp_local_rank = int(os.environ["LOCAL_RANK"]) - DEVICE = f"cuda:{ddp_local_rank}" - torch.cuda.set_device(DEVICE) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--out_dir", type=str, default="../out") - parser.add_argument("--epochs", type=int, default=1) - parser.add_argument("--batch_size", type=int, default=2) - parser.add_argument("--learning_rate", type=float, default=8e-8) - parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") - parser.add_argument("--dtype", type=str, default="bfloat16") - parser.add_argument("--use_wandb", action="store_true") - parser.add_argument("--wandb_project", type=str, default="MiniMind-GRPO") - parser.add_argument("--num_workers", type=int, default=1) - parser.add_argument("--ddp", action="store_true") - parser.add_argument("--accumulation_steps", type=int, default=1) - parser.add_argument("--grad_clip", type=float, default=1.0) - parser.add_argument("--log_interval", type=int, default=1) - parser.add_argument("--save_interval", type=int, default=10) - parser.add_argument('--hidden_size', default=512, type=int) - parser.add_argument('--num_hidden_layers', default=8, type=int) - parser.add_argument('--use_moe', default=False, type=bool) - parser.add_argument('--max_seq_len', default=66, type=int) - parser.add_argument("--max_gen_len", type=int, default=1536) - parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl") - parser.add_argument("--num_generations", type=int, default=8) - parser.add_argument("--beta", type=float, default=0.02) - parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型,1:推理模型') - args = parser.parse_args() - - lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, - max_seq_len=args.max_seq_len + args.max_gen_len, - use_moe=args.use_moe) - args.save_dir = os.path.join(args.out_dir) - os.makedirs(args.save_dir, exist_ok=True) - os.makedirs(args.out_dir, exist_ok=True) - - ctx = nullcontext() if "cuda" not in args.device else torch.amp.autocast('cuda') - ddp = int(os.environ.get("RANK", -1)) != -1 - ddp_local_rank, DEVICE = 0, "cuda:0" - - base_seed = 1337 - torch.manual_seed(base_seed) - torch.cuda.manual_seed(base_seed) - - if ddp: - init_distributed_mode() - args.device = torch.device(DEVICE) - rank = dist.get_rank() - torch.manual_seed(base_seed + rank) - # 同时设置 CUDA 的随机种子 - torch.cuda.manual_seed(base_seed + rank) - - if args.use_wandb and (not ddp or ddp_local_rank == 0): - import swanlab as wandb - - wandb.init(project=args.wandb_project) - else: - wandb = None - - model, ref_model, tokenizer, reward_model, reward_tokenizer = init_model(lm_config=lm_config) + reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True) + # 数据和优化器 train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len) - train_sampler = DistributedSampler(train_ds) if ddp else None - train_loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True, + train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None + optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) + loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler) + iters = len(loader_for_count) + total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs + scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10) + + # ========== 6. 从ckp恢复状态 ========== + start_epoch, start_step = 0, 0 + if ckp_data: + model.load_state_dict(ckp_data['model']) + optimizer.load_state_dict(ckp_data['optimizer']) + scheduler.load_state_dict(ckp_data['scheduler']) + start_epoch = ckp_data['epoch'] + start_step = ckp_data.get('step', 0) + + # ========== 7. DDP包模型 ========== + if dist.is_initialized(): + model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} + model = DistributedDataParallel(model, device_ids=[local_rank]) + + # ========== 8. 开始训练 ========== + for epoch in range(start_epoch, args.epochs): + train_sampler and train_sampler.set_epoch(epoch) + if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点 + batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1) + loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) + Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始') + grpo_train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, reward_model, reward_tokenizer, start_step, wandb) + else: # 默认从头开始 + loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True, drop_last=False, shuffle=(train_sampler is None), num_workers=args.num_workers, sampler=train_sampler) - - optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) - - iter_per_epoch = len(train_loader) - total_optimizer_steps = (iter_per_epoch // args.accumulation_steps) * args.epochs - scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10) - - if ddp: - model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) - - for epoch in range(args.epochs): - train_sampler and train_sampler.set_epoch(epoch) - grpo_train_epoch(epoch, wandb) + grpo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, 0, wandb) diff --git a/trainer/train_lora.py b/trainer/train_lora.py index df9f9ae..b6fc2b0 100644 --- a/trainer/train_lora.py +++ b/trainer/train_lora.py @@ -6,49 +6,39 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) import argparse import time -import math import warnings import torch -from torch import optim, nn import torch.distributed as dist from contextlib import nullcontext +from torch import optim, nn +from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, DistributedSampler -from transformers import AutoTokenizer, AutoModelForCausalLM -from model.model_minimind import MiniMindConfig, MiniMindForCausalLM +from model.model_minimind import MiniMindConfig from dataset.lm_dataset import SFTDataset -from model.model_lora import load_lora, save_lora, apply_lora +from model.model_lora import save_lora, apply_lora +from trainer.trainer_utils import * warnings.filterwarnings('ignore') -# Logger function -def Logger(content): - if not ddp or dist.get_rank() == 0: - print(content) - - -def get_lr(current_step, total_steps, lr): - return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps)) - - -# 代码和full_sft「几乎」一致 -def train_epoch(epoch, wandb): +def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None): loss_fct = nn.CrossEntropyLoss(reduction='none') start_time = time.time() - for step, (X, Y, loss_mask) in enumerate(train_loader): + for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1): X = X.to(args.device) Y = Y.to(args.device) loss_mask = loss_mask.to(args.device) - lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate) + lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate) for param_group in optimizer.param_groups: param_group['lr'] = lr - with ctx: + with autocast_ctx: res = model(X) loss = loss_fct( res.logits.view(-1, res.logits.size(-1)), Y.view(-1) ).view(Y.size()) + loss = (loss * loss_mask).sum() / loss_mask.sum() loss += res.aux_loss loss = loss / args.accumulation_steps @@ -64,146 +54,122 @@ def train_epoch(epoch, wandb): optimizer.zero_grad(set_to_none=True) - if step % args.log_interval == 0 or step == iter_per_epoch - 1: + if step % args.log_interval == 0 or step == iters - 1: spend_time = time.time() - start_time - Logger( - 'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format( - epoch + 1, - args.epochs, - step, - iter_per_epoch, - loss.item() * args.accumulation_steps, - optimizer.param_groups[-1]['lr'], - spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60)) + current_loss = loss.item() * args.accumulation_steps + current_lr = optimizer.param_groups[-1]['lr'] + eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60 + + Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:') + + if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min}) - if (wandb is not None) and (not ddp or dist.get_rank() == 0): - wandb.log({"loss": loss.item() * args.accumulation_steps, - "lr": optimizer.param_groups[-1]['lr'], - "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60}) - - if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0): + if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): model.eval() - lora_save_path = f'{args.save_dir}/lora/{args.lora_name}_{lm_config.hidden_size}.pth' - os.makedirs(os.path.dirname(lora_save_path), exist_ok=True) - # 【区别1】只保存lora权重即可 + lora_save_path = f'{args.save_dir}/{args.lora_name}_{lm_config.hidden_size}.pth' + # LoRA只保存LoRA权重 save_lora(model, lora_save_path) + lm_checkpoint(lm_config, weight=args.lora_name, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') model.train() -def init_model(lm_config): - tokenizer = AutoTokenizer.from_pretrained('../model/') - model = MiniMindForCausalLM(lm_config) - moe_path = '_moe' if lm_config.use_moe else '' - ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth' - state_dict = torch.load(ckp, map_location=args.device) - model.load_state_dict(state_dict, strict=False) - return model.to(args.device), tokenizer - - -def init_distributed_mode(): - if not ddp: return - global ddp_local_rank, DEVICE - - dist.init_process_group(backend="nccl") - ddp_rank = int(os.environ["RANK"]) - ddp_local_rank = int(os.environ["LOCAL_RANK"]) - ddp_world_size = int(os.environ["WORLD_SIZE"]) - DEVICE = f"cuda:{ddp_local_rank}" - torch.cuda.set_device(DEVICE) - - if __name__ == "__main__": - parser = argparse.ArgumentParser(description="MiniMind SFT with LoRA") - parser.add_argument("--out_dir", type=str, default="../out") - parser.add_argument("--epochs", type=int, default=50) - parser.add_argument("--batch_size", type=int, default=32) - parser.add_argument("--learning_rate", type=float, default=1e-4) - parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") - parser.add_argument("--dtype", type=str, default="bfloat16") - parser.add_argument("--use_wandb", action="store_true") - parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA-SFT") - parser.add_argument("--num_workers", type=int, default=1) - parser.add_argument("--ddp", action="store_true") - parser.add_argument("--accumulation_steps", type=int, default=1) - parser.add_argument("--grad_clip", type=float, default=1.0) - parser.add_argument("--warmup_iters", type=int, default=0) - parser.add_argument("--log_interval", type=int, default=10) - parser.add_argument("--save_interval", type=int, default=1) - parser.add_argument('--local_rank', type=int, default=-1) - parser.add_argument('--hidden_size', default=512, type=int) - parser.add_argument('--num_hidden_layers', default=8, type=int) - parser.add_argument('--max_seq_len', default=512, type=int) - parser.add_argument('--use_moe', default=False, type=bool) - parser.add_argument("--data_path", type=str, default="../dataset/lora_identity.jsonl") - parser.add_argument("--lora_name", type=str, default="lora_identity", help="根据任务保存成lora_(英文/医学/心理...)") + parser = argparse.ArgumentParser(description="MiniMind LoRA Fine-tuning") + parser.add_argument("--save_dir", type=str, default="../out/lora", help="模型保存目录") + parser.add_argument("--lora_name", type=str, default="lora_identity", help="LoRA权重名称(如lora_identity/lora_medical等)") + parser.add_argument("--epochs", type=int, default=50, help="训练轮数") + parser.add_argument("--batch_size", type=int, default=32, help="batch size") + parser.add_argument("--learning_rate", type=float, default=1e-4, help="初始学习率") + parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备") + parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型") + parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数") + parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数") + parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值") + parser.add_argument("--log_interval", type=int, default=10, help="日志打印间隔") + parser.add_argument("--save_interval", type=int, default=1, help="模型保存间隔") + parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") + parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") + parser.add_argument('--max_seq_len', default=512, type=int, help="训练的最大截断长度") + parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE") + parser.add_argument("--data_path", type=str, default="../dataset/lora_identity.jsonl", help="LoRA训练数据路径") + parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练,默认full_sft") + parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是") + parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb") + parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA", help="wandb项目名") args = parser.parse_args() - lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, - use_moe=args.use_moe) - args.save_dir = os.path.join(args.out_dir) + # ========== 1. 初始化环境和随机种子 ========== + local_rank = init_distributed_mode() + if dist.is_initialized(): args.device = f"cuda:{local_rank}" + setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) + + # ========== 2. 配置目录、模型参数、检查ckp ========== os.makedirs(args.save_dir, exist_ok=True) - os.makedirs(args.out_dir, exist_ok=True) - tokens_per_iter = args.batch_size * args.max_seq_len + lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe) + ckp_data = lm_checkpoint(lm_config, weight=args.lora_name, save_dir='../checkpoints') if args.from_resume==1 else None + + # ========== 3. 设置混合精度 ========== device_type = "cuda" if "cuda" in args.device else "cpu" - - ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast() - ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run? - ddp_local_rank, DEVICE = 0, "cuda:0" - base_seed = 1337 - torch.manual_seed(base_seed) - torch.cuda.manual_seed(base_seed) - - if ddp: - init_distributed_mode() - args.device = torch.device(DEVICE) - rank = dist.get_rank() - torch.manual_seed(base_seed + rank) - # 同时设置 CUDA 的随机种子 - torch.cuda.manual_seed(base_seed + rank) - - args.wandb_run_name = f"MiniMind-Lora-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}" - if args.use_wandb and (not ddp or ddp_local_rank == 0): + dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) + + # ========== 4. 配wandb ========== + wandb = None + if args.use_wandb and is_main_process(): import swanlab as wandb - - wandb.init(project=args.wandb_project, name=args.wandb_run_name) - else: - wandb = None - - model, tokenizer = init_model(lm_config) + wandb_id = ckp_data.get('wandb_id') if ckp_data else None + resume = 'must' if wandb_id else None + wandb_run_name = f"MiniMind-LoRA-{args.lora_name}-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}" + wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) + + # ========== 5. 定义模型、应用LoRA、冻结非LoRA参数 ========== + model, tokenizer = init_model(lm_config, args.from_weight) apply_lora(model) - - total_params = sum(p.numel() for p in model.parameters()) # 总参数数量 - lora_params_count = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name) # LoRA 参数数量 - if not ddp or dist.get_rank() == 0: - print(f"LLM 总参数量: {total_params}") - print(f"LoRA 参数量: {lora_params_count}") - print(f"LoRA 参数占比: {lora_params_count / total_params * 100:.2f}%") - - for name, param in model.named_parameters(): - if 'lora' not in name: - param.requires_grad = False + + # 统计参数 + total_params = sum(p.numel() for p in model.parameters()) + lora_params_count = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name) + Logger(f"LLM 总参数量: {total_params / 1e6:.3f} M") + Logger(f"LoRA 参数量: {lora_params_count / 1e6:.3f} M") + Logger(f"LoRA 参数占比: {lora_params_count / total_params * 100:.2f}%") + + # 冻结非LoRA参数,收集LoRA参数 lora_params = [] for name, param in model.named_parameters(): if 'lora' in name: + param.requires_grad = True lora_params.append(param) - - # 只对 LoRA 参数进行优化 - optimizer = optim.AdamW(lora_params, lr=args.learning_rate) + else: + param.requires_grad = False + + # ========== 6. 定义数据和优化器 ========== train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len) - train_sampler = DistributedSampler(train_ds) if ddp else None - train_loader = DataLoader( - train_ds, - batch_size=args.batch_size, - pin_memory=True, - drop_last=False, - shuffle=(train_sampler is None), - num_workers=args.num_workers, - sampler=train_sampler - ) - - scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) - iter_per_epoch = len(train_loader) - - for epoch in range(args.epochs): + train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None + scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) + optimizer = optim.AdamW(lora_params, lr=args.learning_rate) + + # ========== 7. 从ckp恢复状态 ========== + start_epoch, start_step = 0, 0 + if ckp_data: + model.load_state_dict(ckp_data['model'], strict=False) + optimizer.load_state_dict(ckp_data['optimizer']) + scaler.load_state_dict(ckp_data['scaler']) + start_epoch = ckp_data['epoch'] + start_step = ckp_data.get('step', 0) + + # ========== 8. DDP包模型 ========== + if dist.is_initialized(): + model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} + model = DistributedDataParallel(model, device_ids=[local_rank]) + + # ========== 9. 开始训练 ========== + for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) - train_epoch(epoch, wandb) + if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点 + batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1) + loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) + Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始') + train_epoch(epoch, loader, len(loader) + start_step + 1, lora_params, start_step, wandb) + else: # 默认从头开始 + loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True) + train_epoch(epoch, loader, len(loader), lora_params, 0, wandb) diff --git a/trainer/train_ppo.py b/trainer/train_ppo.py index 7652c22..27a82b7 100644 --- a/trainer/train_ppo.py +++ b/trainer/train_ppo.py @@ -6,27 +6,43 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) import argparse import re +import warnings import torch import torch.distributed as dist import torch.nn.functional as F +from transformers import AutoTokenizer +from contextlib import nullcontext from torch import optim, nn from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, DistributedSampler -from transformers import AutoTokenizer, AutoModel -from model.model_minimind import MiniMindConfig, MiniMindForCausalLM -from dataset.lm_dataset import RLAIFDataset from torch.nn.utils import clip_grad_norm_ from torch.optim.lr_scheduler import CosineAnnealingLR +from transformers import AutoModel +from model.model_minimind import MiniMindConfig, MiniMindForCausalLM +from dataset.lm_dataset import RLAIFDataset +from trainer.trainer_utils import * + +warnings.filterwarnings('ignore') -def Logger(content): - if not ddp or dist.get_rank() == 0: - print(content) +# 自定义的Critic模型,继承自MiniMindLM +class CriticModel(MiniMindForCausalLM): + def __init__(self, params): + super().__init__(params) + # 替换lm_head为输出单一价值的线性层 + self.value_head = nn.Linear(params.hidden_size, 1) + + def forward(self, input_ids=None, attention_mask=None, **kwargs): + # 使用基础模型获取隐藏状态 + outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) + hidden_states = self.model.norm(outputs[0]) + # 使用value_head获取价值估计 + values = self.value_head(hidden_states).squeeze(-1) + return values def calculate_rewards(prompts, responses, reward_model, reward_tokenizer): """整合所有奖励函数计算总奖励""" - def reasoning_model_reward(rewards): # 1. 格式奖励(仅针对训练推理模型时使用) pattern = r"^\n.*?\n\n\n.*?\n$" @@ -66,7 +82,7 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer): # 格式奖励 if args.reasoning == 1: - rewards = reasoning_model_reward(rewards) # 训练推理模型时使用 + rewards = reasoning_model_reward(rewards) # 使用reward model计算整个response的奖励 with torch.no_grad(): @@ -91,7 +107,6 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer): tmp_chat = messages + [{"role": "assistant", "content": answer_content}] answer_score = reward_model.get_score(reward_tokenizer, tmp_chat) answer_score = max(min(answer_score, scale), -scale) - score = score * 0.4 + answer_score * 0.6 reward_model_scores.append(score) @@ -101,19 +116,20 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer): return rewards -def ppo_train_epoch(epoch: int, wandb_run, old_actor_model, ref_model, actor_scheduler, critic_scheduler): +def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, start_step=0, wandb=None): actor_model.train() critic_model.train() - is_master = (not ddp) or dist.get_rank() == 0 - for step, batch in enumerate(train_loader): + for step, batch in enumerate(loader, start=start_step + 1): prompts = batch["prompt"] # list[str], length B enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=args.max_seq_len).to(args.device) # input_ids: [B, P], attention_mask: [B, P] prompt_lengths = enc.attention_mask.sum(dim=1) # [B] with torch.no_grad(): - gen_out = actor_model.generate( + # DDP 模型需要使用 .module 访问 generate 方法 + model_for_gen = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model + gen_out = model_for_gen.generate( input_ids=enc.input_ids, attention_mask=enc.attention_mask, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id) # [B, P+R] @@ -164,7 +180,7 @@ def ppo_train_epoch(epoch: int, wandb_run, old_actor_model, ref_model, actor_sch actor_optimizer.zero_grad() critic_optimizer.zero_grad() - if is_master: + if is_main_process(): response_ids = gen_out[:, enc.input_ids.shape[1]:] is_eos = (response_ids == tokenizer.eos_token_id) eos_indices = torch.argmax(is_eos.int(), dim=1) @@ -181,8 +197,8 @@ def ppo_train_epoch(epoch: int, wandb_run, old_actor_model, ref_model, actor_sch actor_lr = actor_optimizer.param_groups[0]['lr'] critic_lr = critic_optimizer.param_groups[0]['lr'] - if wandb_run is not None: - wandb_run.log({ + if wandb is not None: + wandb.log({ "actor_loss": actor_loss_val, "critic_loss": critic_loss_val, "reward": reward_val, @@ -192,183 +208,158 @@ def ppo_train_epoch(epoch: int, wandb_run, old_actor_model, ref_model, actor_sch "actor_lr": actor_lr, }) - Logger(f"Epoch: {epoch}, Step: {step + 1}/{len(train_loader)}, " + Logger(f"Epoch: {epoch+1}, Step: {step}/{iters}, " f"Actor Loss: {actor_loss_val:.6f}, Critic Loss: {critic_loss_val:.6f}, " f"Reward: {reward_val:.6f}, KL: {kl_val:.6f}, KL_ref: {kl_ref_val:.6f}, " f"Avg Response Len: {avg_len_val:.2f}, Actor LR: {actor_lr:.2e}, Critic LR: {critic_lr:.2e}") if (step + 1) % args.update_old_actor_freq == 0: - state_dict = actor_model.module.state_dict() if isinstance(actor_model, torch.nn.parallel.DistributedDataParallel) else actor_model.state_dict() + state_dict = actor_model.module.state_dict() if isinstance(actor_model, DistributedDataParallel) else actor_model.state_dict() old_actor_model.load_state_dict({k: v.detach().cpu() for k, v in state_dict.items()}) old_actor_model.to(args.device) - if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0): + if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): actor_model.eval() - moe_path = '_moe' if lm_config.use_moe else '' - ckp = f'{args.save_dir}/ppo_actor_{lm_config.hidden_size}{moe_path}.pth' - - if isinstance(actor_model, torch.nn.parallel.DistributedDataParallel): - state_dict = actor_model.module.state_dict() - else: - state_dict = actor_model.state_dict() - - state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存 - torch.save(state_dict, ckp) + moe_suffix = '_moe' if lm_config.use_moe else '' + ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' + actor_state = actor_model.module.state_dict() if isinstance(actor_model, DistributedDataParallel) else actor_model.state_dict() + torch.save({k: v.half() for k, v in actor_state.items()}, ckp) + + # 使用 lm_checkpoint 保存完整状态(包括 critic) + lm_checkpoint(lm_config, weight=args.save_weight, model=actor_model, optimizer=actor_optimizer, + epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', + scheduler=actor_scheduler, critic_model=critic_model, + critic_optimizer=critic_optimizer, critic_scheduler=critic_scheduler) actor_model.train() -# 自定义的Critic模型,继承自MiniMindLM -class CriticModel(MiniMindForCausalLM): - def __init__(self, params): - super().__init__(params) - # 替换lm_head为输出单一价值的线性层 - self.value_head = nn.Linear(params.hidden_size, 1) +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="MiniMind PPO (Proximal Policy Optimization)") + parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录") + parser.add_argument('--save_weight', default='ppo_actor', type=str, help="保存权重的前缀名") + parser.add_argument("--epochs", type=int, default=1, help="训练轮数") + parser.add_argument("--batch_size", type=int, default=2, help="batch size") + parser.add_argument("--learning_rate", type=float, default=8e-8, help="Actor学习率") + parser.add_argument("--critic_learning_rate", type=float, default=8e-8, help="Critic学习率") + parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备") + parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型") + parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数") + parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数") + parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值") + parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔") + parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔") + parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") + parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") + parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE") + parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度") + parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度") + parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径") + parser.add_argument("--clip_epsilon", type=float, default=0.1, help="PPO裁剪参数") + parser.add_argument("--vf_coef", type=float, default=0.5, help="Value function系数") + parser.add_argument("--kl_coef", type=float, default=0.02, help="KL散度惩罚系数") + parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型,1:推理模型') + parser.add_argument("--update_old_actor_freq", type=int, default=4, help="更新old_actor_model的频率") + parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径") + parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是") + parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb") + parser.add_argument("--wandb_project", type=str, default="MiniMind-PPO", help="wandb项目名") + args = parser.parse_args() - def forward(self, input_ids=None, attention_mask=None, **kwargs): - # 使用基础模型获取隐藏状态 - outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) - # self.model 返回的是一个元组,第一个元素是 last_hidden_state - hidden_states = self.model.norm(outputs[0]) - # 使用value_head获取价值估计 - values = self.value_head(hidden_states).squeeze(-1) - return values - - -def init_model(lm_config): + # ========== 1. 初始化环境和随机种子 ========== + local_rank = init_distributed_mode() + if dist.is_initialized(): args.device = f"cuda:{local_rank}" + setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) + + # ========== 2. 配置目录、模型参数、检查ckp ========== + os.makedirs(args.save_dir, exist_ok=True) + lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe) + ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None + + # ========== 3. 设置混合精度 ========== + device_type = "cuda" if "cuda" in args.device else "cpu" + dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) + + # ========== 4. 配wandb ========== + wandb = None + if args.use_wandb and is_main_process(): + import swanlab as wandb + wandb_id = ckp_data.get('wandb_id') if ckp_data else None + resume = 'must' if wandb_id else None + wandb_run_name = f"MiniMind-PPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}" + wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) + + # ========== 5. 初始化模型和数据 ========== tokenizer = AutoTokenizer.from_pretrained('../model/', padding_side='left') - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - moe_path = '_moe' if lm_config.use_moe else '' - ckp = f'{args.save_dir}/{"reason" if args.reasoning == 1 else "full_sft"}_{lm_config.hidden_size}{moe_path}.pth' + moe_suffix = '_moe' if lm_config.use_moe else '' + base_weight = "reason" if args.reasoning == 1 else "full_sft" + ckp = f'{args.save_dir}/{base_weight}_{lm_config.hidden_size}{moe_suffix}.pth' state_dict = torch.load(ckp, map_location=args.device) - + # Actor模型 actor_model = MiniMindForCausalLM(lm_config) actor_model.load_state_dict(state_dict, strict=False) actor_model = actor_model.to(args.device) - + Logger(f'Actor模型总参数量:{sum(p.numel() for p in actor_model.parameters() if p.requires_grad) / 1e6:.3f} M') + # Old Actor模型 old_actor_model = MiniMindForCausalLM(lm_config) old_actor_model.load_state_dict(state_dict, strict=False) old_actor_model = old_actor_model.eval().requires_grad_(False).to(args.device) - + # Reference模型 ref_model = MiniMindForCausalLM(lm_config) ref_model.load_state_dict(state_dict, strict=False) ref_model = ref_model.eval().requires_grad_(False).to(args.device) - + # Critic模型 critic_model = CriticModel(lm_config) critic_model.load_state_dict(state_dict, strict=False) critic_model = critic_model.to(args.device) - - reward_name = "../../internlm2-1_8b-reward" + Logger(f'Critic模型总参数量:{sum(p.numel() for p in critic_model.parameters() if p.requires_grad) / 1e6:.3f} M') + # Reward模型 reward_model = AutoModel.from_pretrained( - reward_name, device_map="cuda", torch_dtype=torch.float32, trust_remote_code=True + args.reward_model_path, device_map="cuda", torch_dtype=torch.float32, trust_remote_code=True ).to(args.device).eval().requires_grad_(False) - reward_tokenizer = AutoTokenizer.from_pretrained(reward_name, trust_remote_code=True) - - Logger(f'Actor模型总参数量:{sum(p.numel() for p in actor_model.parameters() if p.requires_grad) / 1e6:.3f} 百万') - Logger(f'Critic模型总参数量:{sum(p.numel() for p in critic_model.parameters() if p.requires_grad) / 1e6:.3f} 百万') - - return actor_model, old_actor_model, ref_model, critic_model, reward_model, tokenizer, reward_tokenizer - - -def init_distributed_mode(): - if not ddp: return - global ddp_local_rank, DEVICE - dist.init_process_group(backend="nccl") - ddp_local_rank = int(os.environ["LOCAL_RANK"]) - DEVICE = f"cuda:{ddp_local_rank}" - torch.cuda.set_device(DEVICE) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--out_dir", type=str, default="../out") - parser.add_argument("--epochs", type=int, default=1) - parser.add_argument("--batch_size", type=int, default=2) - parser.add_argument("--learning_rate", type=float, default=8e-8) - parser.add_argument("--critic_learning_rate", type=float, default=8e-8) - parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") - parser.add_argument("--dtype", type=str, default="bfloat16") - parser.add_argument("--use_wandb", action="store_true") - parser.add_argument("--wandb_project", type=str, default="MiniMind-PPO") - parser.add_argument("--num_workers", type=int, default=1) - parser.add_argument("--ddp", action="store_true") - parser.add_argument("--accumulation_steps", type=int, default=1) - parser.add_argument("--grad_clip", type=float, default=1.0) - parser.add_argument("--log_interval", type=int, default=1) - parser.add_argument("--save_interval", type=int, default=10) - parser.add_argument('--hidden_size', default=512, type=int) - parser.add_argument('--num_hidden_layers', default=8, type=int) - parser.add_argument('--use_moe', default=False, type=bool) - parser.add_argument('--max_seq_len', default=66, type=int) - parser.add_argument("--max_gen_len", type=int, default=1536) - parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl") - parser.add_argument("--clip_epsilon", type=float, default=0.1) - parser.add_argument("--vf_coef", type=float, default=0.5) - parser.add_argument("--kl_coef", type=float, default=0.02, help="KL散度惩罚系数") - parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型,1:推理模型') - parser.add_argument("--update_old_actor_freq", type=int, default=4, help="频率:每处理n个batch后更新old_actor_model") - args = parser.parse_args() - - lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, - use_moe=args.use_moe) - args.save_dir = os.path.join(args.out_dir) - os.makedirs(args.save_dir, exist_ok=True) - os.makedirs(args.out_dir, exist_ok=True) - - ddp = int(os.environ.get("RANK", -1)) != -1 - ddp_local_rank, DEVICE = 0, "cuda:0" - base_seed = 1337 - torch.manual_seed(base_seed) - torch.cuda.manual_seed(base_seed) - - if ddp: - init_distributed_mode() - args.device = torch.device(DEVICE) - rank = dist.get_rank() - torch.manual_seed(base_seed + rank) - # 同时设置 CUDA 的随机种子 - torch.cuda.manual_seed(base_seed + rank) - - if args.use_wandb and (not ddp or ddp_local_rank == 0): - import swanlab as wandb - - wandb.init(project=args.wandb_project) - else: - wandb = None - - # 初始化所有模型 - actor_model, old_actor_model, ref_model, critic_model, reward_model, tokenizer, reward_tokenizer = init_model(lm_config=lm_config) - - # 准备数据集和数据加载器 + reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True) + # 数据和优化器 train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=(args.max_seq_len + args.max_gen_len)) - train_sampler = DistributedSampler(train_ds) if ddp else None - train_loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True, - drop_last=False, shuffle=(train_sampler is None), - num_workers=args.num_workers, sampler=train_sampler) - - # 初始化优化器 + train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None actor_optimizer = optim.AdamW(actor_model.parameters(), lr=args.learning_rate) critic_optimizer = optim.AdamW(critic_model.parameters(), lr=args.critic_learning_rate) - - iter_per_epoch = len(train_loader) - total_optimizer_steps = (iter_per_epoch // args.accumulation_steps) * args.epochs + loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler) + iters = len(loader_for_count) + total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs actor_scheduler = CosineAnnealingLR(actor_optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10) - critic_scheduler = CosineAnnealingLR(critic_optimizer, T_max=total_optimizer_steps, - eta_min=args.critic_learning_rate / 10) - - # 如果使用分布式训练,包装模型 - if ddp: + critic_scheduler = CosineAnnealingLR(critic_optimizer, T_max=total_optimizer_steps, eta_min=args.critic_learning_rate / 10) + + # ========== 6. 从ckp恢复状态 ========== + start_epoch, start_step = 0, 0 + if ckp_data: + actor_model.load_state_dict(ckp_data['model']) + critic_model.load_state_dict(ckp_data['critic_model']) + actor_optimizer.load_state_dict(ckp_data['optimizer']) + critic_optimizer.load_state_dict(ckp_data['critic_optimizer']) + actor_scheduler.load_state_dict(ckp_data['scheduler']) + critic_scheduler.load_state_dict(ckp_data['critic_scheduler']) + start_epoch = ckp_data['epoch'] + start_step = ckp_data.get('step', 0) + + # ========== 7. DDP包模型 ========== + if dist.is_initialized(): actor_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} critic_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - actor_model = DistributedDataParallel(actor_model, device_ids=[ddp_local_rank]) - critic_model = DistributedDataParallel(critic_model, device_ids=[ddp_local_rank]) - # old_actor_model 不需要DDP包装,因为它只在主进程上用于计算,并且不进行梯度更新 + actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank]) + critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank]) old_actor_model.to(args.device) - - for epoch in range(args.epochs): + + # ========== 8. 开始训练 ========== + for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) - ppo_train_epoch(epoch, wandb, old_actor_model, ref_model, actor_scheduler, critic_scheduler) - - if ddp: - dist.destroy_process_group() + if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点 + batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1) + loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) + Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始') + ppo_train_epoch(epoch, loader, len(loader) + start_step + 1, old_actor_model, ref_model, + actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, start_step, wandb) + else: # 默认从头开始 + loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), + sampler=train_sampler, num_workers=args.num_workers, pin_memory=True) + ppo_train_epoch(epoch, loader, len(loader), old_actor_model, ref_model, + actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, 0, wandb) diff --git a/trainer/train_pretrain.py b/trainer/train_pretrain.py index 36a3cd8..18cc445 100644 --- a/trainer/train_pretrain.py +++ b/trainer/train_pretrain.py @@ -6,48 +6,38 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) import argparse import time -import math import warnings import torch import torch.distributed as dist +from contextlib import nullcontext from torch import optim, nn from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, DistributedSampler -from contextlib import nullcontext -from transformers import AutoTokenizer -from model.model_minimind import MiniMindConfig, MiniMindForCausalLM +from model.model_minimind import MiniMindConfig from dataset.lm_dataset import PretrainDataset +from trainer.trainer_utils import * warnings.filterwarnings('ignore') -def Logger(content): - if not ddp or dist.get_rank() == 0: - print(content) - - -def get_lr(current_step, total_steps, lr): - return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps)) - - -def train_epoch(epoch, wandb): +def train_epoch(epoch, loader, iters, start_step=0, wandb=None): loss_fct = nn.CrossEntropyLoss(reduction='none') start_time = time.time() - for step, (X, Y, loss_mask) in enumerate(train_loader): + for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1): X = X.to(args.device) Y = Y.to(args.device) loss_mask = loss_mask.to(args.device) - - lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate) + lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate) for param_group in optimizer.param_groups: param_group['lr'] = lr - with ctx: + with autocast_ctx: res = model(X) loss = loss_fct( res.logits.view(-1, res.logits.size(-1)), Y.view(-1) ).view(Y.size()) + loss = (loss * loss_mask).sum() / loss_mask.sum() loss += res.aux_loss loss = loss / args.accumulation_steps @@ -63,139 +53,108 @@ def train_epoch(epoch, wandb): optimizer.zero_grad(set_to_none=True) - if step % args.log_interval == 0 or step == iter_per_epoch - 1: + if step % args.log_interval == 0 or step == iters - 1: spend_time = time.time() - start_time - Logger( - 'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format( - epoch + 1, - args.epochs, - step, - iter_per_epoch, - loss.item() * args.accumulation_steps, - optimizer.param_groups[-1]['lr'], - spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60)) + current_loss = loss.item() * args.accumulation_steps + current_lr = optimizer.param_groups[-1]['lr'] + eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60 + + Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:') + + if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min}) - if (wandb is not None) and (not ddp or dist.get_rank() == 0): - wandb.log({"loss": loss.item() * args.accumulation_steps, - "lr": optimizer.param_groups[-1]['lr'], - "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60}) - - if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0): + if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): model.eval() - moe_path = '_moe' if lm_config.use_moe else '' - ckp = f'{args.save_dir}/pretrain_{lm_config.hidden_size}{moe_path}.pth' - + moe_suffix = '_moe' if lm_config.use_moe else '' + ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' if isinstance(model, torch.nn.parallel.DistributedDataParallel): state_dict = model.module.state_dict() else: state_dict = model.state_dict() - state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存 torch.save(state_dict, ckp) + lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') model.train() -def init_model(lm_config): - tokenizer = AutoTokenizer.from_pretrained('../model/') - model = MiniMindForCausalLM(lm_config).to(args.device) - Logger(f'LLM可训练总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万') - return model, tokenizer - - -def init_distributed_mode(): - if not ddp: return - global ddp_local_rank, DEVICE - - dist.init_process_group(backend="nccl") - ddp_rank = int(os.environ["RANK"]) - ddp_local_rank = int(os.environ["LOCAL_RANK"]) - ddp_world_size = int(os.environ["WORLD_SIZE"]) - DEVICE = f"cuda:{ddp_local_rank}" - torch.cuda.set_device(DEVICE) - - -# torchrun --nproc_per_node 2 1-pretrain.py if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind Pretraining") - parser.add_argument("--out_dir", type=str, default="../out") - # 若要以最快速度实现zero则epochs设置为1轮;否则应当利用有限的数据训练2~6个epochs。 - parser.add_argument("--epochs", type=int, default=1) - parser.add_argument("--batch_size", type=int, default=32) - parser.add_argument("--learning_rate", type=float, default=5e-4) - parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") - parser.add_argument("--dtype", type=str, default="bfloat16") - parser.add_argument("--use_wandb", action="store_true") - parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain") - parser.add_argument("--num_workers", type=int, default=1) - parser.add_argument("--ddp", action="store_true") - parser.add_argument("--accumulation_steps", type=int, default=8) - parser.add_argument("--grad_clip", type=float, default=1.0) - parser.add_argument("--warmup_iters", type=int, default=0) - parser.add_argument("--log_interval", type=int, default=100) - parser.add_argument("--save_interval", type=int, default=100) - parser.add_argument('--local_rank', type=int, default=-1) - parser.add_argument('--hidden_size', default=512, type=int) - parser.add_argument('--num_hidden_layers', default=8, type=int) - parser.add_argument('--max_seq_len', default=512, type=int) - parser.add_argument('--use_moe', default=False, type=bool) - parser.add_argument("--data_path", type=str, default="../dataset/pretrain_hq.jsonl") + parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录") + parser.add_argument('--save_weight', default='pretrain', type=str, help="保存权重的前缀名") + parser.add_argument("--epochs", type=int, default=1, help="训练轮数(建议1轮zero或2-6轮充分训练)") + parser.add_argument("--batch_size", type=int, default=32, help="batch size") + parser.add_argument("--learning_rate", type=float, default=5e-4, help="初始学习率") + parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备") + parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型") + parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数") + parser.add_argument("--accumulation_steps", type=int, default=8, help="梯度累积步数") + parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值") + parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔") + parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔") + parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") + parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") + parser.add_argument('--max_seq_len', default=512, type=int, help="训练的最大截断长度") + parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE") + parser.add_argument("--data_path", type=str, default="../dataset/pretrain_hq.jsonl", help="预训练数据路径") + parser.add_argument('--from_weight', default='none', type=str, help="基于哪个权重训练,为none则从头开始") + parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是") + parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb") + parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain", help="wandb项目名") args = parser.parse_args() - lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, - use_moe=args.use_moe) - args.save_dir = os.path.join(args.out_dir) + # ========== 1. 初始化环境和随机种子 ========== + local_rank = init_distributed_mode() + if dist.is_initialized(): args.device = f"cuda:{local_rank}" + setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) + + # ========== 2. 配置目录、模型参数、检查ckp ========== os.makedirs(args.save_dir, exist_ok=True) - os.makedirs(args.out_dir, exist_ok=True) - tokens_per_iter = args.batch_size * args.max_seq_len + lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe) + ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None + + # ========== 3. 设置混合精度 ========== device_type = "cuda" if "cuda" in args.device else "cpu" - - args.wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}" - - ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast() - - ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run? - ddp_local_rank, DEVICE = 0, "cuda:0" - - base_seed = 1337 - torch.manual_seed(base_seed) - torch.cuda.manual_seed(base_seed) - - if ddp: - init_distributed_mode() - args.device = torch.device(DEVICE) - rank = dist.get_rank() - torch.manual_seed(base_seed + rank) - # 同时设置 CUDA 的随机种子 - torch.cuda.manual_seed(base_seed + rank) - - if args.use_wandb and (not ddp or ddp_local_rank == 0): + dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) + + # ========== 4. 配wandb ========== + wandb = None + if args.use_wandb and is_main_process(): import swanlab as wandb - - wandb.init(project=args.wandb_project, name=args.wandb_run_name) - else: - wandb = None - - model, tokenizer = init_model(lm_config) + wandb_id = ckp_data.get('wandb_id') if ckp_data else None + resume = 'must' if wandb_id else None + wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}" + wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) + + # ========== 5. 定义模型、数据、优化器 ========== + model, tokenizer = init_model(lm_config, args.from_weight) train_ds = PretrainDataset(args.data_path, tokenizer, max_length=args.max_seq_len) - train_sampler = DistributedSampler(train_ds) if ddp else None - train_loader = DataLoader( - train_ds, - batch_size=args.batch_size, - pin_memory=True, - drop_last=False, - shuffle=(train_sampler is None), - num_workers=args.num_workers, - sampler=train_sampler - ) - - scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) + train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None + scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) - - if ddp: + + # ========== 6. 从ckp恢复状态 ========== + start_epoch, start_step = 0, 0 + if ckp_data: + model.load_state_dict(ckp_data['model']) + optimizer.load_state_dict(ckp_data['optimizer']) + scaler.load_state_dict(ckp_data['scaler']) + start_epoch = ckp_data['epoch'] + start_step = ckp_data.get('step', 0) + + # ========== 7. DDP包模型 ========== + if dist.is_initialized(): model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) - - iter_per_epoch = len(train_loader) - for epoch in range(args.epochs): + model = DistributedDataParallel(model, device_ids=[local_rank]) + + # ========== 8. 开始训练 ========== + for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) - train_epoch(epoch, wandb) + if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点 + batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1) + loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) + Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始') + train_epoch(epoch, loader, len(loader) + start_step + 1, start_step, wandb) + else: # 默认从头开始 + loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True) + train_epoch(epoch, loader, len(loader), 0, wandb) diff --git a/trainer/train_spo.py b/trainer/train_spo.py index e13e741..74dc72c 100755 --- a/trainer/train_spo.py +++ b/trainer/train_spo.py @@ -3,31 +3,35 @@ import sys __package__ = "trainer" sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + import argparse -import time import re import gc +import warnings import torch -from contextlib import nullcontext import torch.distributed as dist +from transformers import AutoTokenizer +from contextlib import nullcontext from torch import optim from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, DistributedSampler -from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel +from torch.optim.lr_scheduler import CosineAnnealingLR +from transformers import AutoModel from model.model_minimind import MiniMindConfig, MiniMindForCausalLM from dataset.lm_dataset import RLAIFDataset -from torch.optim.lr_scheduler import CosineAnnealingLR -from collections import defaultdict +from trainer.trainer_utils import * + +warnings.filterwarnings('ignore') class AutoAdaptiveValueTracker: + """SPO自适应价值追踪器""" def __init__(self, rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96): self.rho_mode = rho_mode self.rho_const = rho_const self.D_half = D_half self.clip_lower = clip_lower self.clip_upper = clip_upper - # Stable initialization following N_init = 1/(1-clip_lower) N_init = 1.0 / (1.0 - self.clip_lower) self.alpha = 0.5 * N_init self.beta = 0.5 * N_init @@ -62,43 +66,28 @@ class AutoAdaptiveValueTracker: return rho -def Logger(content): - if not ddp or dist.get_rank() == 0: - print(content) - - def calculate_rewards(prompts, responses, reward_model, reward_tokenizer): """整合所有奖励函数计算总奖励""" - def reasoning_model_reward(rewards): - # 1. 格式奖励(仅针对训练推理模型时使用) pattern = r"^\n.*?\n\n\n.*?\n$" pattern2 = r"^\n.*?\n\n\n\n.*?\n$" - matches_pattern = [re.match(pattern, response, re.S) for response in responses] matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses] format_rewards = [] for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2): - if match_pattern: - format_rewards.append(0.5) - elif match_pattern2: + if match_pattern or match_pattern2: format_rewards.append(0.5) else: format_rewards.append(0.0) rewards += torch.tensor(format_rewards, device=args.device) - # 2. 标记奖励(防止严格奖励稀疏,仅针对训练推理模型时使用) def mark_num(text): reward = 0 - if text.count("") == 1: - reward += 0.25 - if text.count("") == 1: - reward += 0.25 - if text.count("") == 1: - reward += 0.25 - if text.count("") == 1: - reward += 0.25 + if text.count("") == 1: reward += 0.25 + if text.count("") == 1: reward += 0.25 + if text.count("") == 1: reward += 0.25 + if text.count("") == 1: reward += 0.25 return reward mark_rewards = [mark_num(response) for response in responses] @@ -106,12 +95,9 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer): return rewards rewards = torch.zeros(len(responses), device=args.device) - - # 3. 格式奖励 if args.reasoning == 1: - rewards = reasoning_model_reward(rewards) # 训练推理模型时使用 + rewards = reasoning_model_reward(rewards) - # 4. 使用reward model计算奖励 with torch.no_grad(): reward_model_scores = [] scale = 3.0 @@ -142,8 +128,8 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer): return rewards -def spo_train_epoch(epoch, wandb, value_tracker): - for step, batch in enumerate(train_loader): +def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, value_tracker, start_step=0, wandb=None): + for step, batch in enumerate(loader, start=start_step + 1): prompts = batch['prompt'] # list[str], length B prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False, padding_side="left", add_special_tokens=False).to(args.device) # input_ids: [B, P], attention_mask: [B, P] @@ -152,7 +138,9 @@ def spo_train_epoch(epoch, wandb, value_tracker): prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:] with torch.no_grad(): - outputs = (model.module if ddp else model).generate( + # DDP 模型需要使用 .module 访问 generate 方法 + model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model + outputs = model_for_gen.generate( **prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8, num_return_sequences=1, pad_token_id=tokenizer.pad_token_id) # [B, P+R] @@ -205,42 +193,38 @@ def spo_train_epoch(epoch, wandb, value_tracker): scheduler.step() optimizer.zero_grad() - if step % args.log_interval == 0 or step == iter_per_epoch - 1: + if step % args.log_interval == 0 or step == iters: policy_loss_val = loss.item() avg_reward_val = rewards.mean().item() avg_len_val = completion_mask.sum(dim=1).float().mean().item() - # average kl over valid tokens for logging kl_val = ((per_token_kl * completion_mask).sum() / (completion_mask.sum() + 1e-8)).item() avg_baseline_val = baselines.mean().item() current_lr = optimizer.param_groups[0]['lr'] - Logger( - f'Epoch: {epoch}, Step: {step + 1}/{iter_per_epoch}, ' - f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, ' - f'Baseline: {avg_baseline_val:.4f}, KL: {kl_val:.4f}, Rho: {rho:.4f}, Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}') + Logger(f'Epoch: {epoch+1}, Step: {step}/{iters}, ' + f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, ' + f'Baseline: {avg_baseline_val:.4f}, KL: {kl_val:.4f}, Rho: {rho:.4f}, ' + f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}') - if wandb and (not ddp or dist.get_rank() == 0): - log_dict = { + if wandb and is_main_process(): + wandb.log({ "policy_loss": policy_loss_val, "reward": avg_reward_val, "kl": kl_val, "rho": float(rho), "baseline": avg_baseline_val, - # "avg_response_len": avg_len_val, "advantages_mean": advantages.mean().item(), "learning_rate": current_lr - } - wandb.log(log_dict) + }) - if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0): + if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): model.eval() - moe_path = '_moe' if lm_config.use_moe else '' - suffix = 'spo' - ckp = f'{args.save_dir}/{suffix}_{lm_config.hidden_size}{moe_path}.pth' - - state_dict = model.module.state_dict() if isinstance(model, - torch.nn.parallel.DistributedDataParallel) else model.state_dict() + moe_suffix = '_moe' if lm_config.use_moe else '' + ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' + state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict() torch.save({k: v.half() for k, v in state_dict.items()}, ckp) + lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, + epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler) model.train() del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps @@ -249,120 +233,116 @@ def spo_train_epoch(epoch, wandb, value_tracker): gc.collect() -def init_model(lm_config): - tokenizer = AutoTokenizer.from_pretrained('../model/') - model = MiniMindForCausalLM(lm_config) - moe_path = '_moe' if lm_config.use_moe else '' - ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth' - if args.reasoning == 1: - ckp = f'{args.save_dir}/reason_{lm_config.hidden_size}{moe_path}.pth' - state_dict = torch.load(ckp, map_location=args.device) - model.load_state_dict(state_dict, strict=False) +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="MiniMind SPO (Self-Play Optimization)") + parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录") + parser.add_argument('--save_weight', default='spo', type=str, help="保存权重的前缀名") + parser.add_argument("--epochs", type=int, default=1, help="训练轮数") + parser.add_argument("--batch_size", type=int, default=2, help="batch size") + parser.add_argument("--learning_rate", type=float, default=1e-7, help="初始学习率") + parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备") + parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型") + parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数") + parser.add_argument("--accumulation_steps", type=int, default=4, help="梯度累积步数") + parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值") + parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔") + parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔") + parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") + parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") + parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE") + parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度") + parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度") + parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径") + parser.add_argument("--beta", type=float, default=0.02, help="KL惩罚系数") + parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型,1:推理模型') + parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径") + parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是") + parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb") + parser.add_argument("--wandb_project", type=str, default="MiniMind-SPO", help="wandb项目名") + args = parser.parse_args() + + # ========== 1. 初始化环境和随机种子 ========== + local_rank = init_distributed_mode() + if dist.is_initialized(): args.device = f"cuda:{local_rank}" + setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) + # ========== 2. 配置目录、模型参数、检查ckp ========== + os.makedirs(args.save_dir, exist_ok=True) + lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, + max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=args.use_moe) + ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None + + # ========== 3. 设置混合精度 ========== + device_type = "cuda" if "cuda" in args.device else "cpu" + dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) + + # ========== 4. 配wandb ========== + wandb = None + if args.use_wandb and is_main_process(): + import swanlab as wandb + wandb_id = ckp_data.get('wandb_id') if ckp_data else None + resume = 'must' if wandb_id else None + wandb_run_name = f"MiniMind-SPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}" + wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) + + # ========== 5. 初始化模型(Policy, Ref, Reward)和Value Tracker、数据 ========== + tokenizer = AutoTokenizer.from_pretrained('../model/') + moe_suffix = '_moe' if lm_config.use_moe else '' + base_weight = "reason" if args.reasoning == 1 else "full_sft" + ckp = f'{args.save_dir}/{base_weight}_{lm_config.hidden_size}{moe_suffix}.pth' + state_dict = torch.load(ckp, map_location=args.device) + # Policy模型 + model = MiniMindForCausalLM(lm_config) + model.load_state_dict(state_dict, strict=False) + model = model.to(args.device) + Logger(f'Policy模型总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} M') + # Reference模型 ref_model = MiniMindForCausalLM(lm_config) ref_model.load_state_dict(state_dict, strict=False) ref_model.eval().requires_grad_(False) - - Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万') - model = model.to(args.device) ref_model = ref_model.to(args.device) - - reward_name = "../../internlm2-1_8b-reward" + # Reward模型 reward_model = AutoModel.from_pretrained( - reward_name, - device_map="cuda", - torch_dtype=torch.float16, - trust_remote_code=True, + args.reward_model_path, device_map="cuda", torch_dtype=torch.float16, trust_remote_code=True ).to(args.device).eval().requires_grad_(False) - reward_tokenizer = AutoTokenizer.from_pretrained(reward_name, trust_remote_code=True) - - return model, ref_model, tokenizer, reward_model, reward_tokenizer - - -def init_distributed_mode(): - if not ddp: return - global ddp_local_rank, DEVICE - dist.init_process_group(backend="nccl") - ddp_local_rank = int(os.environ["LOCAL_RANK"]) - DEVICE = f"cuda:{ddp_local_rank}" - torch.cuda.set_device(DEVICE) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--out_dir", type=str, default="../out") - parser.add_argument("--epochs", type=int, default=1) - parser.add_argument("--batch_size", type=int, default=2) - parser.add_argument("--learning_rate", type=float, default=1e-7) - parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") - parser.add_argument("--dtype", type=str, default="bfloat16") - parser.add_argument("--use_wandb", action="store_true") - parser.add_argument("--wandb_project", type=str, default="MiniMind-SPO") - parser.add_argument("--num_workers", type=int, default=1) - parser.add_argument("--ddp", action="store_true") - parser.add_argument("--accumulation_steps", type=int, default=4) - parser.add_argument("--grad_clip", type=float, default=1.0) - parser.add_argument("--log_interval", type=int, default=1) - parser.add_argument("--save_interval", type=int, default=10) - parser.add_argument('--hidden_size', default=512, type=int) - parser.add_argument('--num_hidden_layers', default=8, type=int) - parser.add_argument('--use_moe', default=False, type=bool) - parser.add_argument('--max_seq_len', default=66, type=int) - parser.add_argument("--max_gen_len", type=int, default=1536) - parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl") - parser.add_argument("--beta", type=float, default=0.02) - parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型,1:推理模型') - args = parser.parse_args() - - lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, - max_seq_len=args.max_seq_len + args.max_gen_len, - use_moe=args.use_moe) - args.save_dir = os.path.join(args.out_dir) - os.makedirs(args.save_dir, exist_ok=True) - os.makedirs(args.out_dir, exist_ok=True) - - ctx = nullcontext() if "cuda" not in args.device else torch.amp.autocast('cuda') - ddp = int(os.environ.get("RANK", -1)) != -1 - ddp_local_rank, DEVICE = 0, "cuda:0" - - base_seed = 1337 - torch.manual_seed(base_seed) - torch.cuda.manual_seed(base_seed) - - if ddp: - init_distributed_mode() - args.device = torch.device(DEVICE) - rank = dist.get_rank() - torch.manual_seed(base_seed + rank) - # 同时设置 CUDA 的随机种子 - torch.cuda.manual_seed(base_seed + rank) - - if args.use_wandb and (not ddp or ddp_local_rank == 0): - import swanlab as wandb - - wandb.init(project=args.wandb_project) - else: - wandb = None - - model, ref_model, tokenizer, reward_model, reward_tokenizer = init_model(lm_config=lm_config) + reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True) + # Value Tracker + value_tracker = AutoAdaptiveValueTracker(rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96) + train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len) - train_sampler = DistributedSampler(train_ds) if ddp else None - train_loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True, + train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None + optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) + + loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler) + iters = len(loader_for_count) + total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs + scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10) + + # ========== 6. 从ckp恢复状态 ========== + start_epoch, start_step = 0, 0 + if ckp_data: + model.load_state_dict(ckp_data['model']) + optimizer.load_state_dict(ckp_data['optimizer']) + scheduler.load_state_dict(ckp_data['scheduler']) + start_epoch = ckp_data['epoch'] + start_step = ckp_data.get('step', 0) + + # ========== 7. DDP包模型 ========== + if dist.is_initialized(): + model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} + model = DistributedDataParallel(model, device_ids=[local_rank]) + + # ========== 8. 开始训练 ========== + for epoch in range(start_epoch, args.epochs): + train_sampler and train_sampler.set_epoch(epoch) + if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点 + batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1) + loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) + Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始') + spo_train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, reward_model, reward_tokenizer, value_tracker, start_step, wandb) + else: # 默认从头开始 + loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True, drop_last=False, shuffle=(train_sampler is None), num_workers=args.num_workers, sampler=train_sampler) - - optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) - - iter_per_epoch = len(train_loader) - total_optimizer_steps = (iter_per_epoch // args.accumulation_steps) * args.epochs - scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10) - - if ddp: - model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) - - value_tracker = AutoAdaptiveValueTracker(rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96) - - for epoch in range(args.epochs): - train_sampler and train_sampler.set_epoch(epoch) - spo_train_epoch(epoch, wandb, value_tracker) + spo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, value_tracker, 0, wandb) diff --git a/trainer/trainer_utils.py b/trainer/trainer_utils.py new file mode 100644 index 0000000..9675c14 --- /dev/null +++ b/trainer/trainer_utils.py @@ -0,0 +1,139 @@ +""" +训练工具函数集合 +""" +import os +import random +import math +import numpy as np +import torch +import torch.distributed as dist +from torch.utils.data import Sampler + + +def is_main_process(): + return not dist.is_initialized() or dist.get_rank() == 0 + + +def Logger(content): + if is_main_process(): + print(content) + + +def get_lr(current_step, total_steps, lr): + return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps)) + + +def init_distributed_mode(): + if int(os.environ.get("RANK", -1)) == -1: + return 0 # 非DDP模式 + + dist.init_process_group(backend="nccl") + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + return local_rank + + +def setup_seed(seed: int): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoch=0, step=0, wandb=None, save_dir='../checkpoints', **kwargs): + os.makedirs(save_dir, exist_ok=True) + moe_path = '_moe' if lm_config.use_moe else '' + ckp_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}.pth' + resume_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}_resume.pth' + + if model is not None: + from torch.nn.parallel import DistributedDataParallel + state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict() + ckp_tmp = ckp_path + '.tmp' + torch.save({k: v.half() for k, v in state_dict.items()}, ckp_tmp) + os.replace(ckp_tmp, ckp_path) + wandb_id = None + if wandb: + if hasattr(wandb, 'get_run'): + run = wandb.get_run() + wandb_id = getattr(run, 'id', None) if run else None + else: + wandb_id = getattr(wandb, 'id', None) + + resume_data = { + 'model': state_dict, + 'optimizer': optimizer.state_dict(), + 'epoch': epoch, + 'step': step, + 'world_size': dist.get_world_size() if dist.is_initialized() else 1, + 'wandb_id': wandb_id + } + for key, value in kwargs.items(): + if value is not None: + if hasattr(value, 'state_dict'): + if isinstance(value, DistributedDataParallel): + resume_data[key] = value.module.state_dict() + else: + resume_data[key] = value.state_dict() + else: + resume_data[key] = value + + resume_tmp = resume_path + '.tmp' + torch.save(resume_data, resume_tmp) + os.replace(resume_tmp, resume_path) + else: # 加载模式 + if os.path.exists(resume_path): + ckp_data = torch.load(resume_path, map_location='cpu') + saved_ws = ckp_data.get('world_size', 1) + current_ws = dist.get_world_size() if dist.is_initialized() else 1 + if saved_ws != current_ws: + ckp_data['step'] = ckp_data['step'] * saved_ws // current_ws + Logger(f'GPU数量变化({saved_ws}→{current_ws}),step已自动转换为{ckp_data["step"]}') + return ckp_data + return None + + +def init_model(lm_config, from_weight='pretrain', tokenizer_path='../model', save_dir='../out', device='cuda'): + from transformers import AutoTokenizer + from model.model_minimind import MiniMindForCausalLM + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + model = MiniMindForCausalLM(lm_config) + + if from_weight!= 'none': + moe_suffix = '_moe' if lm_config.use_moe else '' + weight_path = f'{save_dir}/{from_weight}_{lm_config.hidden_size}{moe_suffix}.pth' + weights = torch.load(weight_path, map_location=device) + model.load_state_dict(weights, strict=False) + + Logger(f'所加载Model可训练参数:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万') + return model.to(device), tokenizer + + +class SkipBatchSampler(Sampler): + def __init__(self, sampler, batch_size, skip_batches=0): + self.sampler = sampler + self.batch_size = batch_size + self.skip_batches = skip_batches + + def __iter__(self): + batch = [] + skipped = 0 + for idx in self.sampler: + batch.append(idx) + if len(batch) == self.batch_size: + if skipped < self.skip_batches: + skipped += 1 + batch = [] + continue + yield batch + batch = [] + if len(batch) > 0 and skipped >= self.skip_batches: + yield batch + + def __len__(self): + total_batches = (len(self.sampler) + self.batch_size - 1) // self.batch_size + return max(0, total_batches - self.skip_batches) +