[fix] bugs

This commit is contained in:
jingyaogong 2026-04-27 19:16:08 +08:00
parent 6361510016
commit 773e451b11
3 changed files with 58 additions and 56 deletions

View File

@ -43,6 +43,8 @@ class RolloutResult:
completion_ids: Tensor
per_token_logps: Tensor
completions: List[str]
prompt_lens: Tensor
completion_mask: Tensor
# ===== Rollout 引擎抽象基类 =====
@ -71,12 +73,12 @@ class TorchRolloutEngine(RolloutEngine):
ctx = self.autocast_ctx if self.autocast_ctx else nullcontext()
with torch.no_grad(), ctx:
output_ids = model.generate(
input_ids=prompt_ids,
attention_mask=attention_mask,
input_ids=prompt_ids.repeat_interleave(num_generations, dim=0),
attention_mask=attention_mask.repeat_interleave(num_generations, dim=0),
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
num_return_sequences=num_generations,
num_return_sequences=1,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
) # [B*num_gen, P+R]
@ -85,7 +87,9 @@ class TorchRolloutEngine(RolloutEngine):
full_mask = (output_ids != self.tokenizer.pad_token_id).long()
per_token_logps = compute_per_token_logps(self.policy_model, output_ids, completion_ids.size(1), attention_mask=full_mask)
completions = self.tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
return RolloutResult(output_ids, completion_ids, per_token_logps, completions)
return RolloutResult(output_ids, completion_ids, per_token_logps, completions,
prompt_ids.new_full((output_ids.size(0),), prompt_len),
attention_mask.new_ones(output_ids.size(0), completion_ids.size(1)))
def update_policy(self, model: torch.nn.Module):
self.policy_model = model
@ -127,7 +131,6 @@ class SGLangRolloutEngine(RolloutEngine):
all_output_ids, all_completion_ids, all_logprobs = [], [], []
completions = []
prompt_len = prompt_ids.size(1)
for i, result in enumerate(results):
meta = result.get("meta_info", {})
@ -144,7 +147,7 @@ class SGLangRolloutEngine(RolloutEngine):
if len(logprobs) < len(completion_ids):
logprobs = [0.0] * (len(completion_ids) - len(logprobs)) + logprobs
elif len(logprobs) > len(completion_ids):
logprobs = logprobs[-len(completion_ids):]
logprobs = logprobs[-len(completion_ids):] if completion_ids else []
prompt = all_input_ids[i]
full_output = prompt + completion_ids
all_output_ids.append(full_output)
@ -153,8 +156,8 @@ class SGLangRolloutEngine(RolloutEngine):
completions.append(self.tokenizer.decode(completion_ids, skip_special_tokens=True))
device = prompt_ids.device
max_out_len = max(len(ids) for ids in all_output_ids)
max_comp_len = max(len(ids) for ids in all_completion_ids)
max_comp_len = max(1, max(len(ids) for ids in all_completion_ids))
max_out_len = max(len(ids) for ids in all_input_ids) + max_comp_len
def pad_to_tensor(seqs, max_len, pad_val=0):
return torch.tensor([s + [pad_val] * (max_len - len(s)) for s in seqs], device=device)
@ -165,25 +168,29 @@ class SGLangRolloutEngine(RolloutEngine):
completion_ids=pad_to_tensor(all_completion_ids, max_comp_len, pad_val=pad_id),
per_token_logps=pad_to_tensor(all_logprobs, max_comp_len, pad_val=0.0),
completions=completions,
prompt_lens=torch.tensor([len(ids) for ids in all_input_ids], device=device),
completion_mask=torch.tensor([[1] * len(ids) + [0] * (max_comp_len - len(ids)) for ids in all_completion_ids], device=device),
)
def update_policy(self, model: torch.nn.Module):
ok = True
if not dist.is_initialized() or dist.get_rank() == 0:
unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
unwrapped = getattr(unwrapped, '_orig_mod', unwrapped)
abs_path = os.path.abspath(self.shared_ckpt_path)
state_dict = {k: v.detach().half().cpu() for k, v in unwrapped.state_dict().items()}
unwrapped.save_pretrained(abs_path, state_dict=state_dict, safe_serialization=False)
self.tokenizer.save_pretrained(abs_path)
resp = self.http.post(
f"{self.base_url}/update_weights_from_disk",
json={"model_path": abs_path},
timeout=self.timeout
)
if resp.status_code != 200: print(f"[SGLANG WARNING] update_weights 失败: {resp.status_code}, {resp.text}")
ok = resp.status_code == 200
if dist.is_initialized(): dist.barrier()
try:
unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
unwrapped = getattr(unwrapped, '_orig_mod', unwrapped)
abs_path = os.path.abspath(self.shared_ckpt_path)
state_dict = {k: v.detach().half().cpu() for k, v in unwrapped.state_dict().items()}
unwrapped.save_pretrained(abs_path, state_dict=state_dict, safe_serialization=False)
self.tokenizer.save_pretrained(abs_path)
resp = self.http.post(f"{self.base_url}/update_weights_from_disk", json={"model_path": abs_path}, timeout=self.timeout)
if resp.status_code != 200: print(f"[SGLANG WARNING] update_weights 失败: {resp.status_code}, {resp.text}")
ok = resp.status_code == 200
except Exception as e:
print(f"[SGLANG WARNING] update_weights 异常: {e}"); ok = False
if dist.is_initialized():
ok_t = torch.tensor(int(ok), device=next(model.parameters()).device)
dist.broadcast(ok_t, src=0); dist.barrier(); ok = bool(ok_t.item())
if not ok: raise RuntimeError("SGLang update_policy failed")
return ok
def flush_cache(self) -> bool:

