[fix] ppo mask

This commit is contained in:
jingyaogong 2025-11-19 23:39:02 +08:00
parent f5374dc87f
commit d7f4f4eab8

View File

@ -139,7 +139,7 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
full_mask = (gen_out != tokenizer.pad_token_id).long() # [B, P+R]
values_seq = critic_model(input_ids=gen_out, attention_mask=full_mask) # [B, P+R]
last_indices = full_mask.sum(dim=1) - 1 # [B]
last_indices = (full_mask * torch.arange(full_mask.size(1), device=gen_out.device)).argmax(dim=1)
values = values_seq[torch.arange(values_seq.size(0), device=values_seq.device), last_indices] # [B]
advantages = rewards - values.detach() # [B]