[update] harden training and inference reliability

Fix truncated SFT/DPO label masking and resume-safe gradient accumulation to prevent unstable optimization, and add explicit eval backend selection plus regression checks for safer local validation.
This commit is contained in:
root 2026-04-05 18:11:37 +00:00
parent 25a7edcd6f
commit d37bfa9d75
10 changed files with 146 additions and 33 deletions

64
REGRESSION_CHECKS.md Normal file
View File

@ -0,0 +1,64 @@
# Regression Checks for Recent Fixes
This file records lightweight checks for the recent high-priority fixes.
## 1) Syntax checks
```bash
python -m py_compile dataset/lm_dataset.py trainer/train_pretrain.py trainer/train_full_sft.py trainer/train_lora.py trainer/train_distillation.py trainer/train_dpo.py trainer/train_grpo.py trainer/train_agent.py eval_llm.py
```
## 2) Dataset label/mask boundary check (no eos after truncation)
```bash
python - <<'PY'
from dataset.lm_dataset import SFTDataset, DPODataset
s = SFTDataset.__new__(SFTDataset)
s.bos_id = [11, 12]
s.eos_id = [13]
s.max_length = 10
raw_ids = [99, 11, 12, 21, 22] # truncated without eos
labels = s.generate_labels(raw_ids)
pad_len = s.max_length - len(raw_ids)
labels = labels + ([-100] * pad_len)
print("sft_pad_labels_ok", labels[5:] == [-100] * pad_len)
d = DPODataset.__new__(DPODataset)
d.bos_id = [11, 12]
d.eos_id = [13]
d.max_length = 10
mask = d.generate_loss_mask(raw_ids)
mask = mask + ([0] * pad_len)
print("dpo_pad_mask_ok", mask[5:] == [0] * pad_len)
PY
```
Expected output:
```text
sft_pad_labels_ok True
dpo_pad_mask_ok True
```
## 3) Eval backend argument check
```bash
python eval_llm.py -h
```
Expected: help output includes `--backend {auto,torch,hf}`.
## 4) Manual inference smoke tests
HF backend:
```bash
python eval_llm.py --backend hf --load_from ./minimind-3 --max_new_tokens 64 --temperature 0.2 --top_p 0.8
```
Torch backend:
```bash
python eval_llm.py --backend torch --load_from model --weight full_sft --max_new_tokens 64 --temperature 0.2 --top_p 0.8
```

View File

@ -92,11 +92,19 @@ class SFTDataset(Dataset):
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
start = i + len(self.bos_id)
end = start
found_eos = False
while end < len(input_ids):
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
found_eos = True
break
end += 1
for j in range(start, min(end + len(self.eos_id), self.max_length)):
# If eos is not found (e.g. after truncation), only supervise real tokens,
# never extend to padding positions.
if found_eos:
supervise_end = min(end + len(self.eos_id), len(input_ids))
else:
supervise_end = min(end, len(input_ids))
for j in range(start, supervise_end):
labels[j] = input_ids[j]
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
else:
@ -109,8 +117,10 @@ class SFTDataset(Dataset):
prompt = self.create_chat_prompt(conversations)
prompt = post_processing_chat(prompt)
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
labels = self.generate_labels(input_ids)
pad_len = self.max_length - len(input_ids)
input_ids += [self.tokenizer.pad_token_id] * pad_len
labels += [-100] * pad_len
# # === 调试打印 ===
# print(f"\n--- Sample {index} ---")
# for i, (x, y) in enumerate(zip(input_ids[:-1], labels[1:])):
@ -145,18 +155,21 @@ class DPODataset(Dataset):
rejected, tokenize=False, add_generation_prompt=False
)
rejected_prompt = post_processing_chat(rejected_prompt)
chosen_encoding = self.tokenizer(
chosen_prompt, truncation=True, max_length=self.max_length, padding='max_length'
)
rejected_encoding = self.tokenizer(
rejected_prompt, truncation=True, max_length=self.max_length, padding='max_length'
)
chosen_input_ids = chosen_encoding['input_ids']
chosen_input_ids = self.tokenizer(
chosen_prompt, truncation=True, max_length=self.max_length, padding=False
)['input_ids']
chosen_loss_mask = self.generate_loss_mask(chosen_input_ids)
chosen_pad_len = self.max_length - len(chosen_input_ids)
chosen_input_ids += [self.padding] * chosen_pad_len
chosen_loss_mask += [0] * chosen_pad_len
rejected_input_ids = rejected_encoding['input_ids']
rejected_input_ids = self.tokenizer(
rejected_prompt, truncation=True, max_length=self.max_length, padding=False
)['input_ids']
rejected_loss_mask = self.generate_loss_mask(rejected_input_ids)
rejected_pad_len = self.max_length - len(rejected_input_ids)
rejected_input_ids += [self.padding] * rejected_pad_len
rejected_loss_mask += [0] * rejected_pad_len
x_chosen = torch.tensor(chosen_input_ids[:-1], dtype=torch.long)
y_chosen = torch.tensor(chosen_input_ids[1:], dtype=torch.long)
mask_chosen = torch.tensor(chosen_loss_mask[1:], dtype=torch.long)
@ -180,11 +193,17 @@ class DPODataset(Dataset):
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
start = i + len(self.bos_id)
end = start
found_eos = False
while end < len(input_ids):
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
found_eos = True
break
end += 1
for j in range(start, min(end + len(self.eos_id), self.max_length)):
if found_eos:
supervise_end = min(end + len(self.eos_id), len(input_ids))
else:
supervise_end = min(end, len(input_ids))
for j in range(start, supervise_end):
loss_mask[j] = 1
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
else:

