From 020bd44f3ffed4eaf436b73ae65095de767a186b Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 30 Jan 2026 11:03:35 +0800 Subject: [PATCH] [mod] fix spo algorithm in RLAIF part --- dataset/lm_dataset.py | 48 ++++++++ trainer/train_spo.py | 250 ++++++++++++++++++++++++++++-------------- 2 files changed, 217 insertions(+), 81 deletions(-) diff --git a/dataset/lm_dataset.py b/dataset/lm_dataset.py index 89f68f3..f4c5736 100644 --- a/dataset/lm_dataset.py +++ b/dataset/lm_dataset.py @@ -195,6 +195,53 @@ class DPODataset(Dataset): return loss_mask +# 添加SPOdataset +class SPODataset(Dataset): + def __init__(self, jsonl_path, tokenizer, max_length=1024): + super().__init__() + self.tokenizer = tokenizer + self.max_length = max_length + self.samples = self.load_data(jsonl_path) + self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids + self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids + + def __len__(self): + return len(self.samples) + + def load_data(self, path): + samples = [] + with open(path, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f, 1): + data = json.loads(line.strip()) + samples.append(data) + return samples + + def _create_chat_prompt(self, conversations): + """构建符合ChatML格式的对话""" + messages = [] + answer = '' + for i, turn in enumerate(conversations): + role = 'user' if i % 2 == 0 else 'assistant' + messages.append({"role": role, "content": turn['content']}) + answer = turn['content'] + return self.tokenizer.apply_chat_template( + messages[:-1], + tokenize=False, + add_generation_prompt=True # 这里需要True + ), answer + + def __getitem__(self, index): + sample = self.samples[index] + # 构建对话提示 + prompt, answer = self._create_chat_prompt(sample['conversations']) + + return { + 'prompt': prompt, + 'answer': answer, + 'index': index # 关键修改:返回索引 + } + + class RLAIFDataset(Dataset): def __init__(self, jsonl_path, tokenizer, max_length=1024): super().__init__() @@ -240,5 +287,6 @@ class RLAIFDataset(Dataset): } + if __name__ == "__main__": pass diff --git a/trainer/train_spo.py b/trainer/train_spo.py index bac7c14..9a08697 100755 --- a/trainer/train_spo.py +++ b/trainer/train_spo.py @@ -18,14 +18,47 @@ from torch.utils.data import DataLoader, DistributedSampler from torch.optim.lr_scheduler import CosineAnnealingLR from transformers import AutoModel from model.model_minimind import MiniMindConfig, MiniMindForCausalLM -from dataset.lm_dataset import RLAIFDataset +from dataset.lm_dataset import SPODataset from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model warnings.filterwarnings('ignore') +# --- 1. 自定义优先采样器 (带分布式同步) --- +class WeightedDistributedSampler(DistributedSampler): + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, epsilon=1e-5, device='cpu'): + super().__init__(dataset, num_replicas, rank, shuffle, seed) + # [修改点]:允许指定 weights 存放的设备,默认 CPU + self.weights = torch.ones(len(dataset), dtype=torch.float32).to(device) + self.epsilon = epsilon + self.device = device + + def update_weights(self, indices, new_v_estimates): + """更新权重""" + # 如果 self.weights 在 GPU,这里就不再产生同步阻塞 + self.weights[indices] = new_v_estimates.to(self.weights.device) + + def sync_weights(self): + """核心:在 Epoch 结束时同步所有卡的权重,防止采样漂移""" + if dist.is_initialized(): + dist.all_reduce(self.weights, op=dist.ReduceOp.SUM) + self.weights /= dist.get_world_size() + + def __iter__(self): + # 优先级公式:sqrt(v * (1-v)) + epsilon + priority = torch.sqrt(self.weights * (1 - self.weights) + 1e-8) + self.epsilon + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + + # 全局加权采样 + indices = torch.multinomial(priority, self.total_size, replacement=True, generator=g).tolist() + # 分配到当前进程 (Rank) + indices = indices[self.rank:self.total_size:self.num_replicas] + return iter(indices) + + +# --- 2. SPO 自适应价值追踪器 --- class AutoAdaptiveValueTracker: - """SPO自适应价值追踪器""" def __init__(self, rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96): self.rho_mode = rho_mode self.rho_const = rho_const @@ -42,9 +75,7 @@ class AutoAdaptiveValueTracker: return torch.full((batch_size,), baseline, dtype=torch.float32) def compute_rho(self, cur_mean_logprob): - if self.rho_mode == 'constant': - return self.rho_const - if self.old_mean_logprob is None: + if self.rho_mode == 'constant' or self.old_mean_logprob is None: return self.rho_const kl = abs(self.old_mean_logprob - cur_mean_logprob) rho = 2 ** (-kl / self.D_half) @@ -64,7 +95,7 @@ class AutoAdaptiveValueTracker: self.alpha = rho * self.alpha + avg_normalized_reward self.beta = rho * self.beta + (1 - avg_normalized_reward) return rho - + def calculate_rewards(prompts, responses, reward_model, reward_tokenizer): """整合所有奖励函数计算总奖励""" @@ -128,84 +159,96 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer): return rewards -def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, value_tracker, start_step=0, wandb=None): - for step, batch in enumerate(loader, start=start_step + 1): - prompts = batch['prompt'] # list[str], length B - prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False, - padding_side="left", add_special_tokens=False).to(args.device) # input_ids: [B, P], attention_mask: [B, P] - if args.max_seq_len: - prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -args.max_seq_len:] - prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:] +# --- 4. 核心训练循环 --- +def spo_train_epoch(epoch, loader, iters, model, ref_model, reward_model, reward_tokenizer, + value_tracker, sampler, tokenizer, args, autocast_ctx, wandb=None): + model.train() + for step, batch in enumerate(loader, start=1): + prompts = batch['prompt'] + indices = batch['index'] + + # 数据预处理 + prompt_inputs = tokenizer( + prompts, + return_tensors="pt", + padding=True, + padding_side="left", + add_special_tokens=False, + return_token_type_ids=False + ).to(args.device) - with torch.no_grad(): - # DDP 模型需要使用 .module 访问 generate 方法 + if args.max_seq_len: + prompt_inputs = {k: v[:, -args.max_seq_len:] for k, v in prompt_inputs.items()} + + # 1. 采样生成 + with torch.no_grad(), autocast_ctx: model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model outputs = model_for_gen.generate( **prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8, - num_return_sequences=1, pad_token_id=tokenizer.pad_token_id) # [B, P+R] - - completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):] # [B, R] - + pad_token_id=tokenizer.pad_token_id) # use_cache = False + completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):] + completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True) + + # 2. 计算 Logprobs def get_per_token_logps(mdl, input_ids, n_keep): - input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :] - per_token_logps = [] - for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]): - ids_row = ids_row.detach().clone() if ids_row.is_inference() else ids_row - per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1)) - return torch.stack(per_token_logps) + target_ids = input_ids[:, -n_keep:] + log_probs = logits.log_softmax(dim=-1) + return torch.gather(log_probs, 2, target_ids.unsqueeze(2)).squeeze(2) + + with autocast_ctx: + per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1)) + with torch.no_grad(): + ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1)) - per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1)) # [B, R] - with torch.no_grad(): - ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1)) # [B, R] + # 3. 奖励与优势计算 + rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer) + baselines = value_tracker.get_baselines(len(prompts)).to(args.device) - completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True) # list[str], length B - rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer).to(args.device) # [B] + advantages = (rewards - baselines).clamp(-5.0, 5.0) - baselines = value_tracker.get_baselines(len(prompts)).to(args.device) # [B] + # 4. 更新采样权重 (EMA) + norm_rewards = ((rewards.detach() + 3.0) / 6.0).clamp(0, 1) + current_v_gpu = sampler.weights[indices].to(args.device) # 把 CPU 上的旧权重拉到 GPU + updated_v_gpu = 0.7 * current_v_gpu + 0.3 * norm_rewards # 全程 GPU 计算 + sampler.update_weights(indices, updated_v_gpu.cpu()) # 计算完结果传回 CPU 存储 - scale = 3.0 - # Un-normalize baselines to be in the same scale as raw rewards [-3, 3] - unnormalized_baselines = baselines * (2 * scale) - scale # [B] - advantages = rewards - unnormalized_baselines # [B] + # 5. 计算损失 + is_eos = completion_ids == tokenizer.eos_token_id + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int() - # 直接使用 baseline 提供的优势估计,只做裁剪防止梯度爆炸。不再做 batch 内归一化,因为 baseline 已经提供了跨 batch 的稳定基线 - advantages = advantages.clamp(-5.0, 5.0) + kl_div = ref_per_token_logps - per_token_logps + per_token_kl = torch.exp(kl_div) - kl_div - 1 + + per_token_loss = -per_token_logps * advantages.unsqueeze(1) + args.beta * per_token_kl + loss = ((per_token_loss * completion_mask).sum(dim=1) / (completion_mask.sum(dim=1) + 1e-8)).mean() / args.accumulation_steps - is_eos = completion_ids == tokenizer.eos_token_id # [B, R] - eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device) # [B] - eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int() # [B, R] - - kl_div = ref_per_token_logps - per_token_logps # [B, R] - per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B, R] - per_token_loss = -per_token_logps * advantages.unsqueeze(1) + args.beta * per_token_kl # [B, R] - loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() / args.accumulation_steps # scalar + # 6. 反向传播 loss.backward() + + rho = value_tracker.update(rewards, per_token_logps.detach(), completion_mask.float()) - response_masks = completion_mask.float() # [B, R] - rho = value_tracker.update(rewards, per_token_logps.detach(), response_masks) - - if (step + 1) % args.accumulation_steps == 0: - if args.grad_clip > 0: + if step % args.accumulation_steps == 0: + if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() scheduler.step() optimizer.zero_grad() torch.cuda.empty_cache() + # 7. 日志打印 if step % args.log_interval == 0 or step == iters: - policy_loss_val = loss.item() + policy_loss_val = loss.item() * args.accumulation_steps # 恢复显示量级 avg_reward_val = rewards.mean().item() avg_len_val = completion_mask.sum(dim=1).float().mean().item() kl_val = ((per_token_kl * completion_mask).sum() / (completion_mask.sum() + 1e-8)).item() avg_baseline_val = baselines.mean().item() current_lr = optimizer.param_groups[0]['lr'] - Logger(f'Epoch: {epoch+1}, Step: {step}/{iters}, ' - f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, ' - f'Baseline: {avg_baseline_val:.4f}, KL: {kl_val:.4f}, Rho: {rho:.4f}, ' - f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}') + Logger(f'Step: {step}/{iters}, Loss: {policy_loss_val:.4f}, Reward: {avg_reward_val:.4f}, ' + f'Baseline: {avg_baseline_val:.4f}, KL: {kl_val:.4f}, Rho: {rho:.4f}') if wandb and is_main_process(): wandb.log({ @@ -215,26 +258,30 @@ def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokeni "rho": float(rho), "baseline": avg_baseline_val, "advantages_mean": advantages.mean().item(), - "learning_rate": current_lr + "learning_rate": current_lr, + }) - if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): + # 8. 模型保存逻辑 + if (step % args.save_interval == 0 or step == iters) and is_main_process(): + # ### <--- 修改点 5: 确保 lm_config 在作用域内 (通常从 args 或 model 获取) model.eval() moe_suffix = '_moe' if lm_config.use_moe else '' ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict() torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp) - lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, - epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler) + lm_checkpoint(model.module.config if hasattr(model, 'module') else model.config, + weight=args.save_weight, model=model, optimizer=optimizer, + epoch=epoch, step=step, wandb=wandb, save_dir=args.save_dir, scheduler=scheduler) model.train() del state_dict + # 清理内存 del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps - del completions, rewards, advantages, completion_mask, baselines, response_masks + del completions, rewards, advantages, completion_mask, baselines torch.cuda.empty_cache() gc.collect() - if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind SPO (Self-Play Optimization)") parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录") @@ -286,33 +333,54 @@ if __name__ == "__main__": wandb_id = ckp_data.get('wandb_id') if ckp_data else None resume = 'must' if wandb_id else None wandb_run_name = f"MiniMind-SPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}" - wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) + wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume, mode="local") - # ========== 5. 初始化模型(Policy, Ref, Reward)和Value Tracker、数据 ========== +# ========== 5. 初始化模型(Policy, Ref, Reward)和Value Tracker、数据 ========== base_weight = "reason" if args.reasoning == 1 else "full_sft" # Policy模型 model, tokenizer = init_model(lm_config, base_weight, device=args.device) # Reference模型 ref_model, _ = init_model(lm_config, base_weight, device=args.device) ref_model = ref_model.eval().requires_grad_(False) + # Reward模型 reward_model = AutoModel.from_pretrained( args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True ) reward_model = reward_model.to(args.device).eval().requires_grad_(False) reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True) + # Value Tracker value_tracker = AutoAdaptiveValueTracker(rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96) - train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len) - train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None + # --- 关键改动标注 --- + train_ds = SPODataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len) + + # [标注 1]: 必须显式传入 num_replicas 和 rank,确保分布式采样逻辑正确 + train_sampler = WeightedDistributedSampler( + train_ds, + num_replicas=dist.get_world_size() if dist.is_initialized() else 1, + rank=local_rank + ) + + # [标注 2]: DataLoader 必须绑定这个 train_sampler + loader = DataLoader( + train_ds, + batch_size=args.batch_size, + sampler=train_sampler, + num_workers=args.num_workers, + pin_memory=True + ) + optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) - loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler) - iters = len(loader_for_count) - total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs - scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10) + # 直接使用 loader 的长度即可,无需额外创建一个 loader_for_count,节省开销 + iters = len(loader) + # 确保 total_optimizer_steps 考虑了梯度累积 + total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs + scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10) + # ========== 6. 从ckp恢复状态 ========== start_epoch, start_step = 0, 0 if ckp_data: @@ -327,16 +395,36 @@ if __name__ == "__main__": model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model = DistributedDataParallel(model, device_ids=[local_rank]) - # ========== 8. 开始训练 ========== +# ========== 8. 开始训练 ========== for epoch in range(start_epoch, args.epochs): - train_sampler and train_sampler.set_epoch(epoch) - if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点 - batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1) + # 必须调用 set_epoch 使得 priority 随 epoch 重新洗牌 + train_sampler.set_epoch(epoch) + + if epoch == start_epoch and start_step > 0: + # 续训逻辑 + batch_sampler = SkipBatchSampler(train_sampler, args.batch_size, start_step + 1) loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) - Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始') - spo_train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, reward_model, reward_tokenizer, value_tracker, start_step, wandb) - else: # 默认从头开始 + + # [修改点 1]: 补全了大量缺失的参数。原代码漏掉了 model, tokenizer, args, autocast_ctx 等。 + # 注意:这里的 iters 应该是全局步数,修正为 len(loader) + start_step + spo_train_epoch( + epoch, loader, len(loader) + start_step + 1, + model, ref_model, reward_model, reward_tokenizer, + value_tracker, train_sampler, tokenizer, + args, autocast_ctx, wandb + ) + train_sampler.sync_weights() + + else: + # 常规训练 loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True, - drop_last=False, shuffle=(train_sampler is None), - num_workers=args.num_workers, sampler=train_sampler) - spo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, value_tracker, 0, wandb) + drop_last=False, num_workers=args.num_workers, sampler=train_sampler) + + spo_train_epoch( + epoch, loader, len(loader), + model, ref_model, reward_model, reward_tokenizer, + value_tracker, train_sampler, tokenizer, + args, autocast_ctx, wandb + ) + # sync_weights 必须在每个 epoch 结束时调用,以同步多卡的采样权重 + train_sampler.sync_weights() \ No newline at end of file