From d37bfa9d75142e0d2bc273b5a26f648f5077f16e Mon Sep 17 00:00:00 2001 From: root Date: Sun, 5 Apr 2026 18:11:37 +0000 Subject: [PATCH] [update] harden training and inference reliability Fix truncated SFT/DPO label masking and resume-safe gradient accumulation to prevent unstable optimization, and add explicit eval backend selection plus regression checks for safer local validation. --- REGRESSION_CHECKS.md | 64 +++++++++++++++++++++++++++++++++++ dataset/lm_dataset.py | 43 ++++++++++++++++------- eval_llm.py | 15 +++++++- trainer/train_agent.py | 6 ++-- trainer/train_distillation.py | 6 ++-- trainer/train_dpo.py | 6 ++-- trainer/train_full_sft.py | 6 ++-- trainer/train_grpo.py | 8 +++-- trainer/train_lora.py | 6 ++-- trainer/train_pretrain.py | 19 ++++++----- 10 files changed, 146 insertions(+), 33 deletions(-) create mode 100644 REGRESSION_CHECKS.md diff --git a/REGRESSION_CHECKS.md b/REGRESSION_CHECKS.md new file mode 100644 index 0000000..0d48dd2 --- /dev/null +++ b/REGRESSION_CHECKS.md @@ -0,0 +1,64 @@ +# Regression Checks for Recent Fixes + +This file records lightweight checks for the recent high-priority fixes. + +## 1) Syntax checks + +```bash +python -m py_compile dataset/lm_dataset.py trainer/train_pretrain.py trainer/train_full_sft.py trainer/train_lora.py trainer/train_distillation.py trainer/train_dpo.py trainer/train_grpo.py trainer/train_agent.py eval_llm.py +``` + +## 2) Dataset label/mask boundary check (no eos after truncation) + +```bash +python - <<'PY' +from dataset.lm_dataset import SFTDataset, DPODataset + +s = SFTDataset.__new__(SFTDataset) +s.bos_id = [11, 12] +s.eos_id = [13] +s.max_length = 10 +raw_ids = [99, 11, 12, 21, 22] # truncated without eos +labels = s.generate_labels(raw_ids) +pad_len = s.max_length - len(raw_ids) +labels = labels + ([-100] * pad_len) +print("sft_pad_labels_ok", labels[5:] == [-100] * pad_len) + +d = DPODataset.__new__(DPODataset) +d.bos_id = [11, 12] +d.eos_id = [13] +d.max_length = 10 +mask = d.generate_loss_mask(raw_ids) +mask = mask + ([0] * pad_len) +print("dpo_pad_mask_ok", mask[5:] == [0] * pad_len) +PY +``` + +Expected output: + +```text +sft_pad_labels_ok True +dpo_pad_mask_ok True +``` + +## 3) Eval backend argument check + +```bash +python eval_llm.py -h +``` + +Expected: help output includes `--backend {auto,torch,hf}`. + +## 4) Manual inference smoke tests + +HF backend: + +```bash +python eval_llm.py --backend hf --load_from ./minimind-3 --max_new_tokens 64 --temperature 0.2 --top_p 0.8 +``` + +Torch backend: + +```bash +python eval_llm.py --backend torch --load_from model --weight full_sft --max_new_tokens 64 --temperature 0.2 --top_p 0.8 +``` diff --git a/dataset/lm_dataset.py b/dataset/lm_dataset.py index bbc4197..b1d1638 100644 --- a/dataset/lm_dataset.py +++ b/dataset/lm_dataset.py @@ -92,11 +92,19 @@ class SFTDataset(Dataset): if input_ids[i:i + len(self.bos_id)] == self.bos_id: start = i + len(self.bos_id) end = start + found_eos = False while end < len(input_ids): if input_ids[end:end + len(self.eos_id)] == self.eos_id: + found_eos = True break end += 1 - for j in range(start, min(end + len(self.eos_id), self.max_length)): + # If eos is not found (e.g. after truncation), only supervise real tokens, + # never extend to padding positions. + if found_eos: + supervise_end = min(end + len(self.eos_id), len(input_ids)) + else: + supervise_end = min(end, len(input_ids)) + for j in range(start, supervise_end): labels[j] = input_ids[j] i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids) else: @@ -109,8 +117,10 @@ class SFTDataset(Dataset): prompt = self.create_chat_prompt(conversations) prompt = post_processing_chat(prompt) input_ids = self.tokenizer(prompt).input_ids[:self.max_length] - input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids)) labels = self.generate_labels(input_ids) + pad_len = self.max_length - len(input_ids) + input_ids += [self.tokenizer.pad_token_id] * pad_len + labels += [-100] * pad_len # # === 调试打印 === # print(f"\n--- Sample {index} ---") # for i, (x, y) in enumerate(zip(input_ids[:-1], labels[1:])): @@ -145,18 +155,21 @@ class DPODataset(Dataset): rejected, tokenize=False, add_generation_prompt=False ) rejected_prompt = post_processing_chat(rejected_prompt) - chosen_encoding = self.tokenizer( - chosen_prompt, truncation=True, max_length=self.max_length, padding='max_length' - ) - rejected_encoding = self.tokenizer( - rejected_prompt, truncation=True, max_length=self.max_length, padding='max_length' - ) - - chosen_input_ids = chosen_encoding['input_ids'] + chosen_input_ids = self.tokenizer( + chosen_prompt, truncation=True, max_length=self.max_length, padding=False + )['input_ids'] chosen_loss_mask = self.generate_loss_mask(chosen_input_ids) + chosen_pad_len = self.max_length - len(chosen_input_ids) + chosen_input_ids += [self.padding] * chosen_pad_len + chosen_loss_mask += [0] * chosen_pad_len - rejected_input_ids = rejected_encoding['input_ids'] + rejected_input_ids = self.tokenizer( + rejected_prompt, truncation=True, max_length=self.max_length, padding=False + )['input_ids'] rejected_loss_mask = self.generate_loss_mask(rejected_input_ids) + rejected_pad_len = self.max_length - len(rejected_input_ids) + rejected_input_ids += [self.padding] * rejected_pad_len + rejected_loss_mask += [0] * rejected_pad_len x_chosen = torch.tensor(chosen_input_ids[:-1], dtype=torch.long) y_chosen = torch.tensor(chosen_input_ids[1:], dtype=torch.long) mask_chosen = torch.tensor(chosen_loss_mask[1:], dtype=torch.long) @@ -180,11 +193,17 @@ class DPODataset(Dataset): if input_ids[i:i + len(self.bos_id)] == self.bos_id: start = i + len(self.bos_id) end = start + found_eos = False while end < len(input_ids): if input_ids[end:end + len(self.eos_id)] == self.eos_id: + found_eos = True break end += 1 - for j in range(start, min(end + len(self.eos_id), self.max_length)): + if found_eos: + supervise_end = min(end + len(self.eos_id), len(input_ids)) + else: + supervise_end = min(end, len(input_ids)) + for j in range(start, supervise_end): loss_mask[j] = 1 i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids) else: diff --git a/eval_llm.py b/eval_llm.py index 9359ca8..6c38bb2 100755 --- a/eval_llm.py +++ b/eval_llm.py @@ -2,6 +2,7 @@ import time import argparse import random import warnings +import os import torch from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer from model.model_minimind import MiniMindConfig, MiniMindForCausalLM @@ -11,7 +12,18 @@ warnings.filterwarnings('ignore') def init_model(args): tokenizer = AutoTokenizer.from_pretrained(args.load_from) - if 'model' in args.load_from: + use_torch_backend = args.backend == 'torch' + if args.backend == 'auto': + if args.load_from == 'model': + use_torch_backend = True + elif os.path.isdir(args.load_from): + has_hf_config = os.path.exists(os.path.join(args.load_from, "config.json")) + has_hf_weight = os.path.exists(os.path.join(args.load_from, "model.safetensors")) or os.path.exists(os.path.join(args.load_from, "pytorch_model.bin")) + use_torch_backend = not (has_hf_config or has_hf_weight) + else: + use_torch_backend = False + + if use_torch_backend: model = MiniMindForCausalLM(MiniMindConfig( hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, @@ -32,6 +44,7 @@ def init_model(args): def main(): parser = argparse.ArgumentParser(description="MiniMind模型推理与对话") parser.add_argument('--load_from', default='model', type=str, help="模型加载路径(model=原生torch权重,其他路径=transformers格式)") + parser.add_argument('--backend', default='auto', choices=['auto', 'torch', 'hf'], help="模型加载后端(auto/torch/hf)") parser.add_argument('--save_dir', default='out', type=str, help="模型权重目录") parser.add_argument('--weight', default='full_sft', type=str, help="权重名称前缀(pretrain, full_sft, rlhf, reason, ppo_actor, grpo, spo)") parser.add_argument('--lora_weight', default='None', type=str, help="LoRA权重名称(None表示不使用,可选:lora_identity, lora_medical)") diff --git a/trainer/train_agent.py b/trainer/train_agent.py index 9a84b69..9176b28 100644 --- a/trainer/train_agent.py +++ b/trainer/train_agent.py @@ -240,6 +240,7 @@ def calculate_rewards(prompts, completions, gt_batch, tools_batch, num_gen, rewa # ================================ 工具与 Reward = End ================================ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model=None, start_step=0, wandb=None, use_sglang=False): last_step = start_step + accum_counter = start_step % args.accumulation_steps for step, batch in enumerate(loader, start=start_step + 1): messages_batch = batch['messages'] tools_batch = batch['tools'] @@ -328,8 +329,9 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model if valid_rows.any() else per_token_loss.sum() * 0.0) loss = (policy_loss + aux_loss) / args.accumulation_steps loss.backward() + accum_counter += 1 - if step % args.accumulation_steps == 0: + if accum_counter % args.accumulation_steps == 0: if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step(); scheduler.step(); optimizer.zero_grad() if is_main_process() and step % args.save_interval == 0: rollout_engine.update_policy(model) @@ -362,7 +364,7 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model del per_token_logps, ref_per_token_logps del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask - if last_step > start_step and last_step % args.accumulation_steps != 0: + if last_step > start_step and accum_counter % args.accumulation_steps != 0: if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step(); scheduler.step(); optimizer.zero_grad() if is_main_process() and last_step % args.save_interval == 0: rollout_engine.update_policy(model) diff --git a/trainer/train_distillation.py b/trainer/train_distillation.py index 6b1b5b9..be9f668 100644 --- a/trainer/train_distillation.py +++ b/trainer/train_distillation.py @@ -38,6 +38,7 @@ def distillation_loss(student_logits, teacher_logits, temperature=1.0, reduction def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_step=0, wandb=None, alpha=0.0, temperature=1.0): start_time = time.time() last_step = start_step + accum_counter = start_step % args.accumulation_steps if teacher_model is not None: teacher_model.eval() @@ -92,8 +93,9 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st loss = (alpha * ce_loss + (1 - alpha) * distill_loss) / args.accumulation_steps scaler.scale(loss).backward() + accum_counter += 1 - if step % args.accumulation_steps == 0: + if accum_counter % args.accumulation_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) scaler.step(optimizer) @@ -134,7 +136,7 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st del input_ids, labels, loss_mask, res, student_logits, ce_loss, distill_loss, loss - if last_step > start_step and last_step % args.accumulation_steps != 0: + if last_step > start_step and accum_counter % args.accumulation_steps != 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) scaler.step(optimizer) diff --git a/trainer/train_dpo.py b/trainer/train_dpo.py index 1876e26..c01f34a 100644 --- a/trainer/train_dpo.py +++ b/trainer/train_dpo.py @@ -52,6 +52,7 @@ def dpo_loss(ref_log_probs, policy_log_probs, mask, beta): def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=None, beta=0.1): start_time = time.time() last_step = start_step + accum_counter = start_step % args.accumulation_steps for step, batch in enumerate(loader, start=start_step + 1): last_step = step @@ -84,8 +85,9 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb= loss = loss / args.accumulation_steps scaler.scale(loss).backward() + accum_counter += 1 - if step % args.accumulation_steps == 0: + if accum_counter % args.accumulation_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) scaler.step(optimizer) @@ -119,7 +121,7 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb= del x_chosen, x_rejected, y_chosen, y_rejected, mask_chosen, mask_rejected, x, y, mask del ref_outputs, ref_logits, ref_log_probs, outputs, logits, policy_log_probs, loss - if last_step > start_step and last_step % args.accumulation_steps != 0: + if last_step > start_step and accum_counter % args.accumulation_steps != 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) scaler.step(optimizer) diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py index f685760..a717a15 100644 --- a/trainer/train_full_sft.py +++ b/trainer/train_full_sft.py @@ -23,6 +23,7 @@ warnings.filterwarnings('ignore') def train_epoch(epoch, loader, iters, start_step=0, wandb=None): start_time = time.time() last_step = start_step + accum_counter = start_step % args.accumulation_steps for step, (input_ids, labels) in enumerate(loader, start=start_step + 1): input_ids = input_ids.to(args.device) labels = labels.to(args.device) @@ -37,8 +38,9 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None): loss = loss / args.accumulation_steps scaler.scale(loss).backward() + accum_counter += 1 - if step % args.accumulation_steps == 0: + if accum_counter % args.accumulation_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) @@ -72,7 +74,7 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None): del input_ids, labels, res, loss - if last_step > start_step and last_step % args.accumulation_steps != 0: + if last_step > start_step and accum_counter % args.accumulation_steps != 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) scaler.step(optimizer) diff --git a/trainer/train_grpo.py b/trainer/train_grpo.py index 2d514e2..ea72a94 100755 --- a/trainer/train_grpo.py +++ b/trainer/train_grpo.py @@ -68,7 +68,10 @@ def calculate_rewards(prompts, responses, reward_model): def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model, start_step=0, wandb=None, use_sglang=False): + last_step = start_step + accum_counter = start_step % args.accumulation_steps for step, batch in enumerate(loader, start=start_step + 1): + last_step = step prompts = batch['prompt'] # list[str], length B prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False, padding_side="left", add_special_tokens=False).to(args.device) @@ -142,8 +145,9 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod policy_loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() loss = (policy_loss + aux_loss) / args.accumulation_steps # scalar loss.backward() + accum_counter += 1 - if step % args.accumulation_steps == 0: + if accum_counter % args.accumulation_steps == 0: if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() @@ -190,7 +194,7 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod model.train() del state_dict - if step > start_step and step % args.accumulation_steps != 0: + if last_step > start_step and accum_counter % args.accumulation_steps != 0: if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() diff --git a/trainer/train_lora.py b/trainer/train_lora.py index 4dae568..76de7fd 100644 --- a/trainer/train_lora.py +++ b/trainer/train_lora.py @@ -24,6 +24,7 @@ warnings.filterwarnings('ignore') def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None): start_time = time.time() last_step = start_step + accum_counter = start_step % args.accumulation_steps for step, (input_ids, labels) in enumerate(loader, start=start_step + 1): input_ids = input_ids.to(args.device) labels = labels.to(args.device) @@ -38,8 +39,9 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None): loss = loss / args.accumulation_steps scaler.scale(loss).backward() + accum_counter += 1 - if step % args.accumulation_steps == 0: + if accum_counter % args.accumulation_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip) scaler.step(optimizer) @@ -66,7 +68,7 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None): del input_ids, labels, res, loss - if last_step > start_step and last_step % args.accumulation_steps != 0: + if last_step > start_step and accum_counter % args.accumulation_steps != 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip) scaler.step(optimizer) diff --git a/trainer/train_pretrain.py b/trainer/train_pretrain.py index 40acdfb..4d0b68a 100644 --- a/trainer/train_pretrain.py +++ b/trainer/train_pretrain.py @@ -23,22 +23,25 @@ warnings.filterwarnings('ignore') def train_epoch(epoch, loader, iters, start_step=0, wandb=None): start_time = time.time() last_step = start_step + accum_counter = start_step % args.accumulation_steps for step, (input_ids, labels) in enumerate(loader, start=start_step + 1): input_ids = input_ids.to(args.device) labels = labels.to(args.device) + attention_mask = (input_ids != tokenizer.pad_token_id).long() last_step = step lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate) for param_group in optimizer.param_groups: param_group['lr'] = lr with autocast_ctx: - res = model(input_ids, labels=labels) + res = model(input_ids, attention_mask=attention_mask, labels=labels) loss = res.loss + res.aux_loss loss = loss / args.accumulation_steps scaler.scale(loss).backward() + accum_counter += 1 - if step % args.accumulation_steps == 0: + if accum_counter % args.accumulation_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) @@ -69,9 +72,9 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None): model.train() del state_dict - del input_ids, labels, res, loss + del input_ids, labels, attention_mask, res, loss - if last_step > start_step and last_step % args.accumulation_steps != 0: + if last_step > start_step and accum_counter % args.accumulation_steps != 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) scaler.step(optimizer) @@ -83,7 +86,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind Pretraining") parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录") parser.add_argument('--save_weight', default='pretrain', type=str, help="保存权重的前缀名") - parser.add_argument("--epochs", type=int, default=2, help="训练轮数") + parser.add_argument("--epochs", type=int, default=3, help="训练轮数") parser.add_argument("--batch_size", type=int, default=32, help="batch size") parser.add_argument("--learning_rate", type=float, default=5e-4, help="初始学习率") parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备") @@ -94,15 +97,15 @@ if __name__ == "__main__": parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔") parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔") parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度") - parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") + parser.add_argument('--num_hidden_layers', default=12, type=int, help="隐藏层数量") parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)") parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)") parser.add_argument("--data_path", type=str, default="../dataset/pretrain_t2t_mini.jsonl", help="预训练数据路径") parser.add_argument('--from_weight', default='none', type=str, help="基于哪个权重训练,为none则从头开始") - parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)") + parser.add_argument('--from_resume', default=1, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)") parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb") parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain", help="wandb项目名") - parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)") + parser.add_argument("--use_compile", default=1, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)") args = parser.parse_args() # ========== 1. 初始化环境和随机种子 ==========