[fix] inference bug

This commit is contained in:
jingyaogong
2026-05-03 20:48:48 +08:00
parent 06d882e4ef
commit bdee223036
+1 -1
View File
@@ -81,7 +81,7 @@ class TorchRolloutEngine(RolloutEngine):
num_return_sequences=1,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
) # [B*num_gen, P+R]
).clone() # [B*num_gen, P+R]
prompt_len = prompt_ids.size(1)
completion_ids = output_ids[:, prompt_len:] # [B*num_gen, R]
full_mask = (output_ids != self.tokenizer.pad_token_id).long()