[None][perf] Fix TPOT when min_tokens set (#9862)

Signed-off-by: jthomson04 <jwillthomson19@gmail.com>
This commit is contained in:
jthomson04 2025-12-11 13:55:31 -08:00 committed by GitHub
parent 95d928f071
commit 4f6d4da035
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2230,9 +2230,14 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
for beam_idx in range(num_beams[index]):
for step in range(num_steps[index]):
if r.get_num_tokens(beam_idx) + step < r.py_min_length[0]:
# NOTE(jthomson04): We can NOT just assign logits[...] = float("-inf").
# This introduces a pageable HtoD transfer, which wreaks havoc on TPOT (up to ~20%)
# Instead, we create a little tensor on device, then assign to that.
# This way, we avoid the pageable transfer.
neg_inf_tensor = torch.full((), float("-inf"), device=logits.device)
logits[
current_offset + num_steps[index] * beam_idx + step, r.py_end_id
] = float("-inf")
] = neg_inf_tensor
else:
# early exit
break