[https://nvbugs/5791242][fix] workaround for flashinfer.sampling.sampling_from_logits (#10713)

Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
mpikulski 2026-01-20 05:27:01 +01:00 committed by Yanchao Lu
parent 8959c41d8b
commit 0f7ec033f7

View File

@ -619,20 +619,29 @@ class _StrategyImpls:
generator: Optional[torch.Generator] = None,
group_metadata: StrategyMetadata | None = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self._prepare_logits_with_temperature(
logits, group_logit_indices, self._temperature
)
new_tokens = flashinfer.sampling.sampling_from_logits(
new_tokens, _ = self._sample_with_probs(
logits,
# NB: Leveraging 'indices' would require applying temperature+softmax before batching,
# because 'flashinfer.sampling.softmax' has no 'indices' argument; but that would
# compute unnecessarily softmax also for situations allowing
# flashinfer.sampling...._sampling_from_logits.
# indices=group_logit_indices,
deterministic=True,
group_logit_indices=group_logit_indices,
top_k=None,
top_p=None,
temperature=self._temperature,
generator=generator,
check_nan=self._flashinfer_check_nans(logits),
)
# FIXME: https://nvbugs/5791242
# logits = self._prepare_logits_with_temperature(
# logits, group_logit_indices, self._temperature
# )
# new_tokens = flashinfer.sampling.sampling_from_logits(
# logits,
# # NB: Leveraging 'indices' would require applying temperature+softmax before batching,
# # because 'flashinfer.sampling.softmax' has no 'indices' argument; but that would
# # compute unnecessarily softmax also for situations allowing
# # flashinfer.sampling...._sampling_from_logits.
# # indices=group_logit_indices,
# deterministic=True,
# generator=generator,
# check_nan=self._flashinfer_check_nans(logits),
# )
return new_tokens, None
class BeamSearchSampleOnly(BeamSearchMixin, StrategyImplSampleOnly):