[fix] rollout bugs

This commit is contained in:
jingyaogong
2026-04-27 17:54:09 +08:00
parent d4c6bc5c7e
commit 6361510016
4 changed files with 44 additions and 42 deletions
+22 -17
View File
@@ -82,7 +82,8 @@ class TorchRolloutEngine(RolloutEngine):
) # [B*num_gen, P+R]
prompt_len = prompt_ids.size(1)
completion_ids = output_ids[:, prompt_len:] # [B*num_gen, R]
per_token_logps = compute_per_token_logps(self.policy_model, output_ids, completion_ids.size(1))
full_mask = (output_ids != self.tokenizer.pad_token_id).long()
per_token_logps = compute_per_token_logps(self.policy_model, output_ids, completion_ids.size(1), attention_mask=full_mask)
completions = self.tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
return RolloutResult(output_ids, completion_ids, per_token_logps, completions)
@@ -158,28 +159,32 @@ class SGLangRolloutEngine(RolloutEngine):
def pad_to_tensor(seqs, max_len, pad_val=0):
return torch.tensor([s + [pad_val] * (max_len - len(s)) for s in seqs], device=device)
pad_id = self.tokenizer.pad_token_id
return RolloutResult(
output_ids=pad_to_tensor(all_output_ids, max_out_len),
completion_ids=pad_to_tensor(all_completion_ids, max_comp_len),
output_ids=pad_to_tensor(all_output_ids, max_out_len, pad_val=pad_id),
completion_ids=pad_to_tensor(all_completion_ids, max_comp_len, pad_val=pad_id),
per_token_logps=pad_to_tensor(all_logprobs, max_comp_len, pad_val=0.0),
completions=completions,
)
def update_policy(self, model: torch.nn.Module):
if dist.is_initialized() and dist.get_rank() != 0: return True
unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
unwrapped = getattr(unwrapped, '_orig_mod', unwrapped)
abs_path = os.path.abspath(self.shared_ckpt_path)
state_dict = {k: v.detach().half().cpu() for k, v in unwrapped.state_dict().items()}
unwrapped.save_pretrained(abs_path, state_dict=state_dict, safe_serialization=False)
self.tokenizer.save_pretrained(abs_path)
resp = self.http.post(
f"{self.base_url}/update_weights_from_disk",
json={"model_path": abs_path},
timeout=self.timeout
)
if resp.status_code != 200: print(f"[SGLANG WARNING] update_weights 失败: {resp.status_code}, {resp.text}")
return resp.status_code == 200
ok = True
if not dist.is_initialized() or dist.get_rank() == 0:
unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
unwrapped = getattr(unwrapped, '_orig_mod', unwrapped)
abs_path = os.path.abspath(self.shared_ckpt_path)
state_dict = {k: v.detach().half().cpu() for k, v in unwrapped.state_dict().items()}
unwrapped.save_pretrained(abs_path, state_dict=state_dict, safe_serialization=False)
self.tokenizer.save_pretrained(abs_path)
resp = self.http.post(
f"{self.base_url}/update_weights_from_disk",
json={"model_path": abs_path},
timeout=self.timeout
)
if resp.status_code != 200: print(f"[SGLANG WARNING] update_weights 失败: {resp.status_code}, {resp.text}")
ok = resp.status_code == 200
if dist.is_initialized(): dist.barrier()
return ok
def flush_cache(self) -> bool:
resp = self.http.post(f"{self.base_url}/flush_cache", timeout=30)
+6 -5
View File
@@ -267,16 +267,17 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model
prompt_lens = torch.tensor([prompt_len for _, _, prompt_len, _ in packed_samples], device=args.device)
full_response_masks = torch.tensor([mask + [0] * (max_len - len(mask)) for _, mask, _, _ in packed_samples], device=args.device, dtype=torch.float32)
old_per_token_logps = torch.tensor([old_logps + [0.0] * ((max_len - 1) - len(old_logps)) for _, _, _, old_logps in packed_samples], device=args.device, dtype=torch.float32)
full_mask = (input_ids != tokenizer.pad_token_id).long()
model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
with autocast_ctx:
res = model_unwrapped(input_ids)
res = model_unwrapped(input_ids, attention_mask=full_mask)
aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device)
logits = res.logits[:, :-1, :]
per_token_logps = F.log_softmax(logits, dim=-1).gather(2, input_ids[:, 1:].unsqueeze(-1)).squeeze(-1)
with torch.no_grad():
ref_per_token_logps = compute_per_token_logps(ref_model, input_ids, input_ids.size(1) - 1)
ref_per_token_logps = compute_per_token_logps(ref_model, input_ids, input_ids.size(1) - 1, attention_mask=full_mask)
completion_mask = full_response_masks[:, 1:]
is_eos = (input_ids[:, 1:] == tokenizer.eos_token_id) & completion_mask.bool()
@@ -332,7 +333,6 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model
if step % args.accumulation_steps == 0:
if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step(); scheduler.step(); optimizer.zero_grad()
if is_main_process() and step % args.save_interval == 0: rollout_engine.update_policy(model)
if step % args.log_interval == 0 or step == iters:
pl = loss.item() * args.accumulation_steps
@@ -359,13 +359,14 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model
model.train()
del state_dict
if step % args.save_interval == 0 or step == iters: rollout_engine.update_policy(model)
del per_token_logps, ref_per_token_logps
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
if last_step > start_step and last_step % args.accumulation_steps != 0:
if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step(); scheduler.step(); optimizer.zero_grad()
if is_main_process() and last_step % args.save_interval == 0: rollout_engine.update_policy(model)
if __name__ == "__main__":
@@ -470,7 +471,7 @@ if __name__ == "__main__":
rollout_engine.update_policy(model)
if dist.is_initialized():
model = DistributedDataParallel(model, device_ids=[local_rank])
if is_main_process(): rollout_engine.update_policy(model)
rollout_engine.update_policy(model)
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
+13 -16
View File
@@ -87,20 +87,17 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
completion_ids = rollout_result.completion_ids
completions = rollout_result.completions
old_per_token_logps = rollout_result.per_token_logps.to(args.device)
full_mask = (outputs != tokenizer.pad_token_id).long()
model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
with autocast_ctx:
if use_sglang or lm_config.use_moe:
res = model_unwrapped(outputs)
aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device)
logits = res.logits[:, :-1, :]
per_token_logps = F.log_softmax(logits, dim=-1).gather(2, outputs[:, 1:].unsqueeze(-1)).squeeze(-1)[:, -completion_ids.size(1):]
else:
aux_loss = torch.tensor(0.0, device=args.device)
per_token_logps = rollout_result.per_token_logps
res = model_unwrapped(outputs, attention_mask=full_mask)
aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device)
logits = res.logits[:, :-1, :]
per_token_logps = F.log_softmax(logits, dim=-1).gather(2, outputs[:, 1:].unsqueeze(-1)).squeeze(-1)[:, -completion_ids.size(1):]
with torch.no_grad():
ref_per_token_logps = compute_per_token_logps(ref_model, outputs, completion_ids.size(1))
ref_per_token_logps = compute_per_token_logps(ref_model, outputs, completion_ids.size(1), attention_mask=full_mask)
rewards = calculate_rewards(prompts, completions, reward_model).to(args.device) # [B*num_gen]
if args.debug_mode and is_main_process() and step % args.debug_interval == 0:
@@ -120,7 +117,7 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
grouped_rewards = rewards.view(-1, args.num_generations) # [B, num_gen]
mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
std_r = grouped_rewards.std(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
std_r = grouped_rewards.std(dim=1, unbiased=False).repeat_interleave(args.num_generations) # [B*num_gen]
advantages = (rewards - mean_r) / (std_r + 1e-4) # [B*num_gen]
is_eos = completion_ids == tokenizer.eos_token_id # [B*num_gen, R]
@@ -149,7 +146,6 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if is_main_process() and step % args.save_interval == 0: rollout_engine.update_policy(model)
if step % args.log_interval == 0 or step == iters:
policy_loss_val = loss.item() * args.accumulation_steps
@@ -190,16 +186,17 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
model.train()
del state_dict
if step % args.save_interval == 0 or step == iters: rollout_engine.update_policy(model)
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
if step > start_step and step % args.accumulation_steps != 0:
if args.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if is_main_process() and step % args.save_interval == 0: rollout_engine.update_policy(model)
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
if __name__ == "__main__":
@@ -312,7 +309,7 @@ if __name__ == "__main__":
rollout_engine.update_policy(model)
if dist.is_initialized():
model = DistributedDataParallel(model, device_ids=[local_rank])
if is_main_process(): rollout_engine.update_policy(model)
rollout_engine.update_policy(model)
# ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs):
+3 -4
View File
@@ -115,7 +115,6 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
labels = gen_out[:, 1:].clone() # [B, P+R-1]
seq_len, resp_start = gen_out.size(1) - 1, prompt_length - 1
resp_mask = torch.arange(seq_len, device=gen_out.device).unsqueeze(0) >= resp_start
final_mask = (resp_mask & (~labels.eq(tokenizer.pad_token_id))).float() # [B, P+R-1]
B = len(prompts)
resp_labels = labels[:, resp_start:] # [B, R]
resp_idx = torch.arange(resp_labels.size(1), device=gen_out.device).unsqueeze(0)
@@ -250,7 +249,7 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
actor_optimizer.zero_grad()
critic_optimizer.zero_grad()
if is_main_process() and step % args.save_interval == 0: rollout_engine.update_policy(actor_model)
if step % args.save_interval == 0: rollout_engine.update_policy(actor_model)
if is_main_process():
critic_loss_val = value_loss_sum / max(log_count, 1)
@@ -296,7 +295,7 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
del actor_state
del enc, gen_out, responses_text, rewards, full_mask, values_seq, advantages
del logits, labels, final_mask, resp_labels, resp_idx, resp_pad_mask, eos_mask, has_eos, eos_pos, resp_lengths, resp_policy_mask, resp_value_mask, old_resp_logp, ref_logp_all, ref_resp_logp
del logits, labels, resp_labels, resp_idx, resp_pad_mask, eos_mask, has_eos, eos_pos, resp_lengths, resp_policy_mask, resp_value_mask, old_resp_logp, ref_logp_all, ref_resp_logp
del kl, kl_ref, policy_loss, value_loss, loss, token_rewards, returns, old_resp_values
@@ -423,7 +422,7 @@ if __name__ == "__main__":
if dist.is_initialized():
actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank])
critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank])
if is_main_process(): rollout_engine.update_policy(actor_model)
rollout_engine.update_policy(actor_model)
# ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs):