From 4f6d4da035817a144a8539f684f10b554417692b Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Thu, 11 Dec 2025 13:55:31 -0800 Subject: [PATCH] [None][perf] Fix TPOT when `min_tokens` set (#9862) Signed-off-by: jthomson04 --- tensorrt_llm/_torch/pyexecutor/sampler.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 40d1450e45..83826eaad7 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -2230,9 +2230,14 @@ class TorchSampler(Sampler, AsyncWorkerMixin): for beam_idx in range(num_beams[index]): for step in range(num_steps[index]): if r.get_num_tokens(beam_idx) + step < r.py_min_length[0]: + # NOTE(jthomson04): We can NOT just assign logits[...] = float("-inf"). + # This introduces a pageable HtoD transfer, which wreaks havoc on TPOT (up to ~20%) + # Instead, we create a little tensor on device, then assign to that. + # This way, we avoid the pageable transfer. + neg_inf_tensor = torch.full((), float("-inf"), device=logits.device) logits[ current_offset + num_steps[index] * beam_idx + step, r.py_end_id - ] = float("-inf") + ] = neg_inf_tensor else: # early exit break