From 0f7ec033f74e08ec94d3ae0af917e6364b9c9b85 Mon Sep 17 00:00:00 2001 From: mpikulski <206748156+ixlmar@users.noreply.github.com> Date: Tue, 20 Jan 2026 05:27:01 +0100 Subject: [PATCH] [https://nvbugs/5791242][fix] workaround for flashinfer.sampling.sampling_from_logits (#10713) Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- .../pyexecutor/sampling_utils_flashinfer.py | 31 ++++++++++++------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py b/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py index 4114c2310f..75d5edb40e 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py +++ b/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py @@ -619,20 +619,29 @@ class _StrategyImpls: generator: Optional[torch.Generator] = None, group_metadata: StrategyMetadata | None = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - logits = self._prepare_logits_with_temperature( - logits, group_logit_indices, self._temperature - ) - new_tokens = flashinfer.sampling.sampling_from_logits( + new_tokens, _ = self._sample_with_probs( logits, - # NB: Leveraging 'indices' would require applying temperature+softmax before batching, - # because 'flashinfer.sampling.softmax' has no 'indices' argument; but that would - # compute unnecessarily softmax also for situations allowing - # flashinfer.sampling...._sampling_from_logits. - # indices=group_logit_indices, - deterministic=True, + group_logit_indices=group_logit_indices, + top_k=None, + top_p=None, + temperature=self._temperature, generator=generator, - check_nan=self._flashinfer_check_nans(logits), ) + # FIXME: https://nvbugs/5791242 + # logits = self._prepare_logits_with_temperature( + # logits, group_logit_indices, self._temperature + # ) + # new_tokens = flashinfer.sampling.sampling_from_logits( + # logits, + # # NB: Leveraging 'indices' would require applying temperature+softmax before batching, + # # because 'flashinfer.sampling.softmax' has no 'indices' argument; but that would + # # compute unnecessarily softmax also for situations allowing + # # flashinfer.sampling...._sampling_from_logits. + # # indices=group_logit_indices, + # deterministic=True, + # generator=generator, + # check_nan=self._flashinfer_check_nans(logits), + # ) return new_tokens, None class BeamSearchSampleOnly(BeamSearchMixin, StrategyImplSampleOnly):