From 773e451b1112ba60ba5446bf9ced7b61daabc2f1 Mon Sep 17 00:00:00 2001 From: jingyaogong Date: Mon, 27 Apr 2026 19:16:08 +0800 Subject: [PATCH] [fix] bugs --- trainer/rollout_engine.py | 51 ++++++++++++++++++++++----------------- trainer/train_grpo.py | 22 +++++++++-------- trainer/train_ppo.py | 41 +++++++++++++------------------ 3 files changed, 58 insertions(+), 56 deletions(-) diff --git a/trainer/rollout_engine.py b/trainer/rollout_engine.py index 748f118..e433a0b 100644 --- a/trainer/rollout_engine.py +++ b/trainer/rollout_engine.py @@ -43,6 +43,8 @@ class RolloutResult: completion_ids: Tensor per_token_logps: Tensor completions: List[str] + prompt_lens: Tensor + completion_mask: Tensor # ===== Rollout 引擎抽象基类 ===== @@ -71,12 +73,12 @@ class TorchRolloutEngine(RolloutEngine): ctx = self.autocast_ctx if self.autocast_ctx else nullcontext() with torch.no_grad(), ctx: output_ids = model.generate( - input_ids=prompt_ids, - attention_mask=attention_mask, + input_ids=prompt_ids.repeat_interleave(num_generations, dim=0), + attention_mask=attention_mask.repeat_interleave(num_generations, dim=0), max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, - num_return_sequences=num_generations, + num_return_sequences=1, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, ) # [B*num_gen, P+R] @@ -85,7 +87,9 @@ class TorchRolloutEngine(RolloutEngine): full_mask = (output_ids != self.tokenizer.pad_token_id).long() per_token_logps = compute_per_token_logps(self.policy_model, output_ids, completion_ids.size(1), attention_mask=full_mask) completions = self.tokenizer.batch_decode(completion_ids, skip_special_tokens=True) - return RolloutResult(output_ids, completion_ids, per_token_logps, completions) + return RolloutResult(output_ids, completion_ids, per_token_logps, completions, + prompt_ids.new_full((output_ids.size(0),), prompt_len), + attention_mask.new_ones(output_ids.size(0), completion_ids.size(1))) def update_policy(self, model: torch.nn.Module): self.policy_model = model @@ -127,7 +131,6 @@ class SGLangRolloutEngine(RolloutEngine): all_output_ids, all_completion_ids, all_logprobs = [], [], [] completions = [] - prompt_len = prompt_ids.size(1) for i, result in enumerate(results): meta = result.get("meta_info", {}) @@ -144,7 +147,7 @@ class SGLangRolloutEngine(RolloutEngine): if len(logprobs) < len(completion_ids): logprobs = [0.0] * (len(completion_ids) - len(logprobs)) + logprobs elif len(logprobs) > len(completion_ids): - logprobs = logprobs[-len(completion_ids):] + logprobs = logprobs[-len(completion_ids):] if completion_ids else [] prompt = all_input_ids[i] full_output = prompt + completion_ids all_output_ids.append(full_output) @@ -153,8 +156,8 @@ class SGLangRolloutEngine(RolloutEngine): completions.append(self.tokenizer.decode(completion_ids, skip_special_tokens=True)) device = prompt_ids.device - max_out_len = max(len(ids) for ids in all_output_ids) - max_comp_len = max(len(ids) for ids in all_completion_ids) + max_comp_len = max(1, max(len(ids) for ids in all_completion_ids)) + max_out_len = max(len(ids) for ids in all_input_ids) + max_comp_len def pad_to_tensor(seqs, max_len, pad_val=0): return torch.tensor([s + [pad_val] * (max_len - len(s)) for s in seqs], device=device) @@ -165,25 +168,29 @@ class SGLangRolloutEngine(RolloutEngine): completion_ids=pad_to_tensor(all_completion_ids, max_comp_len, pad_val=pad_id), per_token_logps=pad_to_tensor(all_logprobs, max_comp_len, pad_val=0.0), completions=completions, + prompt_lens=torch.tensor([len(ids) for ids in all_input_ids], device=device), + completion_mask=torch.tensor([[1] * len(ids) + [0] * (max_comp_len - len(ids)) for ids in all_completion_ids], device=device), ) def update_policy(self, model: torch.nn.Module): ok = True if not dist.is_initialized() or dist.get_rank() == 0: - unwrapped = model.module if isinstance(model, DistributedDataParallel) else model - unwrapped = getattr(unwrapped, '_orig_mod', unwrapped) - abs_path = os.path.abspath(self.shared_ckpt_path) - state_dict = {k: v.detach().half().cpu() for k, v in unwrapped.state_dict().items()} - unwrapped.save_pretrained(abs_path, state_dict=state_dict, safe_serialization=False) - self.tokenizer.save_pretrained(abs_path) - resp = self.http.post( - f"{self.base_url}/update_weights_from_disk", - json={"model_path": abs_path}, - timeout=self.timeout - ) - if resp.status_code != 200: print(f"[SGLANG WARNING] update_weights 失败: {resp.status_code}, {resp.text}") - ok = resp.status_code == 200 - if dist.is_initialized(): dist.barrier() + try: + unwrapped = model.module if isinstance(model, DistributedDataParallel) else model + unwrapped = getattr(unwrapped, '_orig_mod', unwrapped) + abs_path = os.path.abspath(self.shared_ckpt_path) + state_dict = {k: v.detach().half().cpu() for k, v in unwrapped.state_dict().items()} + unwrapped.save_pretrained(abs_path, state_dict=state_dict, safe_serialization=False) + self.tokenizer.save_pretrained(abs_path) + resp = self.http.post(f"{self.base_url}/update_weights_from_disk", json={"model_path": abs_path}, timeout=self.timeout) + if resp.status_code != 200: print(f"[SGLANG WARNING] update_weights 失败: {resp.status_code}, {resp.text}") + ok = resp.status_code == 200 + except Exception as e: + print(f"[SGLANG WARNING] update_weights 异常: {e}"); ok = False + if dist.is_initialized(): + ok_t = torch.tensor(int(ok), device=next(model.parameters()).device) + dist.broadcast(ok_t, src=0); dist.barrier(); ok = bool(ok_t.item()) + if not ok: raise RuntimeError("SGLang update_policy failed") return ok def flush_cache(self) -> bool: diff --git a/trainer/train_grpo.py b/trainer/train_grpo.py index 77ec71a..ea63419 100755 --- a/trainer/train_grpo.py +++ b/trainer/train_grpo.py @@ -22,7 +22,7 @@ from transformers import AutoModel from model.model_minimind import MiniMindConfig, MiniMindForCausalLM from dataset.lm_dataset import RLAIFDataset from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model, LMForRewardModel -from trainer.rollout_engine import create_rollout_engine, compute_per_token_logps +from trainer.rollout_engine import create_rollout_engine warnings.filterwarnings('ignore') @@ -87,17 +87,18 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod completion_ids = rollout_result.completion_ids completions = rollout_result.completions old_per_token_logps = rollout_result.per_token_logps.to(args.device) + prompt_lens = rollout_result.prompt_lens.to(args.device) full_mask = (outputs != tokenizer.pad_token_id).long() + logp_pos = prompt_lens.unsqueeze(1) - 1 + torch.arange(completion_ids.size(1), device=args.device).unsqueeze(0) model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model with autocast_ctx: res = model_unwrapped(outputs, attention_mask=full_mask) aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device) - logits = res.logits[:, :-1, :] - per_token_logps = F.log_softmax(logits, dim=-1).gather(2, outputs[:, 1:].unsqueeze(-1)).squeeze(-1)[:, -completion_ids.size(1):] + per_token_logps = F.log_softmax(res.logits[:, :-1, :], dim=-1).gather(2, outputs[:, 1:].unsqueeze(-1)).squeeze(-1).gather(1, logp_pos) with torch.no_grad(): - ref_per_token_logps = compute_per_token_logps(ref_model, outputs, completion_ids.size(1), attention_mask=full_mask) + ref_per_token_logps = F.log_softmax(ref_model(outputs, attention_mask=full_mask).logits[:, :-1, :], dim=-1).gather(2, outputs[:, 1:].unsqueeze(-1)).squeeze(-1).gather(1, logp_pos) rewards = calculate_rewards(prompts, completions, reward_model).to(args.device) # [B*num_gen] if args.debug_mode and is_main_process() and step % args.debug_interval == 0: @@ -120,10 +121,11 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod std_r = grouped_rewards.std(dim=1, unbiased=False).repeat_interleave(args.num_generations) # [B*num_gen] advantages = (rewards - mean_r) / (std_r + 1e-4) # [B*num_gen] - is_eos = completion_ids == tokenizer.eos_token_id # [B*num_gen, R] - eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device) + completion_pad_mask = rollout_result.completion_mask.to(args.device).bool() + is_eos = (completion_ids == tokenizer.eos_token_id) & completion_pad_mask # [B*num_gen, R] + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1) - 1, dtype=torch.long, device=args.device) eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int() # [B*num_gen, R] + completion_mask = ((torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)) & completion_pad_mask).int() # [B*num_gen, R] kl_div = ref_per_token_logps - per_token_logps per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B*num_gen, R] @@ -136,7 +138,7 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod per_token_loss1 = ratio * advantages.unsqueeze(1) per_token_loss2 = clipped_ratio * advantages.unsqueeze(1) per_token_loss = -(torch.min(per_token_loss1, per_token_loss2) - args.beta * per_token_kl) - policy_loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() + policy_loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1).clamp(min=1)).mean() loss = (policy_loss + aux_loss) / args.accumulation_steps # scalar loss.backward() @@ -152,7 +154,7 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod current_aux_loss = aux_loss.item() avg_reward_val = rewards.mean().item() avg_len_val = completion_mask.sum(dim=1).float().mean().item() - kl_ref_val = ((ref_per_token_logps - per_token_logps) * completion_mask).sum().item() / completion_mask.sum().item() + kl_ref_val = ((ref_per_token_logps - per_token_logps) * completion_mask).sum().item() / max(completion_mask.sum().item(), 1) advantages_mean_val = advantages.mean().item() advantages_std_val = advantages.std().item() current_lr = optimizer.param_groups[0]['lr'] @@ -189,7 +191,7 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod if step % args.save_interval == 0 or step == iters: rollout_engine.update_policy(model) del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps - del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask + del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask, completion_pad_mask, prompt_lens, logp_pos if step > start_step and step % args.accumulation_steps != 0: if args.grad_clip > 0: diff --git a/trainer/train_ppo.py b/trainer/train_ppo.py index ef5ccc4..228fb68 100644 --- a/trainer/train_ppo.py +++ b/trainer/train_ppo.py @@ -84,7 +84,6 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched prompts = batch["prompt"] # list[str], length B enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=args.max_seq_len, padding_side="left").to(args.device) # input_ids: [B, P], attention_mask: [B, P] - prompt_length = enc.input_ids.shape[1] rollout_result = rollout_engine.rollout( prompt_ids=enc.input_ids, @@ -94,7 +93,10 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched temperature=0.8, ) gen_out = rollout_result.output_ids + completion_ids = rollout_result.completion_ids + prompt_lens = rollout_result.prompt_lens.to(args.device) responses_text = rollout_result.completions + old_resp_logp = rollout_result.per_token_logps.to(args.device) rewards = calculate_rewards(prompts, responses_text, reward_model) # [B] if args.debug_mode and is_main_process() and step % args.debug_interval == 0: @@ -104,7 +106,7 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched Logger(f"{'=' * 30} [DEBUG] sample[{i}] CONTEXT_BEGIN {'=' * 30}") Logger(prompts[i]) Logger(f"{'=' * 31} [DEBUG] sample[{i}] CONTEXT_END {'=' * 31}") - Logger(f"[DEBUG] prompt_len={prompt_length}, response_len={len(responses_text[i])}") + Logger(f"[DEBUG] prompt_len={prompt_lens[i].item()}, response_len={len(responses_text[i])}") Logger(f"{'=' * 28} [DEBUG] sample[{i}] RESPONSE_BEGIN {'=' * 28}") Logger(responses_text[i]) Logger(f"{'=' * 29} [DEBUG] sample[{i}] RESPONSE_END {'=' * 29}") @@ -113,13 +115,12 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched full_mask = (gen_out != tokenizer.pad_token_id).long() # [B, P+R] labels = gen_out[:, 1:].clone() # [B, P+R-1] - seq_len, resp_start = gen_out.size(1) - 1, prompt_length - 1 - resp_mask = torch.arange(seq_len, device=gen_out.device).unsqueeze(0) >= resp_start B = len(prompts) - resp_labels = labels[:, resp_start:] # [B, R] + resp_labels = completion_ids resp_idx = torch.arange(resp_labels.size(1), device=gen_out.device).unsqueeze(0) - resp_pad_mask = ~resp_labels.eq(tokenizer.pad_token_id) - resp_lengths = resp_pad_mask.sum(dim=1); eos_mask = resp_labels.eq(tokenizer.eos_token_id) & resp_pad_mask + logp_pos = prompt_lens.unsqueeze(1) - 1 + resp_idx + resp_pad_mask = rollout_result.completion_mask.to(args.device).bool() + resp_lengths = resp_pad_mask.sum(dim=1); valid_resp = resp_lengths > 0; eos_mask = resp_labels.eq(tokenizer.eos_token_id) & resp_pad_mask has_eos = eos_mask.any(dim=1); eos_pos = torch.argmax(eos_mask.int(), dim=1) resp_lengths = torch.where(has_eos, eos_pos + 1, resp_lengths).long().clamp(min=1) resp_policy_mask = ((resp_idx < resp_lengths.unsqueeze(1)) & resp_pad_mask).float() @@ -128,19 +129,12 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched with torch.no_grad(): # Rollout阶段只需推理获取old_logp和old_values,切断梯度省显存 critic_for_rollout = critic_model.module if isinstance(critic_model, DistributedDataParallel) else critic_model values_seq = critic_for_rollout(input_ids=gen_out, attention_mask=full_mask) - old_resp_values = values_seq[:, resp_start:-1] * resp_value_mask + old_resp_values = values_seq.gather(1, logp_pos) * resp_value_mask - actor_for_rollout = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model - with autocast_ctx: - logits = actor_for_rollout(input_ids=gen_out, attention_mask=full_mask).logits - - old_resp_logp = F.log_softmax(logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1)[:, resp_start:] - - ref_logp_all = F.log_softmax(ref_model(input_ids=gen_out, attention_mask=full_mask).logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1) - ref_resp_logp = ref_logp_all[:, resp_start:] + ref_resp_logp = F.log_softmax(ref_model(input_ids=gen_out, attention_mask=full_mask).logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1).gather(1, logp_pos) token_rewards = torch.zeros_like(old_resp_logp) last_idx = resp_lengths - 1 # [B] - token_rewards[torch.arange(B, device=args.device), last_idx] += rewards # 末尾加外部奖励 + token_rewards[torch.arange(B, device=args.device)[valid_resp], last_idx[valid_resp]] += rewards[valid_resp] # 末尾加外部奖励 gen_len = old_resp_values.size(1); lastgaelam = torch.zeros(B, device=args.device); advs_rev = [] for t in reversed(range(gen_len)): @@ -174,14 +168,13 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched inds = b_inds[i:i + mb_size] mb_values_seq = critic_unwrapped(input_ids=gen_out[inds], attention_mask=full_mask[inds]) - mb_resp_values = mb_values_seq[:, resp_start:-1] + mb_resp_values = mb_values_seq.gather(1, logp_pos[inds]) with autocast_ctx: res = actor_unwrapped(input_ids=gen_out[inds], attention_mask=full_mask[inds]) aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device) - mb_logp_all = F.log_softmax(res.logits[:, :-1], dim=-1).gather(2, labels[inds].unsqueeze(-1)).squeeze(-1) - mb_resp_logp = mb_logp_all[:, resp_start:] + mb_resp_logp = F.log_softmax(res.logits[:, :-1], dim=-1).gather(2, labels[inds].unsqueeze(-1)).squeeze(-1).gather(1, logp_pos[inds]) log_ratio = mb_resp_logp - old_resp_logp[inds] approx_kl = (0.5 * (log_ratio ** 2) * resp_policy_mask[inds]).sum() / resp_policy_mask[inds].sum().clamp(min=1) @@ -249,7 +242,7 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched actor_optimizer.zero_grad() critic_optimizer.zero_grad() - if step % args.save_interval == 0: rollout_engine.update_policy(actor_model) + if step % args.save_interval == 0 or step == iters: rollout_engine.update_policy(actor_model) if is_main_process(): critic_loss_val = value_loss_sum / max(log_count, 1) @@ -294,9 +287,9 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched actor_model.train() del actor_state - del enc, gen_out, responses_text, rewards, full_mask, values_seq, advantages - del logits, labels, resp_labels, resp_idx, resp_pad_mask, eos_mask, has_eos, eos_pos, resp_lengths, resp_policy_mask, resp_value_mask, old_resp_logp, ref_logp_all, ref_resp_logp - del kl, kl_ref, policy_loss, value_loss, loss, token_rewards, returns, old_resp_values + del enc, gen_out, completion_ids, responses_text, rewards, full_mask, values_seq, advantages + del labels, resp_labels, resp_idx, resp_pad_mask, valid_resp, eos_mask, has_eos, eos_pos, resp_lengths, resp_policy_mask, resp_value_mask, old_resp_logp, ref_resp_logp + del kl, kl_ref, policy_loss, value_loss, loss, token_rewards, returns, old_resp_values, prompt_lens, logp_pos if __name__ == "__main__":