mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
[https://nvbugs/5791242][fix] workaround for flashinfer.sampling.sampling_from_logits (#10713)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
parent
8959c41d8b
commit
0f7ec033f7
@ -619,20 +619,29 @@ class _StrategyImpls:
|
||||
generator: Optional[torch.Generator] = None,
|
||||
group_metadata: StrategyMetadata | None = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
logits = self._prepare_logits_with_temperature(
|
||||
logits, group_logit_indices, self._temperature
|
||||
)
|
||||
new_tokens = flashinfer.sampling.sampling_from_logits(
|
||||
new_tokens, _ = self._sample_with_probs(
|
||||
logits,
|
||||
# NB: Leveraging 'indices' would require applying temperature+softmax before batching,
|
||||
# because 'flashinfer.sampling.softmax' has no 'indices' argument; but that would
|
||||
# compute unnecessarily softmax also for situations allowing
|
||||
# flashinfer.sampling...._sampling_from_logits.
|
||||
# indices=group_logit_indices,
|
||||
deterministic=True,
|
||||
group_logit_indices=group_logit_indices,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
temperature=self._temperature,
|
||||
generator=generator,
|
||||
check_nan=self._flashinfer_check_nans(logits),
|
||||
)
|
||||
# FIXME: https://nvbugs/5791242
|
||||
# logits = self._prepare_logits_with_temperature(
|
||||
# logits, group_logit_indices, self._temperature
|
||||
# )
|
||||
# new_tokens = flashinfer.sampling.sampling_from_logits(
|
||||
# logits,
|
||||
# # NB: Leveraging 'indices' would require applying temperature+softmax before batching,
|
||||
# # because 'flashinfer.sampling.softmax' has no 'indices' argument; but that would
|
||||
# # compute unnecessarily softmax also for situations allowing
|
||||
# # flashinfer.sampling...._sampling_from_logits.
|
||||
# # indices=group_logit_indices,
|
||||
# deterministic=True,
|
||||
# generator=generator,
|
||||
# check_nan=self._flashinfer_check_nans(logits),
|
||||
# )
|
||||
return new_tokens, None
|
||||
|
||||
class BeamSearchSampleOnly(BeamSearchMixin, StrategyImplSampleOnly):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user