View File

@ -2,6 +2,7 @@ import time
import argparse
import random
import warnings
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
@ -11,7 +12,18 @@ warnings.filterwarnings('ignore')
def init_model(args):
tokenizer = AutoTokenizer.from_pretrained(args.load_from)
if 'model' in args.load_from:
use_torch_backend = args.backend == 'torch'
if args.backend == 'auto':
if args.load_from == 'model':
use_torch_backend = True
elif os.path.isdir(args.load_from):
has_hf_config = os.path.exists(os.path.join(args.load_from, "config.json"))
has_hf_weight = os.path.exists(os.path.join(args.load_from, "model.safetensors")) or os.path.exists(os.path.join(args.load_from, "pytorch_model.bin"))
use_torch_backend = not (has_hf_config or has_hf_weight)
else:
use_torch_backend = False
if use_torch_backend:
model = MiniMindForCausalLM(MiniMindConfig(
hidden_size=args.hidden_size,
num_hidden_layers=args.num_hidden_layers,
@ -32,6 +44,7 @@ def init_model(args):
def main():
parser = argparse.ArgumentParser(description="MiniMind模型推理与对话")
parser.add_argument('--load_from', default='model', type=str, help="模型加载路径model=原生torch权重其他路径=transformers格式")
parser.add_argument('--backend', default='auto', choices=['auto', 'torch', 'hf'], help="模型加载后端auto/torch/hf")
parser.add_argument('--save_dir', default='out', type=str, help="模型权重目录")
parser.add_argument('--weight', default='full_sft', type=str, help="权重名称前缀pretrain, full_sft, rlhf, reason, ppo_actor, grpo, spo")
parser.add_argument('--lora_weight', default='None', type=str, help="LoRA权重名称None表示不使用可选lora_identity, lora_medical")

View File

@ -240,6 +240,7 @@ def calculate_rewards(prompts, completions, gt_batch, tools_batch, num_gen, rewa
# ================================ 工具与 Reward = End ================================
def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model=None, start_step=0, wandb=None, use_sglang=False):
last_step = start_step
accum_counter = start_step % args.accumulation_steps
for step, batch in enumerate(loader, start=start_step + 1):
messages_batch = batch['messages']
tools_batch = batch['tools']
@ -328,8 +329,9 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model
if valid_rows.any() else per_token_loss.sum() * 0.0)
loss = (policy_loss + aux_loss) / args.accumulation_steps
loss.backward()
accum_counter += 1
if step % args.accumulation_steps == 0:
if accum_counter % args.accumulation_steps == 0:
if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step(); scheduler.step(); optimizer.zero_grad()
if is_main_process() and step % args.save_interval == 0: rollout_engine.update_policy(model)
@ -362,7 +364,7 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model
del per_token_logps, ref_per_token_logps
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
if last_step > start_step and last_step % args.accumulation_steps != 0:
if last_step > start_step and accum_counter % args.accumulation_steps != 0:
if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step(); scheduler.step(); optimizer.zero_grad()
if is_main_process() and last_step % args.save_interval == 0: rollout_engine.update_policy(model)

View File

@ -38,6 +38,7 @@ def distillation_loss(student_logits, teacher_logits, temperature=1.0, reduction
def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_step=0, wandb=None, alpha=0.0, temperature=1.0):
start_time = time.time()
last_step = start_step
accum_counter = start_step % args.accumulation_steps
if teacher_model is not None:
teacher_model.eval()
@ -92,8 +93,9 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st
loss = (alpha * ce_loss + (1 - alpha) * distill_loss) / args.accumulation_steps
scaler.scale(loss).backward()
accum_counter += 1
if step % args.accumulation_steps == 0:
if accum_counter % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
@ -134,7 +136,7 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st
del input_ids, labels, loss_mask, res, student_logits, ce_loss, distill_loss, loss
if last_step > start_step and last_step % args.accumulation_steps != 0:
if last_step > start_step and accum_counter % args.accumulation_steps != 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)

View File

