From ef40a1f271b6295c35abbd463cc2865cce47d85f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BF=9F=E9=94=A6=E6=B4=8B?= <129938610+UCAS-zhaijinyang@users.noreply.github.com> Date: Sun, 12 Apr 2026 18:06:52 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=80=E4=BA=9B=E5=B0=8F=E7=9A=84=E6=94=B9?= =?UTF-8?q?=E5=8A=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1、给模型定义和模型训练的相关程序添加注释,并移动代码位置,增加代码的可读性和可理解性 2、修改了训练因故中断后,无法正常从分布式环境中续训的问题 --- model/model_minimind.py | 70 +++++++++++++++------- scripts/convert_model.py | 6 +- trainer/train_agent.py | 59 +++++++++++++------ trainer/train_distillation.py | 2 +- trainer/train_dpo.py | 70 ++++++++++++---------- trainer/train_full_sft.py | 106 +++++++++++++++++----------------- trainer/train_grpo.py | 2 +- trainer/train_lora.py | 72 ++++++++++++----------- trainer/train_ppo.py | 104 ++++++++++++++++++--------------- trainer/train_pretrain.py | 98 +++++++++++++++---------------- trainer/train_tokenizer.py | 53 ++++++++++++++--- 11 files changed, 376 insertions(+), 266 deletions(-) diff --git a/model/model_minimind.py b/model/model_minimind.py index 70ee32b..4dba819 100755 --- a/model/model_minimind.py +++ b/model/model_minimind.py @@ -11,21 +11,17 @@ class MiniMindConfig(PretrainedConfig): model_type = "minimind" def __init__(self, hidden_size=768, num_hidden_layers=8, use_moe=False, **kwargs): super().__init__(**kwargs) - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.use_moe = use_moe - self.dropout = kwargs.get("dropout", 0.0) + #################################################### + # token相关 + #################################################### self.vocab_size = kwargs.get("vocab_size", 6400) self.bos_token_id = kwargs.get("bos_token_id", 1) self.eos_token_id = kwargs.get("eos_token_id", 2) - self.flash_attn = kwargs.get("flash_attn", True) - self.num_attention_heads = kwargs.get("num_attention_heads", 8) - self.num_key_value_heads = kwargs.get("num_key_value_heads", 4) - self.head_dim = kwargs.get("head_dim", self.hidden_size // self.num_attention_heads) - self.hidden_act = kwargs.get("hidden_act", 'silu') - self.intermediate_size = kwargs.get("intermediate_size", math.ceil(hidden_size * math.pi / 64) * 64) + + #################################################### + # embedding相关 + #################################################### self.max_position_embeddings = kwargs.get("max_position_embeddings", 32768) - self.rms_norm_eps = kwargs.get("rms_norm_eps", 1e-6) self.rope_theta = kwargs.get("rope_theta", 1e6) self.inference_rope_scaling = kwargs.get("inference_rope_scaling", False) self.rope_scaling = { @@ -36,6 +32,35 @@ class MiniMindConfig(PretrainedConfig): "attention_factor": 1.0, "type": "yarn" } if self.inference_rope_scaling else None + + #################################################### + # 表示空间(Representation Space)相关 + #################################################### + self.hidden_size = hidden_size + + #################################################### + # transformer相关 + #################################################### + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = kwargs.get("num_attention_heads", 8) + # GQA中的KV复用机制 + self.num_key_value_heads = kwargs.get("num_key_value_heads", 4) + self.head_dim = kwargs.get("head_dim", self.hidden_size // self.num_attention_heads) + self.flash_attn = kwargs.get("flash_attn", True) + + #################################################### + # 前馈网络相关 + #################################################### + self.intermediate_size = kwargs.get("intermediate_size", math.ceil(hidden_size * math.pi / 64) * 64) + self.hidden_act = kwargs.get("hidden_act", 'silu') + + #################################################### + # 模型整体架构相关 + #################################################### + self.use_moe = use_moe + self.dropout = kwargs.get("dropout", 0.0) + self.rms_norm_eps = kwargs.get("rms_norm_eps", 1e-6) + ### MoE specific configs (ignored if use_moe = False) self.num_experts = kwargs.get("num_experts", 4) self.num_experts_per_tok = kwargs.get("num_experts_per_tok", 1) @@ -46,17 +71,6 @@ class MiniMindConfig(PretrainedConfig): # 🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏 # MiniMind Model # 🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏 -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-5): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - return (self.weight * self.norm(x.float())).type_as(x) def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6, rope_scaling: dict = None): freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0 @@ -87,6 +101,18 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: if n_rep == 1: return x return (x[:, :, :, None, :].expand(bs, slen, num_key_value_heads, n_rep, head_dim).reshape(bs, slen, num_key_value_heads * n_rep, head_dim)) +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + return (self.weight * self.norm(x.float())).type_as(x) + class Attention(nn.Module): def __init__(self, config: MiniMindConfig): super().__init__() diff --git a/scripts/convert_model.py b/scripts/convert_model.py index 34099e5..e6e7fcb 100644 --- a/scripts/convert_model.py +++ b/scripts/convert_model.py @@ -126,10 +126,12 @@ def convert_json_to_jinja(json_file_path, output_path): if __name__ == '__main__': - lm_config = MiniMindConfig(hidden_size=768, num_hidden_layers=8, max_seq_len=8192, use_moe=True) + + # 注意这里use_moe参数的配置,默认使用非MoE模型 + lm_config = MiniMindConfig(hidden_size=768, num_hidden_layers=8, max_seq_len=8192, use_moe=False) # convert torch to transformers torch_path = f"../out/full_sft_{lm_config.hidden_size}{'_moe' if lm_config.use_moe else ''}.pth" - transformers_path = '../minimind-3-moe' + transformers_path = '../minimind-3' convert_torch2transformers(torch_path, transformers_path) # # merge lora diff --git a/trainer/train_agent.py b/trainer/train_agent.py index 9a84b69..82ee985 100644 --- a/trainer/train_agent.py +++ b/trainer/train_agent.py @@ -241,15 +241,19 @@ def calculate_rewards(prompts, completions, gt_batch, tools_batch, num_gen, rewa def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model=None, start_step=0, wandb=None, use_sglang=False): last_step = start_step for step, batch in enumerate(loader, start=start_step + 1): + ########################### 训练前操作 ########################### + # 数据准备 messages_batch = batch['messages'] tools_batch = batch['tools'] gt_batch = batch['gt'] last_step = step + prompts = [tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True, tools=t) for m, t in zip(messages_batch, tools_batch)] + with torch.no_grad(): completions, contexts, prompt_ids_batch, response_ids_batch, response_masks_batch, response_old_logps_batch, turn_outputs_batch, unfinished_batch = rollout_batch(rollout_engine, tokenizer, messages_batch, tools_batch, args.num_generations, max_turns=3, max_new_tokens=args.max_gen_len, thinking_ratio=args.thinking_ratio, device=args.device) - prompts = [tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True, tools=t) for m, t in zip(messages_batch, tools_batch)] + # 数据处理(保证序列长度一致) packed_samples = [] for p, r, m, old_lp in zip(prompt_ids_batch, response_ids_batch, response_masks_batch, response_old_logps_batch): ids = p + r @@ -268,6 +272,8 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model full_response_masks = torch.tensor([mask + [0] * (max_len - len(mask)) for _, mask, _, _ in packed_samples], device=args.device, dtype=torch.float32) old_per_token_logps = torch.tensor([old_logps + [0.0] * ((max_len - 1) - len(old_logps)) for _, _, _, old_logps in packed_samples], device=args.device, dtype=torch.float32) + ########################### 训练中操作 ########################### + # 数据计算 model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model with autocast_ctx: res = model_unwrapped(input_ids) @@ -316,6 +322,8 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model kl_div = ref_per_token_logps - per_token_logps per_token_kl = torch.exp(kl_div) - kl_div - 1 ratio = torch.exp(per_token_logps - old_per_token_logps) + + # 定义损失函数 if args.loss_type == "cispo": clamped_ratio = torch.clamp(ratio, max=args.epsilon_high).detach() per_token_loss = -(clamped_ratio * advantages.unsqueeze(1) * per_token_logps - args.beta * per_token_kl) @@ -327,13 +335,18 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model policy_loss = (((per_token_loss * completion_mask).sum(dim=1)[valid_rows] / token_counts[valid_rows].clamp(min=1)).mean() if valid_rows.any() else per_token_loss.sum() * 0.0) loss = (policy_loss + aux_loss) / args.accumulation_steps + + # 反向传播 loss.backward() + # 梯度更新 if step % args.accumulation_steps == 0: if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step(); scheduler.step(); optimizer.zero_grad() if is_main_process() and step % args.save_interval == 0: rollout_engine.update_policy(model) + ########################### 训练后操作 ########################### + # 日志打印 if step % args.log_interval == 0 or step == iters: pl = loss.item() * args.accumulation_steps ar = rewards.mean().item() @@ -346,6 +359,7 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model if wandb and is_main_process(): wandb.log({"reward":ar,"kl_ref":kl,"group_reward_std":gs,"advantages_std":ast,"policy_loss":pl,"avg_response_len":al,"advantages_mean":am,"learning_rate":lr}) + # 模型保存 if (step % args.save_interval == 0 or step == iters) and is_main_process(): model.eval() moe_suffix = '_moe' if lm_config.use_moe else '' @@ -409,25 +423,20 @@ if __name__ == "__main__": parser.add_argument("--sglang_shared_path", type=str, default="./sglang_ckpt_agent", help="SGLang共享存储路径") args = parser.parse_args() + # ========== 1. 创建训练环境 ========== local_rank = init_distributed_mode() if dist.is_initialized(): args.device = f"cuda:{local_rank}" setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) + # ========== 2. 模型相关 ========== os.makedirs(args.save_dir, exist_ok=True) - lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, - max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe)) - ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume == 1 else None - + device_type = "cuda" if "cuda" in args.device else "cpu" dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) - wandb = None - if args.use_wandb and is_main_process(): - import swanlab as wandb - wandb_id = ckp_data.get('wandb_id') if ckp_data else None - resume = 'must' if wandb_id else None - wandb.init(project=args.wandb_project, name=f"Agent-RL-E{args.epochs}-B{args.batch_size}-LR{args.learning_rate}", id=wandb_id, resume=resume) + lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, + max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe)) model, tokenizer = init_model(lm_config, args.from_weight, device=args.device) @@ -456,14 +465,6 @@ if __name__ == "__main__": total_optimizer_steps = math.ceil(iters / args.accumulation_steps) * args.epochs scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10) - start_epoch, start_step = 0, 0 - if ckp_data: - model.load_state_dict(ckp_data['model']) - optimizer.load_state_dict(ckp_data['optimizer']) - scheduler.load_state_dict(ckp_data['scheduler']) - start_epoch = ckp_data['epoch'] - start_step = ckp_data.get('step', 0) - if args.use_compile == 1: model = torch.compile(model) Logger('torch.compile enabled') @@ -472,6 +473,25 @@ if __name__ == "__main__": model = DistributedDataParallel(model, device_ids=[local_rank]) if is_main_process(): rollout_engine.update_policy(model) + ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume == 1 else None + + # ========== 3. checkpoint相关 ========== + wandb = None + if args.use_wandb and is_main_process(): + import swanlab as wandb + wandb_id = ckp_data.get('wandb_id') if ckp_data else None + resume = 'must' if wandb_id else None + wandb.init(project=args.wandb_project, name=f"Agent-RL-E{args.epochs}-B{args.batch_size}-LR{args.learning_rate}", id=wandb_id, resume=resume) + + start_epoch, start_step = 0, 0 + if ckp_data: + model.module.load_state_dict(ckp_data['model']) + optimizer.load_state_dict(ckp_data['optimizer']) + scheduler.load_state_dict(ckp_data['scheduler']) + start_epoch = ckp_data['epoch'] + start_step = ckp_data.get('step', 0) + + # ========== 4. 开始训练 ========== for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist() @@ -484,4 +504,5 @@ if __name__ == "__main__": else: rl_train_epoch(epoch, loader, len(loader), rollout_engine, ref_model, reward_model, 0, wandb, use_sglang = (args.rollout_engine == "sglang")) + # ========== 5. 清理分布进程 ========== if dist.is_initialized(): dist.destroy_process_group() diff --git a/trainer/train_distillation.py b/trainer/train_distillation.py index 6b1b5b9..59c1697 100644 --- a/trainer/train_distillation.py +++ b/trainer/train_distillation.py @@ -215,7 +215,7 @@ if __name__ == "__main__": # ========== 6. 从ckp恢复状态 ========== start_epoch, start_step = 0, 0 if ckp_data: - model.load_state_dict(ckp_data['model']) + model.module.load_state_dict(ckp_data['model']) optimizer.load_state_dict(ckp_data['optimizer']) scaler.load_state_dict(ckp_data['scaler']) start_epoch = ckp_data['epoch'] diff --git a/trainer/train_dpo.py b/trainer/train_dpo.py index 1876e26..ca3ba15 100644 --- a/trainer/train_dpo.py +++ b/trainer/train_dpo.py @@ -54,6 +54,8 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb= last_step = start_step for step, batch in enumerate(loader, start=start_step + 1): + ########################### 训练前操作 ########################### + # 数据加载 last_step = step x_chosen = batch['x_chosen'].to(args.device) x_rejected = batch['x_rejected'].to(args.device) @@ -64,11 +66,14 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb= x = torch.cat([x_chosen, x_rejected], dim=0) y = torch.cat([y_chosen, y_rejected], dim=0) mask = torch.cat([mask_chosen, mask_rejected], dim=0) - lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate) + + # 学习率调整 for param_group in optimizer.param_groups: param_group['lr'] = lr + ########################### 训练中操作 ########################### + # 模型前向传播 with autocast_ctx: with torch.no_grad(): ref_outputs = ref_model(x) @@ -83,8 +88,10 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb= loss = dpo_loss_val + outputs.aux_loss loss = loss / args.accumulation_steps + # 模型反向传播 scaler.scale(loss).backward() + # 梯度更新 if step % args.accumulation_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) @@ -92,6 +99,8 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb= scaler.update() optimizer.zero_grad(set_to_none=True) + ########################### 训练后操作 ########################### + # 日志打印 if step % args.log_interval == 0 or step == iters: spend_time = time.time() - start_time current_loss = loss.item() * args.accumulation_steps @@ -104,6 +113,7 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb= if wandb: wandb.log({"loss": current_loss, "dpo_loss": current_dpo_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min}) + # 模型保存 if (step % args.save_interval == 0 or step == iters) and is_main_process(): model.eval() moe_suffix = '_moe' if lm_config.use_moe else '' @@ -154,31 +164,20 @@ if __name__ == "__main__": parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)") args = parser.parse_args() - # ========== 1. 初始化环境和随机种子 ========== + # ========== 1. 创建训练环境 ========== local_rank = init_distributed_mode() if dist.is_initialized(): args.device = f"cuda:{local_rank}" setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) - # ========== 2. 配置目录、模型参数、检查ckp ========== + # ========== 2. 模型相关 ========== os.makedirs(args.save_dir, exist_ok=True) - lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe)) - ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None - - # ========== 3. 设置混合精度 ========== + device_type = "cuda" if "cuda" in args.device else "cpu" dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) - - # ========== 4. 配wandb ========== - wandb = None - if args.use_wandb and is_main_process(): - import swanlab as wandb - wandb_id = ckp_data.get('wandb_id') if ckp_data else None - resume = 'must' if wandb_id else None - wandb_run_name = f"MiniMind-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}" - wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) - - # ========== 5. 定义模型和参考模型 ========== + + lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe)) + model, tokenizer = init_model(lm_config, args.from_weight, device=args.device) Logger(f'策略模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M') # 初始化参考模型(ref_model冻结) @@ -191,25 +190,34 @@ if __name__ == "__main__": train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) - - # ========== 6. 从ckp恢复状态 ========== - start_epoch, start_step = 0, 0 - if ckp_data: - model.load_state_dict(ckp_data['model']) - optimizer.load_state_dict(ckp_data['optimizer']) - scaler.load_state_dict(ckp_data['scaler']) - start_epoch = ckp_data['epoch'] - start_step = ckp_data.get('step', 0) - - # ========== 7. 编译和分布式包装 ========== + if args.use_compile == 1: model = torch.compile(model) Logger('torch.compile enabled') if dist.is_initialized(): model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model = DistributedDataParallel(model, device_ids=[local_rank]) + + ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None - # ========== 8. 开始训练 ========== + # ========== 3.checkpoint相关 ========== + wandb = None + if args.use_wandb and is_main_process(): + import swanlab as wandb + wandb_id = ckp_data.get('wandb_id') if ckp_data else None + resume = 'must' if wandb_id else None + wandb_run_name = f"MiniMind-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}" + wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) + + start_epoch, start_step = 0, 0 + if ckp_data: + model.module.load_state_dict(ckp_data['model']) + optimizer.load_state_dict(ckp_data['optimizer']) + scaler.load_state_dict(ckp_data['scaler']) + start_epoch = ckp_data['epoch'] + start_step = ckp_data.get('step', 0) + + # ========== 4. 训练 ========== for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist() @@ -222,5 +230,5 @@ if __name__ == "__main__": else: train_epoch(epoch, loader, len(loader), ref_model, lm_config, 0, wandb, args.beta) - # ========== 9. 清理分布进程 ========== + # ========== 5. 清理分布进程 ========== if dist.is_initialized(): dist.destroy_process_group() \ No newline at end of file diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py index f685760..d656db0 100644 --- a/trainer/train_full_sft.py +++ b/trainer/train_full_sft.py @@ -19,26 +19,32 @@ from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint warnings.filterwarnings('ignore') - def train_epoch(epoch, loader, iters, start_step=0, wandb=None): start_time = time.time() - last_step = start_step for step, (input_ids, labels) in enumerate(loader, start=start_step + 1): + + ########################### 训练前操作 ########################### + # 数据加载 input_ids = input_ids.to(args.device) labels = labels.to(args.device) - last_step = step lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate) + + # 学习率调整 for param_group in optimizer.param_groups: param_group['lr'] = lr + ########################### 训练中操作 ########################### + # 模型前向传播 with autocast_ctx: res = model(input_ids, labels=labels) loss = res.loss + res.aux_loss loss = loss / args.accumulation_steps + # 模型反向传播 scaler.scale(loss).backward() - if step % args.accumulation_steps == 0: + # 梯度更新 + if (step + 1) % args.accumulation_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) @@ -47,17 +53,20 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None): optimizer.zero_grad(set_to_none=True) - if step % args.log_interval == 0 or step == iters: + ########################### 训练后操作 ########################### + # 日志打印 + if step % args.log_interval == 0 or step == iters - 1: spend_time = time.time() - start_time current_loss = loss.item() * args.accumulation_steps current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0 current_logits_loss = current_loss - current_aux_loss current_lr = optimizer.param_groups[-1]['lr'] - eta_min = spend_time / max(step - start_step, 1) * (iters - step) // 60 + eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60 Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min') if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min}) - if (step % args.save_interval == 0 or step == iters) and is_main_process(): + # 模型保存 + if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): model.eval() moe_suffix = '_moe' if lm_config.use_moe else '' ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' @@ -65,20 +74,12 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None): raw_model = getattr(raw_model, '_orig_mod', raw_model) state_dict = raw_model.state_dict() torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp) - lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, - epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler) + lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') model.train() del state_dict del input_ids, labels, res, loss - if last_step > start_step and last_step % args.accumulation_steps != 0: - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad(set_to_none=True) - if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind Full SFT") @@ -86,7 +87,7 @@ if __name__ == "__main__": parser.add_argument('--save_weight', default='full_sft', type=str, help="保存权重的前缀名") parser.add_argument("--epochs", type=int, default=2, help="训练轮数") parser.add_argument("--batch_size", type=int, default=16, help="batch size") - parser.add_argument("--learning_rate", type=float, default=1e-5, help="初始学习率") + parser.add_argument("--learning_rate", type=float, default=1e-6, help="初始学习率") parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备") parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型") parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数") @@ -96,7 +97,7 @@ if __name__ == "__main__": parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔") parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度") parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") - parser.add_argument('--max_seq_len', default=768, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)") + parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)") parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)") parser.add_argument("--data_path", type=str, default="../dataset/sft_t2t_mini.jsonl", help="训练数据路径") parser.add_argument('--from_weight', default='pretrain', type=str, help="基于哪个权重训练,为none则不基于任何权重训练") @@ -106,55 +107,52 @@ if __name__ == "__main__": parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)") args = parser.parse_args() - # ========== 1. 初始化环境和随机种子 ========== + # ========== 1. 创建训练环境 ========== local_rank = init_distributed_mode() - if dist.is_initialized(): args.device = f"cuda:{local_rank}" - setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) + if dist.is_initialized(): args.device = f"cuda:{local_rank}" # 分布式训练时,设置设备 + setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) # 设置随机种子 - # ========== 2. 配置目录、模型参数、检查ckp ========== - os.makedirs(args.save_dir, exist_ok=True) - lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe)) - ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None - - # ========== 3. 设置混合精度 ========== + # ========== 2. 模型相关 ========== + os.makedirs(args.save_dir, exist_ok=True) # 创建模型保存目录 + device_type = "cuda" if "cuda" in args.device else "cpu" dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 - autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) - - # ========== 4. 配wandb ========== + autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) # 设置混合精度 + + lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe)) + model, tokenizer = init_model(lm_config, args.from_weight, device=args.device) + if args.use_compile == 1: + model = torch.compile(model) + Logger('torch.compile enabled') + train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len) + train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None + scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) + optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) # 所有模型相关的初始化 + + if dist.is_initialized(): + model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} + model = DistributedDataParallel(model, device_ids=[local_rank]) # 分布式训练模型初始化 + + ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None # 检查checkpoint + + # ========== 3. checkpoint相关 ========== wandb = None if args.use_wandb and is_main_process(): import swanlab as wandb wandb_id = ckp_data.get('wandb_id') if ckp_data else None resume = 'must' if wandb_id else None - wandb_run_name = f"MiniMind-Full-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}" - wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) - - # ========== 5. 定义模型、数据、优化器 ========== - model, tokenizer = init_model(lm_config, args.from_weight, device=args.device) - train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len) - train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None - scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) - optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) - - # ========== 6. 从ckp恢复状态 ========== + wandb_run_name = f"MiniMind-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}" + wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) # 通过checkpoint进行训练可视化 + start_epoch, start_step = 0, 0 if ckp_data: - model.load_state_dict(ckp_data['model']) + model.module.load_state_dict(ckp_data['model']) optimizer.load_state_dict(ckp_data['optimizer']) scaler.load_state_dict(ckp_data['scaler']) start_epoch = ckp_data['epoch'] - start_step = ckp_data.get('step', 0) + start_step = ckp_data.get('step', 0) # 通过checkpoint进行状态恢复 - # ========== 7. 编译和分布式包装 ========== - if args.use_compile == 1: - model = torch.compile(model) - Logger('torch.compile enabled') - if dist.is_initialized(): - model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[local_rank]) - - # ========== 8. 开始训练 ========== + # ========== 4. 训练 ========== for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist() @@ -167,5 +165,5 @@ if __name__ == "__main__": else: train_epoch(epoch, loader, len(loader), 0, wandb) - # ========== 9. 清理分布进程 ========== - if dist.is_initialized(): dist.destroy_process_group() \ No newline at end of file + # ========== 5. 撤销训练环境 ========== + if dist.is_initialized(): dist.destroy_process_group() # 撤销分布式训练环境 diff --git a/trainer/train_grpo.py b/trainer/train_grpo.py index 2d514e2..00c762a 100755 --- a/trainer/train_grpo.py +++ b/trainer/train_grpo.py @@ -299,7 +299,7 @@ if __name__ == "__main__": # ========== 6. 从ckp恢复状态 ========== start_epoch, start_step = 0, 0 if ckp_data: - model.load_state_dict(ckp_data['model']) + model.module.load_state_dict(ckp_data['model']) optimizer.load_state_dict(ckp_data['optimizer']) scheduler.load_state_dict(ckp_data['scheduler']) start_epoch = ckp_data['epoch'] diff --git a/trainer/train_lora.py b/trainer/train_lora.py index 4dae568..aabaf7e 100644 --- a/trainer/train_lora.py +++ b/trainer/train_lora.py @@ -25,20 +25,28 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None): start_time = time.time() last_step = start_step for step, (input_ids, labels) in enumerate(loader, start=start_step + 1): + ########################### 训练前操作 ########################### + # 数据加载 input_ids = input_ids.to(args.device) labels = labels.to(args.device) last_step = step + + # 学习率调整 lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate) for param_group in optimizer.param_groups: param_group['lr'] = lr + ########################### 训练中操作 ########################### + # 模型前向传播 with autocast_ctx: res = model(input_ids, labels=labels) loss = res.loss + res.aux_loss loss = loss / args.accumulation_steps + # 模型反向传播 scaler.scale(loss).backward() - + + # 梯度更新 if step % args.accumulation_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip) @@ -46,6 +54,8 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None): scaler.update() optimizer.zero_grad(set_to_none=True) + ########################### 训练后操作 ########################### + # 日志打印 if step % args.log_interval == 0 or step == iters: spend_time = time.time() - start_time current_loss = loss.item() * args.accumulation_steps @@ -56,10 +66,10 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None): Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min') if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min}) + # 模型保存(仅保存LoRA权重) if (step % args.save_interval == 0 or step == iters) and is_main_process(): model.eval() lora_save_path = f'{args.save_dir}/{args.lora_name}_{lm_config.hidden_size}.pth' - # LoRA只保存LoRA权重 save_lora(model, lora_save_path) lm_checkpoint(lm_config, weight=args.lora_name, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') model.train() @@ -99,41 +109,27 @@ if __name__ == "__main__": parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)") args = parser.parse_args() - # ========== 1. 初始化环境和随机种子 ========== + # ========== 1. 创建训练环境 ========== local_rank = init_distributed_mode() if dist.is_initialized(): args.device = f"cuda:{local_rank}" setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) - # ========== 2. 配置目录、模型参数、检查ckp ========== - os.makedirs(args.save_dir, exist_ok=True) - lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe)) - ckp_data = lm_checkpoint(lm_config, weight=args.lora_name, save_dir='../checkpoints') if args.from_resume==1 else None - - # ========== 3. 设置混合精度 ========== + # ========== 2. 模型相关 ========== + os.makedirs(args.save_dir, exist_ok=True) + device_type = "cuda" if "cuda" in args.device else "cpu" dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) - - # ========== 4. 配wandb ========== - wandb = None - if args.use_wandb and is_main_process(): - import swanlab as wandb - wandb_id = ckp_data.get('wandb_id') if ckp_data else None - resume = 'must' if wandb_id else None - wandb_run_name = f"MiniMind-LoRA-{args.lora_name}-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}" - wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) - - # ========== 5. 定义模型、应用LoRA、冻结非LoRA参数 ========== + + lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe)) model, tokenizer = init_model(lm_config, args.from_weight, device=args.device) apply_lora(model) - # 统计参数 total_params = sum(p.numel() for p in model.parameters()) lora_params_count = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name) Logger(f"LLM 总参数量: {total_params / 1e6:.3f} M") Logger(f"LoRA 参数量: {lora_params_count / 1e6:.3f} M") Logger(f"LoRA 参数占比: {lora_params_count / total_params * 100:.2f}%") - # 冻结非LoRA参数,收集LoRA参数 lora_params = [] for name, param in model.named_parameters(): @@ -143,13 +139,29 @@ if __name__ == "__main__": else: param.requires_grad = False - # ========== 6. 定义数据和优化器 ========== train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len) train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) optimizer = optim.AdamW(lora_params, lr=args.learning_rate) + + if args.use_compile == 1: + model = torch.compile(model) + Logger('torch.compile enabled') + if dist.is_initialized(): + model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} + model = DistributedDataParallel(model, device_ids=[local_rank]) + + ckp_data = lm_checkpoint(lm_config, weight=args.lora_name, save_dir='../checkpoints') if args.from_resume==1 else None + + # ========== 3. checkpoint相关 ========== + wandb = None + if args.use_wandb and is_main_process(): + import swanlab as wandb + wandb_id = ckp_data.get('wandb_id') if ckp_data else None + resume = 'must' if wandb_id else None + wandb_run_name = f"MiniMind-LoRA-{args.lora_name}-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}" + wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) - # ========== 7. 从ckp恢复状态 ========== start_epoch, start_step = 0, 0 if ckp_data: model.load_state_dict(ckp_data['model'], strict=False) @@ -158,15 +170,7 @@ if __name__ == "__main__": start_epoch = ckp_data['epoch'] start_step = ckp_data.get('step', 0) - # ========== 8. 编译和分布式包装 ========== - if args.use_compile == 1: - model = torch.compile(model) - Logger('torch.compile enabled') - if dist.is_initialized(): - model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[local_rank]) - - # ========== 9. 开始训练 ========== + # ========== 4. 训练 ========== for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist() @@ -179,5 +183,5 @@ if __name__ == "__main__": else: train_epoch(epoch, loader, len(loader), lora_params, 0, wandb) - # ========== 10. 清理分布进程 ========== + # ========== 5. 撤销训练环境 ========== if dist.is_initialized(): dist.destroy_process_group() \ No newline at end of file diff --git a/trainer/train_ppo.py b/trainer/train_ppo.py index 27668ec..0469b3d 100644 --- a/trainer/train_ppo.py +++ b/trainer/train_ppo.py @@ -25,28 +25,24 @@ from trainer.rollout_engine import create_rollout_engine warnings.filterwarnings('ignore') - -def rep_penalty(text, n=3, cap=0.5): - toks = re.findall(r"\w+|[^\w\s]", text.lower()) - grams = [tuple(toks[i:i + n]) for i in range(len(toks) - n + 1)] - return min(cap, (len(grams) - len(set(grams))) * cap * 2 / len(grams)) if grams else 0.0 - - # 自定义的Critic模型,继承自MiniMindLM class CriticModel(MiniMindForCausalLM): def __init__(self, params): super().__init__(params) - # 替换lm_head为输出单一价值的线性层 self.value_head = nn.Linear(params.hidden_size, 1) def forward(self, input_ids=None, attention_mask=None, **kwargs): # 使用基础模型获取隐藏状态 outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) hidden_states = self.model.norm(outputs[0]) - # 使用value_head获取价值估计 + # 使用value_head替代lm_head,获取价值估计 values = self.value_head(hidden_states).squeeze(-1) return values +def rep_penalty(text, n=3, cap=0.5): + toks = re.findall(r"\w+|[^\w\s]", text.lower()) + grams = [tuple(toks[i:i + n]) for i in range(len(toks) - n + 1)] + return min(cap, (len(grams) - len(set(grams))) * cap * 2 / len(grams)) if grams else 0.0 def calculate_rewards(prompts, responses, reward_model): rewards = torch.zeros(len(responses), device=args.device) @@ -81,6 +77,8 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched grad_accum_step = 0 for step, batch in enumerate(loader, start=start_step + 1): + ########################### 训练前操作 ########################### + # 数据准备 prompts = batch["prompt"] # list[str], length B enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=args.max_seq_len, padding_side="left").to(args.device) # input_ids: [B, P], attention_mask: [B, P] @@ -110,7 +108,8 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched Logger(f"{'=' * 29} [DEBUG] sample[{i}] RESPONSE_END {'=' * 29}") Logger(f"[DEBUG] reward={rewards[i].item():.4f}") Logger('='*100) - + + # 数据处理 full_mask = (gen_out != tokenizer.pad_token_id).long() # [B, P+R] labels = gen_out[:, 1:].clone() # [B, P+R-1] seq_len, resp_start = gen_out.size(1) - 1, prompt_length - 1 @@ -126,6 +125,8 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched resp_policy_mask = ((resp_idx < resp_lengths.unsqueeze(1)) & resp_pad_mask).float() resp_value_mask = resp_policy_mask.clone() + ########################### 训练中操作 ########################### + # 数据计算,初始化优势函数advantages和价值函数returns with torch.no_grad(): # Rollout阶段只需推理获取old_logp和old_values,切断梯度省显存 critic_for_rollout = critic_model.module if isinstance(critic_model, DistributedDataParallel) else critic_model values_seq = critic_for_rollout(input_ids=gen_out, attention_mask=full_mask) @@ -155,7 +156,8 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched adv_mean = (advantages * resp_policy_mask).sum() / resp_policy_mask.sum().clamp(min=1) adv_var = ((advantages - adv_mean) ** 2 * resp_policy_mask).sum() / resp_policy_mask.sum().clamp(min=1) advantages = (advantages - adv_mean) * torch.rsqrt(adv_var + 1e-8) * resp_policy_mask - + + # 训练参数初始化 mb_size = max(1, min(args.mini_batch_size, B)) stop_ppo = False policy_loss_sum = 0.0 @@ -167,6 +169,8 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched log_count = 0 actor_unwrapped = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model critic_unwrapped = critic_model.module if isinstance(critic_model, DistributedDataParallel) else critic_model + + # 强化学习训练迭代 for ppo_epoch in range(args.ppo_update_iters): if stop_ppo: break @@ -240,6 +244,7 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched actor_optimizer.zero_grad() critic_optimizer.zero_grad() + # 梯度更新 if grad_accum_step % args.accumulation_steps != 0: clip_grad_norm_(actor_model.parameters(), args.grad_clip) clip_grad_norm_(critic_model.parameters(), args.grad_clip) @@ -249,9 +254,12 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched critic_scheduler.step() actor_optimizer.zero_grad() critic_optimizer.zero_grad() - - if is_main_process() and step % args.save_interval == 0: rollout_engine.update_policy(actor_model) + ########################### 训练后操作 ########################### + # 模型更新 + if is_main_process() and step % args.save_interval == 0: rollout_engine.update_policy(actor_model) + + # 日志打印 if is_main_process(): critic_loss_val = value_loss_sum / max(log_count, 1) reward_val = rewards.mean().item() @@ -278,6 +286,7 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched f"ClipFrac: {clipfrac_val:.4f}, Critic Loss: {critic_loss_val:.4f}, " f"Avg Response Len: {avg_len_val:.2f}, Actor LR: {actor_lr:.8f}, Critic LR: {critic_lr:.8f}") + # 模型保存 if (step % args.save_interval == 0 or step == iters) and is_main_process(): actor_model.eval() moe_suffix = '_moe' if lm_config.use_moe else '' @@ -345,43 +354,37 @@ if __name__ == "__main__": parser.add_argument("--sglang_shared_path", type=str, default="./sglang_ckpt_ppo", help="SGLang共享存储路径") args = parser.parse_args() - # ========== 1. 初始化环境和随机种子 ========== + # ========== 1. 创建训练环境 ========== local_rank = init_distributed_mode() if dist.is_initialized(): args.device = f"cuda:{local_rank}" setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) - # ========== 2. 配置目录、模型参数、检查ckp ========== + # ========== 2. 模型相关 ========== os.makedirs(args.save_dir, exist_ok=True) - lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe)) - ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None - - # ========== 3. 设置混合精度 ========== + device_type = "cuda" if "cuda" in args.device else "cpu" dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) - - # ========== 4. 配wandb ========== - wandb = None - if args.use_wandb and is_main_process(): - import swanlab as wandb - wandb_id = ckp_data.get('wandb_id') if ckp_data else None - resume = 'must' if wandb_id else None - wandb_run_name = f"MiniMind-PPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}" - wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) - - # ========== 5. 初始化模型和数据 ========== + + lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe)) + base_weight = args.from_weight - # Actor模型 + + # LLM_PPO四大模型初始化 actor_model, tokenizer = init_model(lm_config, base_weight, device=args.device) + ref_model, _ = init_model(lm_config, base_weight, device=args.device) ref_model = ref_model.eval().requires_grad_(False) moe_suffix = '_moe' if lm_config.use_moe else '' ckp = f'{args.save_dir}/{base_weight}_{lm_config.hidden_size}{moe_suffix}.pth' state_dict = torch.load(ckp, map_location=args.device) + critic_model = CriticModel(lm_config) critic_model.load_state_dict(state_dict, strict=False) critic_model = critic_model.to(args.device) + reward_model = LMForRewardModel(args.reward_model_path, device=args.device, dtype=torch.float16) + # Rollout引擎 rollout_engine = create_rollout_engine( engine_type=args.rollout_engine, @@ -393,6 +396,7 @@ if __name__ == "__main__": sglang_model_path=args.sglang_model_path, sglang_shared_path=args.sglang_shared_path, ) + train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=(args.max_seq_len + args.max_gen_len), thinking_ratio=args.thinking_ratio) train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None actor_optimizer = optim.AdamW(actor_model.parameters(), lr=args.learning_rate) @@ -404,18 +408,6 @@ if __name__ == "__main__": actor_scheduler = CosineAnnealingLR(actor_optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10) critic_scheduler = CosineAnnealingLR(critic_optimizer, T_max=total_optimizer_steps, eta_min=args.critic_learning_rate / 10) - start_epoch, start_step = 0, 0 - if ckp_data: - actor_model.load_state_dict(ckp_data['model']) - critic_model.load_state_dict(ckp_data['critic_model']) - actor_optimizer.load_state_dict(ckp_data['optimizer']) - critic_optimizer.load_state_dict(ckp_data['critic_optimizer']) - actor_scheduler.load_state_dict(ckp_data['scheduler']) - critic_scheduler.load_state_dict(ckp_data['critic_scheduler']) - start_epoch = ckp_data['epoch'] - start_step = ckp_data.get('step', 0) - - # ========== 7. 编译和分布式包装 ========== if args.use_compile == 1: actor_model = torch.compile(actor_model) Logger('torch.compile enabled') @@ -426,8 +418,30 @@ if __name__ == "__main__": actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank]) critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank]) if is_main_process(): rollout_engine.update_policy(actor_model) + + ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None + + # ========== 3. checkpoint相关 ========== + wandb = None + if args.use_wandb and is_main_process(): + import swanlab as wandb + wandb_id = ckp_data.get('wandb_id') if ckp_data else None + resume = 'must' if wandb_id else None + wandb_run_name = f"MiniMind-PPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}" + wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) + + start_epoch, start_step = 0, 0 + if ckp_data: + actor_model.module.load_state_dict(ckp_data['model']) + critic_model.module.load_state_dict(ckp_data['critic_model']) + actor_optimizer.load_state_dict(ckp_data['optimizer']) + critic_optimizer.load_state_dict(ckp_data['critic_optimizer']) + actor_scheduler.load_state_dict(ckp_data['scheduler']) + critic_scheduler.load_state_dict(ckp_data['critic_scheduler']) + start_epoch = ckp_data['epoch'] + start_step = ckp_data.get('step', 0) - # ========== 8. 开始训练 ========== + # ========== 4. 开始训练 ========== for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist() @@ -440,5 +454,5 @@ if __name__ == "__main__": else: ppo_train_epoch(epoch, loader, len(loader), rollout_engine, ref_model, actor_scheduler, critic_scheduler, reward_model, 0, wandb, use_sglang = (args.rollout_engine == "sglang")) - # ========== 9. 清理分布进程 ========== + # ========== 5. 清理分布进程 ========== if dist.is_initialized(): dist.destroy_process_group() \ No newline at end of file diff --git a/trainer/train_pretrain.py b/trainer/train_pretrain.py index 40acdfb..600b850 100644 --- a/trainer/train_pretrain.py +++ b/trainer/train_pretrain.py @@ -22,23 +22,30 @@ warnings.filterwarnings('ignore') def train_epoch(epoch, loader, iters, start_step=0, wandb=None): start_time = time.time() - last_step = start_step for step, (input_ids, labels) in enumerate(loader, start=start_step + 1): + + ########################### 训练前操作 ########################### + # 数据加载 input_ids = input_ids.to(args.device) labels = labels.to(args.device) - last_step = step lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate) + + # 学习率调整 for param_group in optimizer.param_groups: param_group['lr'] = lr + ########################### 训练中操作 ########################### + # 模型前向传播 with autocast_ctx: res = model(input_ids, labels=labels) loss = res.loss + res.aux_loss loss = loss / args.accumulation_steps + # 模型反向传播 scaler.scale(loss).backward() - if step % args.accumulation_steps == 0: + # 梯度更新 + if (step + 1) % args.accumulation_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) @@ -47,17 +54,20 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None): optimizer.zero_grad(set_to_none=True) - if step % args.log_interval == 0 or step == iters: + ########################### 训练后操作 ########################### + # 日志打印 + if step % args.log_interval == 0 or step == iters - 1: spend_time = time.time() - start_time current_loss = loss.item() * args.accumulation_steps current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0 current_logits_loss = current_loss - current_aux_loss current_lr = optimizer.param_groups[-1]['lr'] - eta_min = spend_time / max(step - start_step, 1) * (iters - step) // 60 + eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60 Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min') if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min}) - if (step % args.save_interval == 0 or step == iters) and is_main_process(): + # 模型保存 + if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): model.eval() moe_suffix = '_moe' if lm_config.use_moe else '' ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' @@ -71,19 +81,12 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None): del input_ids, labels, res, loss - if last_step > start_step and last_step % args.accumulation_steps != 0: - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad(set_to_none=True) - if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind Pretraining") parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录") parser.add_argument('--save_weight', default='pretrain', type=str, help="保存权重的前缀名") - parser.add_argument("--epochs", type=int, default=2, help="训练轮数") + parser.add_argument("--epochs", type=int, default=1, help="训练轮数(建议1轮zero或2-6轮充分训练)") parser.add_argument("--batch_size", type=int, default=32, help="batch size") parser.add_argument("--learning_rate", type=float, default=5e-4, help="初始学习率") parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备") @@ -105,55 +108,52 @@ if __name__ == "__main__": parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)") args = parser.parse_args() - # ========== 1. 初始化环境和随机种子 ========== + # ========== 1. 创建训练环境 ========== local_rank = init_distributed_mode() - if dist.is_initialized(): args.device = f"cuda:{local_rank}" - setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) + if dist.is_initialized(): args.device = f"cuda:{local_rank}" # 分布式训练时,设置设备 + setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) # 设置随机种子 - # ========== 2. 配置目录、模型参数、检查ckp ========== - os.makedirs(args.save_dir, exist_ok=True) - lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe)) - ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None - - # ========== 3. 设置混合精度 ========== + # ========== 2. 模型相关 ========== + os.makedirs(args.save_dir, exist_ok=True) # 创建模型保存目录 + device_type = "cuda" if "cuda" in args.device else "cpu" dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 - autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) - - # ========== 4. 配wandb ========== + autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) # 设置混合精度 + + lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe)) + model, tokenizer = init_model(lm_config, args.from_weight, device=args.device) + if args.use_compile == 1: + model = torch.compile(model) + Logger('torch.compile enabled') + train_ds = PretrainDataset(args.data_path, tokenizer, max_length=args.max_seq_len) + train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None + scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) + optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) # 所有模型相关的初始化 + + if dist.is_initialized(): + model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} + model = DistributedDataParallel(model, device_ids=[local_rank]) # 分布式训练模型初始化 + + ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None # 检查checkpoint + + # ========== 3. checkpoint相关 ========== wandb = None if args.use_wandb and is_main_process(): import swanlab as wandb wandb_id = ckp_data.get('wandb_id') if ckp_data else None resume = 'must' if wandb_id else None wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}" - wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) - - # ========== 5. 定义模型、数据、优化器 ========== - model, tokenizer = init_model(lm_config, args.from_weight, device=args.device) - train_ds = PretrainDataset(args.data_path, tokenizer, max_length=args.max_seq_len) - train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None - scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) - optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) - - # ========== 6. 从ckp恢复状态 ========== + wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) # 通过checkpoint进行训练可视化 + start_epoch, start_step = 0, 0 if ckp_data: - model.load_state_dict(ckp_data['model']) + model.module.load_state_dict(ckp_data['model']) optimizer.load_state_dict(ckp_data['optimizer']) scaler.load_state_dict(ckp_data['scaler']) start_epoch = ckp_data['epoch'] - start_step = ckp_data.get('step', 0) + start_step = ckp_data.get('step', 0) # 通过checkpoint进行状态恢复 - # ========== 7. 编译和分布式包装 ========== - if args.use_compile == 1: - model = torch.compile(model) - Logger('torch.compile enabled') - if dist.is_initialized(): - model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[local_rank]) - - # ========== 8. 开始训练 ========== + # ========== 4. 训练 ========== for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist() @@ -166,5 +166,5 @@ if __name__ == "__main__": else: train_epoch(epoch, loader, len(loader), 0, wandb) - # ========== 9. 清理分布进程 ========== - if dist.is_initialized(): dist.destroy_process_group() \ No newline at end of file + # ========== 5. 撤销训练环境 ========== + if dist.is_initialized(): dist.destroy_process_group() # 撤销分布式训练环境 diff --git a/trainer/train_tokenizer.py b/trainer/train_tokenizer.py index 336a688..04b9da1 100644 --- a/trainer/train_tokenizer.py +++ b/trainer/train_tokenizer.py @@ -9,22 +9,32 @@ TOKENIZER_DIR = '../model_learn_tokenizer/' VOCAB_SIZE = 6400 SPECIAL_TOKENS_NUM = 36 +# 获取文本(train_tokenizer辅助函数) def get_texts(data_path): with open(data_path, 'r', encoding='utf-8', errors='ignore') as f: for i, line in enumerate(f): - if i >= 10000: break # 选10000行测试 + if i >= 10000: break # 选10000行测试(注释掉该行,可以进行所有数据集的文本读取) try: data = json.loads(line) + # 仅利用SFT数据集中conversations字段的content字段,忽略reasoning_content、tools、tool_calls字段 contents = [item.get('content') for item in data.get('conversations', []) if item.get('content')] if contents: + # 注意这里yeild的使用 yield "\n".join(contents) except json.JSONDecodeError: continue def train_tokenizer(data_path, tokenizer_dir, vocab_size, special_tokens_num=SPECIAL_TOKENS_NUM): - tokenizer = Tokenizer(models.BPE()) - tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) + ###########################训练前操作########################### + # 获取文本 + texts = get_texts(data_path) + # 创建tokenizer目录 + os.makedirs(tokenizer_dir, exist_ok=True) + tokenizer_json_path = os.path.join(tokenizer_dir, "tokenizer.json") + tokenizer_config_path = os.path.join(tokenizer_dir, "tokenizer_config.json") + + # special_tokens定义 special_tokens_list = [ "<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|object_ref_start|>", "<|object_ref_end|>", "<|box_start|>", "<|box_end|>", "<|quad_start|>", "<|quad_end|>", @@ -40,21 +50,28 @@ def train_tokenizer(data_path, tokenizer_dir, vocab_size, special_tokens_num=SPE num_buffer = special_tokens_num - len(special_tokens_list + additional_tokens_list) buffer_tokens = [f"<|buffer{i}|>" for i in range(1, num_buffer + 1)] # 预留一定数量的token位置 all_special_tokens = special_tokens_list + additional_tokens_list + buffer_tokens + + # 分词器、训练器初始化及预处理 + tokenizer = Tokenizer(models.BPE()) + tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) trainer = trainers.BpeTrainer( vocab_size=vocab_size, show_progress=True, initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), special_tokens=all_special_tokens ) - texts = get_texts(data_path) + + ###########################训练中操作########################### tokenizer.train_from_iterator(texts, trainer=trainer) + + ###########################训练后操作########################### + # 后处理 tokenizer.decoder = decoders.ByteLevel() tokenizer.add_special_tokens(special_tokens_list) - os.makedirs(tokenizer_dir, exist_ok=True) - tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json")) + # 修改并保存分词器json文件 + tokenizer.save(tokenizer_json_path) tokenizer.model.save(tokenizer_dir) - tokenizer_json_path = os.path.join(tokenizer_dir, "tokenizer.json") with open(tokenizer_json_path, 'r', encoding='utf-8') as f: tokenizer_data = json.load(f) for token_info in tokenizer_data.get('added_tokens', []): @@ -63,6 +80,7 @@ def train_tokenizer(data_path, tokenizer_dir, vocab_size, special_tokens_num=SPE with open(tokenizer_json_path, 'w', encoding='utf-8') as f: json.dump(tokenizer_data, f, ensure_ascii=False, indent=2) + # 创建并保存分词器config文件 added_tokens_decoder = {} for i, token in enumerate(all_special_tokens): idx = tokenizer.token_to_id(token) @@ -101,13 +119,19 @@ def train_tokenizer(data_path, tokenizer_dir, vocab_size, special_tokens_num=SPE "tokenizer_class": "PreTrainedTokenizerFast" } - with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w", encoding="utf-8") as f: + with open(tokenizer_config_path, "w", encoding="utf-8") as f: json.dump(config, f, ensure_ascii=False, indent=4) + + # 打印训练完成信息 print("Tokenizer training completed.") def eval_tokenizer(tokenizer_dir): from transformers import AutoTokenizer + ###########################评估前操作########################### + # 加载tokenizer tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) + + # 创建测试消息 messages = [ {"role": "system", "content": "你是一个优秀的聊天机器人,总是给我正确的回应!"}, {"role": "user", "content": '你来自哪里?'}, @@ -119,14 +143,21 @@ def eval_tokenizer(tokenizer_dir): messages, tokenize=False ) + + ###########################评估中操作########################### + # 聊天模版测试 print('-'*100) print(new_prompt) + + # 基础信息测试 print('-'*100) print('tokenizer词表长度:', len(tokenizer)) model_inputs = tokenizer(new_prompt) print('encoder长度:', len(model_inputs['input_ids'])) response = tokenizer.decode(model_inputs['input_ids'], skip_special_tokens=False) print('decoder一致性:', response == new_prompt, "\n") + + # 压缩率测试 print('-'*100) print('压缩率测试(Chars/Tokens):') test_texts = [ @@ -150,6 +181,8 @@ def eval_tokenizer(tokenizer_dir): print(f"样本 {i+1} | 字符数: {char_count:4} | Tokens: {token_count:3} | 压缩率: {compression_ratio:.2f}") print(f"平均压缩率: {total_compression / len(test_texts):.2f}") + + # 流式解码测试 print('-'*100) print('流式解码(字节缓冲)测试:') input_ids = model_inputs['input_ids'] @@ -162,6 +195,10 @@ def eval_tokenizer(tokenizer_dir): raw_tokens = [tokenizer.convert_ids_to_tokens(int(t)) for t in (token_cache if isinstance(token_cache, list) else [token_cache])] print(f'Token ID: {str(display_ids):15} -> Raw: {str(raw_tokens):20} -> Decode Str: {current_decode}') token_cache = [] + + ###########################评估后操作########################### + # 打印评估完成信息 + print("Tokenizer evaluation completed.") if __name__ == '__main__': train_tokenizer(DATA_PATH, TOKENIZER_DIR, VOCAB_SIZE)