diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 40d1450e45..83826eaad7 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -2230,9 +2230,14 @@ class TorchSampler(Sampler, AsyncWorkerMixin): for beam_idx in range(num_beams[index]): for step in range(num_steps[index]): if r.get_num_tokens(beam_idx) + step < r.py_min_length[0]: + # NOTE(jthomson04): We can NOT just assign logits[...] = float("-inf"). + # This introduces a pageable HtoD transfer, which wreaks havoc on TPOT (up to ~20%) + # Instead, we create a little tensor on device, then assign to that. + # This way, we avoid the pageable transfer. + neg_inf_tensor = torch.full((), float("-inf"), device=logits.device) logits[ current_offset + num_steps[index] * beam_idx + step, r.py_end_id - ] = float("-inf") + ] = neg_inf_tensor else: # early exit break