@ -52,6 +52,7 @@ def dpo_loss(ref_log_probs, policy_log_probs, mask, beta):
def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=None, beta=0.1):
start_time = time.time()
last_step = start_step
accum_counter = start_step % args.accumulation_steps
for step, batch in enumerate(loader, start=start_step + 1):
last_step = step
@ -84,8 +85,9 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
loss = loss / args.accumulation_steps
scaler.scale(loss).backward()
accum_counter += 1
if step % args.accumulation_steps == 0:
if accum_counter % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
@ -119,7 +121,7 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
del x_chosen, x_rejected, y_chosen, y_rejected, mask_chosen, mask_rejected, x, y, mask
del ref_outputs, ref_logits, ref_log_probs, outputs, logits, policy_log_probs, loss
if last_step > start_step and last_step % args.accumulation_steps != 0:
if last_step > start_step and accum_counter % args.accumulation_steps != 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)

View File

@ -23,6 +23,7 @@ warnings.filterwarnings('ignore')
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
start_time = time.time()
last_step = start_step
accum_counter = start_step % args.accumulation_steps
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
input_ids = input_ids.to(args.device)
labels = labels.to(args.device)
@ -37,8 +38,9 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
loss = loss / args.accumulation_steps
scaler.scale(loss).backward()
accum_counter += 1
if step % args.accumulation_steps == 0:
if accum_counter % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
@ -72,7 +74,7 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
del input_ids, labels, res, loss
if last_step > start_step and last_step % args.accumulation_steps != 0:
if last_step > start_step and accum_counter % args.accumulation_steps != 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)

View File

@ -68,7 +68,10 @@ def calculate_rewards(prompts, responses, reward_model):
def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model, start_step=0, wandb=None, use_sglang=False):
last_step = start_step
accum_counter = start_step % args.accumulation_steps
for step, batch in enumerate(loader, start=start_step + 1):
last_step = step
prompts = batch['prompt'] # list[str], length B
prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
padding_side="left", add_special_tokens=False).to(args.device)
@ -142,8 +145,9 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
policy_loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
loss = (policy_loss + aux_loss) / args.accumulation_steps # scalar
loss.backward()
accum_counter += 1
if step % args.accumulation_steps == 0:
if accum_counter % args.accumulation_steps == 0:
if args.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
@ -190,7 +194,7 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
model.train()
del state_dict
if step > start_step and step % args.accumulation_steps != 0:
if last_step > start_step and accum_counter % args.accumulation_steps != 0:
if args.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()

View File

@ -24,6 +24,7 @@ warnings.filterwarnings('ignore')
def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
start_time = time.time()
last_step = start_step
accum_counter = start_step % args.accumulation_steps
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
input_ids = input_ids.to(args.device)
labels = labels.to(args.device)
@ -38,8 +39,9 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
loss = loss / args.accumulation_steps
scaler.scale(loss).backward()
accum_counter += 1
if step % args.accumulation_steps == 0:
if accum_counter % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
scaler.step(optimizer)
@ -66,7 +68,7 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
del input_ids, labels, res, loss
if last_step > start_step and last_step % args.accumulation_steps != 0:
if last_step > start_step and accum_counter % args.accumulation_steps != 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
scaler.step(optimizer)

View File

@ -23,22 +23,25 @@ warnings.filterwarnings('ignore')
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
start_time = time.time()
last_step = start_step
accum_counter = start_step % args.accumulation_steps
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
input_ids = input_ids.to(args.device)
labels = labels.to(args.device)
attention_mask = (input_ids != tokenizer.pad_token_id).long()
last_step = step
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with autocast_ctx:
res = model(input_ids, labels=labels)
res = model(input_ids, attention_mask=attention_mask, labels=labels)
loss = res.loss + res.aux_loss
loss = loss / args.accumulation_steps
scaler.scale(loss).backward()
accum_counter += 1
if step % args.accumulation_steps == 0:
if accum_counter % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
@ -69,9 +72,9 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
model.train()
del state_dict
del input_ids, labels, res, loss
del input_ids, labels, attention_mask, res, loss
if last_step > start_step and last_step % args.accumulation_steps != 0:
if last_step > start_step and accum_counter % args.accumulation_steps != 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
@ -83,7 +86,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Pretraining")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument('--save_weight', default='pretrain', type=str, help="保存权重的前缀名")
parser.add_argument("--epochs", type=int, default=2, help="训练轮数")
parser.add_argument("--epochs", type=int, default=3, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
parser.add_argument("--learning_rate", type=float, default=5e-4, help="初始学习率")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
@ -94,15 +97,15 @@ if __name__ == "__main__":
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔")
parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--num_hidden_layers', default=12, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度中文1token≈1.5~1.7字符)")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument("--data_path", type=str, default="../dataset/pretrain_t2t_mini.jsonl", help="预训练数据路径")
parser.add_argument('--from_weight', default='none', type=str, help="基于哪个权重训练为none则从头开始")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument('--from_resume', default=1, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速0=否1=是)")
parser.add_argument("--use_compile", default=1, type=int, choices=[0, 1], help="是否使用torch.compile加速0=否1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========