mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
[fix] ppo mask
This commit is contained in:
parent
f5374dc87f
commit
d7f4f4eab8
@ -139,7 +139,7 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
|
||||
|
||||
full_mask = (gen_out != tokenizer.pad_token_id).long() # [B, P+R]
|
||||
values_seq = critic_model(input_ids=gen_out, attention_mask=full_mask) # [B, P+R]
|
||||
last_indices = full_mask.sum(dim=1) - 1 # [B]
|
||||
last_indices = (full_mask * torch.arange(full_mask.size(1), device=gen_out.device)).argmax(dim=1)
|
||||
values = values_seq[torch.arange(values_seq.size(0), device=values_seq.device), last_indices] # [B]
|
||||
advantages = rewards - values.detach() # [B]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user