diff --git a/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py b/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py index 77cdcf0a38..be58227d9a 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py +++ b/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py @@ -611,6 +611,9 @@ class _StrategyImpls: generator: Optional[torch.Generator] = None, group_metadata: StrategyMetadata | None = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + # NB: Gumbel-max trick sampling used by flashinfer.sampling.sampling_from_logits + # is numerically tricky and was not observed to provide a performance advantage + # (cf. https://nvbugs/5791242). new_tokens, _ = self._sample_with_probs( logits, group_logit_indices=group_logit_indices, @@ -619,21 +622,6 @@ class _StrategyImpls: temperature=self._temperature, generator=generator, ) - # FIXME: https://nvbugs/5791242 - # logits = self._prepare_logits_with_temperature( - # logits, group_logit_indices, self._temperature - # ) - # new_tokens = flashinfer.sampling.sampling_from_logits( - # logits, - # # NB: Leveraging 'indices' would require applying temperature+softmax before batching, - # # because 'flashinfer.sampling.softmax' has no 'indices' argument; but that would - # # compute unnecessarily softmax also for situations allowing - # # flashinfer.sampling...._sampling_from_logits. - # # indices=group_logit_indices, - # deterministic=True, - # generator=generator, - # check_nan=self._flashinfer_check_nans(logits), - # ) return new_tokens, None class BeamSearchSampleOnly(BeamSearchMixin, StrategyImplSampleOnly):