diff --git a/trainer/rollout_engine.py b/trainer/rollout_engine.py index 7ef2eae..748f118 100644 --- a/trainer/rollout_engine.py +++ b/trainer/rollout_engine.py @@ -82,7 +82,8 @@ class TorchRolloutEngine(RolloutEngine): ) # [B*num_gen, P+R] prompt_len = prompt_ids.size(1) completion_ids = output_ids[:, prompt_len:] # [B*num_gen, R] - per_token_logps = compute_per_token_logps(self.policy_model, output_ids, completion_ids.size(1)) + full_mask = (output_ids != self.tokenizer.pad_token_id).long() + per_token_logps = compute_per_token_logps(self.policy_model, output_ids, completion_ids.size(1), attention_mask=full_mask) completions = self.tokenizer.batch_decode(completion_ids, skip_special_tokens=True) return RolloutResult(output_ids, completion_ids, per_token_logps, completions) @@ -158,28 +159,32 @@ class SGLangRolloutEngine(RolloutEngine): def pad_to_tensor(seqs, max_len, pad_val=0): return torch.tensor([s + [pad_val] * (max_len - len(s)) for s in seqs], device=device) + pad_id = self.tokenizer.pad_token_id return RolloutResult( - output_ids=pad_to_tensor(all_output_ids, max_out_len), - completion_ids=pad_to_tensor(all_completion_ids, max_comp_len), + output_ids=pad_to_tensor(all_output_ids, max_out_len, pad_val=pad_id), + completion_ids=pad_to_tensor(all_completion_ids, max_comp_len, pad_val=pad_id), per_token_logps=pad_to_tensor(all_logprobs, max_comp_len, pad_val=0.0), completions=completions, ) def update_policy(self, model: torch.nn.Module): - if dist.is_initialized() and dist.get_rank() != 0: return True - unwrapped = model.module if isinstance(model, DistributedDataParallel) else model - unwrapped = getattr(unwrapped, '_orig_mod', unwrapped) - abs_path = os.path.abspath(self.shared_ckpt_path) - state_dict = {k: v.detach().half().cpu() for k, v in unwrapped.state_dict().items()} - unwrapped.save_pretrained(abs_path, state_dict=state_dict, safe_serialization=False) - self.tokenizer.save_pretrained(abs_path) - resp = self.http.post( - f"{self.base_url}/update_weights_from_disk", - json={"model_path": abs_path}, - timeout=self.timeout - ) - if resp.status_code != 200: print(f"[SGLANG WARNING] update_weights 失败: {resp.status_code}, {resp.text}") - return resp.status_code == 200 + ok = True + if not dist.is_initialized() or dist.get_rank() == 0: + unwrapped = model.module if isinstance(model, DistributedDataParallel) else model + unwrapped = getattr(unwrapped, '_orig_mod', unwrapped) + abs_path = os.path.abspath(self.shared_ckpt_path) + state_dict = {k: v.detach().half().cpu() for k, v in unwrapped.state_dict().items()} + unwrapped.save_pretrained(abs_path, state_dict=state_dict, safe_serialization=False) + self.tokenizer.save_pretrained(abs_path) + resp = self.http.post( + f"{self.base_url}/update_weights_from_disk", + json={"model_path": abs_path}, + timeout=self.timeout + ) + if resp.status_code != 200: print(f"[SGLANG WARNING] update_weights 失败: {resp.status_code}, {resp.text}") + ok = resp.status_code == 200 + if dist.is_initialized(): dist.barrier() + return ok def flush_cache(self) -> bool: resp = self.http.post(f"{self.base_url}/flush_cache", timeout=30) diff --git a/trainer/train_agent.py b/trainer/train_agent.py index ff21d62..2138fb2 100644 --- a/trainer/train_agent.py +++ b/trainer/train_agent.py @@ -267,16 +267,17 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model prompt_lens = torch.tensor([prompt_len for _, _, prompt_len, _ in packed_samples], device=args.device) full_response_masks = torch.tensor([mask + [0] * (max_len - len(mask)) for _, mask, _, _ in packed_samples], device=args.device, dtype=torch.float32) old_per_token_logps = torch.tensor([old_logps + [0.0] * ((max_len - 1) - len(old_logps)) for _, _, _, old_logps in packed_samples], device=args.device, dtype=torch.float32) + full_mask = (input_ids != tokenizer.pad_token_id).long() model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model with autocast_ctx: - res = model_unwrapped(input_ids) + res = model_unwrapped(input_ids, attention_mask=full_mask) aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device) logits = res.logits[:, :-1, :] per_token_logps = F.log_softmax(logits, dim=-1).gather(2, input_ids[:, 1:].unsqueeze(-1)).squeeze(-1) with torch.no_grad(): - ref_per_token_logps = compute_per_token_logps(ref_model, input_ids, input_ids.size(1) - 1) + ref_per_token_logps = compute_per_token_logps(ref_model, input_ids, input_ids.size(1) - 1, attention_mask=full_mask) completion_mask = full_response_masks[:, 1:] is_eos = (input_ids[:, 1:] == tokenizer.eos_token_id) & completion_mask.bool() @@ -332,7 +333,6 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model if step % args.accumulation_steps == 0: if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step(); scheduler.step(); optimizer.zero_grad() - if is_main_process() and step % args.save_interval == 0: rollout_engine.update_policy(model) if step % args.log_interval == 0 or step == iters: pl = loss.item() * args.accumulation_steps @@ -359,13 +359,14 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model model.train() del state_dict + if step % args.save_interval == 0 or step == iters: rollout_engine.update_policy(model) + del per_token_logps, ref_per_token_logps del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask if last_step > start_step and last_step % args.accumulation_steps != 0: if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step(); scheduler.step(); optimizer.zero_grad() - if is_main_process() and last_step % args.save_interval == 0: rollout_engine.update_policy(model) if __name__ == "__main__": @@ -470,7 +471,7 @@ if __name__ == "__main__": rollout_engine.update_policy(model) if dist.is_initialized(): model = DistributedDataParallel(model, device_ids=[local_rank]) - if is_main_process(): rollout_engine.update_policy(model) + rollout_engine.update_policy(model) for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) diff --git a/trainer/train_grpo.py b/trainer/train_grpo.py index 196c201..77ec71a 100755 --- a/trainer/train_grpo.py +++ b/trainer/train_grpo.py @@ -87,20 +87,17 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod completion_ids = rollout_result.completion_ids completions = rollout_result.completions old_per_token_logps = rollout_result.per_token_logps.to(args.device) + full_mask = (outputs != tokenizer.pad_token_id).long() model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model with autocast_ctx: - if use_sglang or lm_config.use_moe: - res = model_unwrapped(outputs) - aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device) - logits = res.logits[:, :-1, :] - per_token_logps = F.log_softmax(logits, dim=-1).gather(2, outputs[:, 1:].unsqueeze(-1)).squeeze(-1)[:, -completion_ids.size(1):] - else: - aux_loss = torch.tensor(0.0, device=args.device) - per_token_logps = rollout_result.per_token_logps + res = model_unwrapped(outputs, attention_mask=full_mask) + aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device) + logits = res.logits[:, :-1, :] + per_token_logps = F.log_softmax(logits, dim=-1).gather(2, outputs[:, 1:].unsqueeze(-1)).squeeze(-1)[:, -completion_ids.size(1):] with torch.no_grad(): - ref_per_token_logps = compute_per_token_logps(ref_model, outputs, completion_ids.size(1)) + ref_per_token_logps = compute_per_token_logps(ref_model, outputs, completion_ids.size(1), attention_mask=full_mask) rewards = calculate_rewards(prompts, completions, reward_model).to(args.device) # [B*num_gen] if args.debug_mode and is_main_process() and step % args.debug_interval == 0: @@ -120,7 +117,7 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod grouped_rewards = rewards.view(-1, args.num_generations) # [B, num_gen] mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations) # [B*num_gen] - std_r = grouped_rewards.std(dim=1).repeat_interleave(args.num_generations) # [B*num_gen] + std_r = grouped_rewards.std(dim=1, unbiased=False).repeat_interleave(args.num_generations) # [B*num_gen] advantages = (rewards - mean_r) / (std_r + 1e-4) # [B*num_gen] is_eos = completion_ids == tokenizer.eos_token_id # [B*num_gen, R] @@ -149,7 +146,6 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod optimizer.step() scheduler.step() optimizer.zero_grad() - if is_main_process() and step % args.save_interval == 0: rollout_engine.update_policy(model) if step % args.log_interval == 0 or step == iters: policy_loss_val = loss.item() * args.accumulation_steps @@ -190,16 +186,17 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod model.train() del state_dict + if step % args.save_interval == 0 or step == iters: rollout_engine.update_policy(model) + + del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps + del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask + if step > start_step and step % args.accumulation_steps != 0: if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() scheduler.step() optimizer.zero_grad() - if is_main_process() and step % args.save_interval == 0: rollout_engine.update_policy(model) - - del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps - del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask if __name__ == "__main__": @@ -312,7 +309,7 @@ if __name__ == "__main__": rollout_engine.update_policy(model) if dist.is_initialized(): model = DistributedDataParallel(model, device_ids=[local_rank]) - if is_main_process(): rollout_engine.update_policy(model) + rollout_engine.update_policy(model) # ========== 8. 开始训练 ========== for epoch in range(start_epoch, args.epochs): diff --git a/trainer/train_ppo.py b/trainer/train_ppo.py index 1faf74b..ef5ccc4 100644 --- a/trainer/train_ppo.py +++ b/trainer/train_ppo.py @@ -115,7 +115,6 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched labels = gen_out[:, 1:].clone() # [B, P+R-1] seq_len, resp_start = gen_out.size(1) - 1, prompt_length - 1 resp_mask = torch.arange(seq_len, device=gen_out.device).unsqueeze(0) >= resp_start - final_mask = (resp_mask & (~labels.eq(tokenizer.pad_token_id))).float() # [B, P+R-1] B = len(prompts) resp_labels = labels[:, resp_start:] # [B, R] resp_idx = torch.arange(resp_labels.size(1), device=gen_out.device).unsqueeze(0) @@ -250,7 +249,7 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched actor_optimizer.zero_grad() critic_optimizer.zero_grad() - if is_main_process() and step % args.save_interval == 0: rollout_engine.update_policy(actor_model) + if step % args.save_interval == 0: rollout_engine.update_policy(actor_model) if is_main_process(): critic_loss_val = value_loss_sum / max(log_count, 1) @@ -296,7 +295,7 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched del actor_state del enc, gen_out, responses_text, rewards, full_mask, values_seq, advantages - del logits, labels, final_mask, resp_labels, resp_idx, resp_pad_mask, eos_mask, has_eos, eos_pos, resp_lengths, resp_policy_mask, resp_value_mask, old_resp_logp, ref_logp_all, ref_resp_logp + del logits, labels, resp_labels, resp_idx, resp_pad_mask, eos_mask, has_eos, eos_pos, resp_lengths, resp_policy_mask, resp_value_mask, old_resp_logp, ref_logp_all, ref_resp_logp del kl, kl_ref, policy_loss, value_loss, loss, token_rewards, returns, old_resp_values @@ -423,7 +422,7 @@ if __name__ == "__main__": if dist.is_initialized(): actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank]) critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank]) - if is_main_process(): rollout_engine.update_policy(actor_model) + rollout_engine.update_policy(actor_model) # ========== 8. 开始训练 ========== for epoch in range(start_epoch, args.epochs):