View File

@ -22,7 +22,7 @@ from transformers import AutoModel
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from dataset.lm_dataset import RLAIFDataset
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model, LMForRewardModel
from trainer.rollout_engine import create_rollout_engine, compute_per_token_logps
from trainer.rollout_engine import create_rollout_engine
warnings.filterwarnings('ignore')
@ -87,17 +87,18 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
completion_ids = rollout_result.completion_ids
completions = rollout_result.completions
old_per_token_logps = rollout_result.per_token_logps.to(args.device)
prompt_lens = rollout_result.prompt_lens.to(args.device)
full_mask = (outputs != tokenizer.pad_token_id).long()
logp_pos = prompt_lens.unsqueeze(1) - 1 + torch.arange(completion_ids.size(1), device=args.device).unsqueeze(0)
model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
with autocast_ctx:
res = model_unwrapped(outputs, attention_mask=full_mask)
aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device)
logits = res.logits[:, :-1, :]
per_token_logps = F.log_softmax(logits, dim=-1).gather(2, outputs[:, 1:].unsqueeze(-1)).squeeze(-1)[:, -completion_ids.size(1):]
per_token_logps = F.log_softmax(res.logits[:, :-1, :], dim=-1).gather(2, outputs[:, 1:].unsqueeze(-1)).squeeze(-1).gather(1, logp_pos)
with torch.no_grad():
ref_per_token_logps = compute_per_token_logps(ref_model, outputs, completion_ids.size(1), attention_mask=full_mask)
ref_per_token_logps = F.log_softmax(ref_model(outputs, attention_mask=full_mask).logits[:, :-1, :], dim=-1).gather(2, outputs[:, 1:].unsqueeze(-1)).squeeze(-1).gather(1, logp_pos)
rewards = calculate_rewards(prompts, completions, reward_model).to(args.device) # [B*num_gen]
if args.debug_mode and is_main_process() and step % args.debug_interval == 0:
@ -120,10 +121,11 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
std_r = grouped_rewards.std(dim=1, unbiased=False).repeat_interleave(args.num_generations) # [B*num_gen]
advantages = (rewards - mean_r) / (std_r + 1e-4) # [B*num_gen]
is_eos = completion_ids == tokenizer.eos_token_id # [B*num_gen, R]
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device)
completion_pad_mask = rollout_result.completion_mask.to(args.device).bool()
is_eos = (completion_ids == tokenizer.eos_token_id) & completion_pad_mask # [B*num_gen, R]
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1) - 1, dtype=torch.long, device=args.device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int() # [B*num_gen, R]
completion_mask = ((torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)) & completion_pad_mask).int() # [B*num_gen, R]
kl_div = ref_per_token_logps - per_token_logps
per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B*num_gen, R]
@ -136,7 +138,7 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
per_token_loss1 = ratio * advantages.unsqueeze(1)
per_token_loss2 = clipped_ratio * advantages.unsqueeze(1)
per_token_loss = -(torch.min(per_token_loss1, per_token_loss2) - args.beta * per_token_kl)
policy_loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
policy_loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1).clamp(min=1)).mean()
loss = (policy_loss + aux_loss) / args.accumulation_steps # scalar
loss.backward()
@ -152,7 +154,7 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
current_aux_loss = aux_loss.item()
avg_reward_val = rewards.mean().item()
avg_len_val = completion_mask.sum(dim=1).float().mean().item()
kl_ref_val = ((ref_per_token_logps - per_token_logps) * completion_mask).sum().item() / completion_mask.sum().item()
kl_ref_val = ((ref_per_token_logps - per_token_logps) * completion_mask).sum().item() / max(completion_mask.sum().item(), 1)
advantages_mean_val = advantages.mean().item()
advantages_std_val = advantages.std().item()
current_lr = optimizer.param_groups[0]['lr']
@ -189,7 +191,7 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
if step % args.save_interval == 0 or step == iters: rollout_engine.update_policy(model)
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask, completion_pad_mask, prompt_lens, logp_pos
if step > start_step and step % args.accumulation_steps != 0:
if args.grad_clip > 0:

View File

@ -84,7 +84,6 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
prompts = batch["prompt"] # list[str], length B
enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=args.max_seq_len,
padding_side="left").to(args.device) # input_ids: [B, P], attention_mask: [B, P]
prompt_length = enc.input_ids.shape[1]
rollout_result = rollout_engine.rollout(
prompt_ids=enc.input_ids,
@ -94,7 +93,10 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
temperature=0.8,
)
gen_out = rollout_result.output_ids
completion_ids = rollout_result.completion_ids
prompt_lens = rollout_result.prompt_lens.to(args.device)
responses_text = rollout_result.completions
old_resp_logp = rollout_result.per_token_logps.to(args.device)
rewards = calculate_rewards(prompts, responses_text, reward_model) # [B]
if args.debug_mode and is_main_process() and step % args.debug_interval == 0:
@ -104,7 +106,7 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
Logger(f"{'=' * 30} [DEBUG] sample[{i}] CONTEXT_BEGIN {'=' * 30}")
Logger(prompts[i])
Logger(f"{'=' * 31} [DEBUG] sample[{i}] CONTEXT_END {'=' * 31}")
Logger(f"[DEBUG] prompt_len={prompt_length}, response_len={len(responses_text[i])}")
Logger(f"[DEBUG] prompt_len={prompt_lens[i].item()}, response_len={len(responses_text[i])}")
Logger(f"{'=' * 28} [DEBUG] sample[{i}] RESPONSE_BEGIN {'=' * 28}")
Logger(responses_text[i])
Logger(f"{'=' * 29} [DEBUG] sample[{i}] RESPONSE_END {'=' * 29}")
@ -113,13 +115,12 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
full_mask = (gen_out != tokenizer.pad_token_id).long() # [B, P+R]
labels = gen_out[:, 1:].clone() # [B, P+R-1]
seq_len, resp_start = gen_out.size(1) - 1, prompt_length - 1
resp_mask = torch.arange(seq_len, device=gen_out.device).unsqueeze(0) >= resp_start
B = len(prompts)
resp_labels = labels[:, resp_start:] # [B, R]
resp_labels = completion_ids
resp_idx = torch.arange(resp_labels.size(1), device=gen_out.device).unsqueeze(0)
resp_pad_mask = ~resp_labels.eq(tokenizer.pad_token_id)
resp_lengths = resp_pad_mask.sum(dim=1); eos_mask = resp_labels.eq(tokenizer.eos_token_id) & resp_pad_mask
logp_pos = prompt_lens.unsqueeze(1) - 1 + resp_idx
resp_pad_mask = rollout_result.completion_mask.to(args.device).bool()
resp_lengths = resp_pad_mask.sum(dim=1); valid_resp = resp_lengths > 0; eos_mask = resp_labels.eq(tokenizer.eos_token_id) & resp_pad_mask
has_eos = eos_mask.any(dim=1); eos_pos = torch.argmax(eos_mask.int(), dim=1)
resp_lengths = torch.where(has_eos, eos_pos + 1, resp_lengths).long().clamp(min=1)
resp_policy_mask = ((resp_idx < resp_lengths.unsqueeze(1)) & resp_pad_mask).float()
@ -128,19 +129,12 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
with torch.no_grad(): # Rollout阶段只需推理获取old_logp和old_values切断梯度省显存
critic_for_rollout = critic_model.module if isinstance(critic_model, DistributedDataParallel) else critic_model
values_seq = critic_for_rollout(input_ids=gen_out, attention_mask=full_mask)
old_resp_values = values_seq[:, resp_start:-1] * resp_value_mask
old_resp_values = values_seq.gather(1, logp_pos) * resp_value_mask
actor_for_rollout = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model
with autocast_ctx:
logits = actor_for_rollout(input_ids=gen_out, attention_mask=full_mask).logits
old_resp_logp = F.log_softmax(logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1)[:, resp_start:]
ref_logp_all = F.log_softmax(ref_model(input_ids=gen_out, attention_mask=full_mask).logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1)
ref_resp_logp = ref_logp_all[:, resp_start:]
ref_resp_logp = F.log_softmax(ref_model(input_ids=gen_out, attention_mask=full_mask).logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1).gather(1, logp_pos)
token_rewards = torch.zeros_like(old_resp_logp)
last_idx = resp_lengths - 1 # [B]
token_rewards[torch.arange(B, device=args.device), last_idx] += rewards # 末尾加外部奖励
token_rewards[torch.arange(B, device=args.device)[valid_resp], last_idx[valid_resp]] += rewards[valid_resp] # 末尾加外部奖励
gen_len = old_resp_values.size(1); lastgaelam = torch.zeros(B, device=args.device); advs_rev = []
for t in reversed(range(gen_len)):
@ -174,14 +168,13 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
inds = b_inds[i:i + mb_size]
mb_values_seq = critic_unwrapped(input_ids=gen_out[inds], attention_mask=full_mask[inds])
mb_resp_values = mb_values_seq[:, resp_start:-1]
mb_resp_values = mb_values_seq.gather(1, logp_pos[inds])
with autocast_ctx:
res = actor_unwrapped(input_ids=gen_out[inds], attention_mask=full_mask[inds])
aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device)
mb_logp_all = F.log_softmax(res.logits[:, :-1], dim=-1).gather(2, labels[inds].unsqueeze(-1)).squeeze(-1)
mb_resp_logp = mb_logp_all[:, resp_start:]
mb_resp_logp = F.log_softmax(res.logits[:, :-1], dim=-1).gather(2, labels[inds].unsqueeze(-1)).squeeze(-1).gather(1, logp_pos[inds])
log_ratio = mb_resp_logp - old_resp_logp[inds]
approx_kl = (0.5 * (log_ratio ** 2) * resp_policy_mask[inds]).sum() / resp_policy_mask[inds].sum().clamp(min=1)
@ -249,7 +242,7 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
actor_optimizer.zero_grad()
critic_optimizer.zero_grad()
if step % args.save_interval == 0: rollout_engine.update_policy(actor_model)
if step % args.save_interval == 0 or step == iters: rollout_engine.update_policy(actor_model)
if is_main_process():
critic_loss_val = value_loss_sum / max(log_count, 1)
@ -294,9 +287,9 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
actor_model.train()
del actor_state
del enc, gen_out, responses_text, rewards, full_mask, values_seq, advantages
del logits, labels, resp_labels, resp_idx, resp_pad_mask, eos_mask, has_eos, eos_pos, resp_lengths, resp_policy_mask, resp_value_mask, old_resp_logp, ref_logp_all, ref_resp_logp
del kl, kl_ref, policy_loss, value_loss, loss, token_rewards, returns, old_resp_values
del enc, gen_out, completion_ids, responses_text, rewards, full_mask, values_seq, advantages
del labels, resp_labels, resp_idx, resp_pad_mask, valid_resp, eos_mask, has_eos, eos_pos, resp_lengths, resp_policy_mask, resp_value_mask, old_resp_logp, ref_resp_logp
del kl, kl_ref, policy_loss, value_loss, loss, token_rewards, returns, old_resp_values, prompt_lens, logp_pos
if __name__ == "__main__":