diff --git a/trainer/rollout_engine.py b/trainer/rollout_engine.py index e433a0b..3a1aab0 100644 --- a/trainer/rollout_engine.py +++ b/trainer/rollout_engine.py @@ -81,7 +81,7 @@ class TorchRolloutEngine(RolloutEngine): num_return_sequences=1, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, - ) # [B*num_gen, P+R] + ).clone() # [B*num_gen, P+R] prompt_len = prompt_ids.size(1) completion_ids = output_ids[:, prompt_len:] # [B*num_gen, R] full_mask = (output_ids != self.tokenizer.pad_token_id).long()