mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-05-03 12:52:34 +00:00
[fix] inference bug
This commit is contained in:
@@ -81,7 +81,7 @@ class TorchRolloutEngine(RolloutEngine):
|
||||
num_return_sequences=1,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
) # [B*num_gen, P+R]
|
||||
).clone() # [B*num_gen, P+R]
|
||||
prompt_len = prompt_ids.size(1)
|
||||
completion_ids = output_ids[:, prompt_len:] # [B*num_gen, R]
|
||||
full_mask = (output_ids != self.tokenizer.pad_token_id).long()
|
||||
|
||||
Reference in New Issue
Block a user