diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 8dffa020c2..f4908c8b79 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -12,11 +12,15 @@ from tensorrt_llm.logger import logger from ..._utils import get_sm_version from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention +from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE from ..pyexecutor.resource_manager import BaseResourceManager if TYPE_CHECKING: from ..pyexecutor.guided_decoder import CapturableGuidedDecoder +if IS_FLASHINFER_AVAILABLE: + import flashinfer + # Environment variable name for forcing the number of accepted tokens in speculative decoding FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR = "TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS" @@ -371,6 +375,9 @@ class SpecWorkerBase(nn.Module, ABC): super().__init__() self.guided_decoder: Optional["CapturableGuidedDecoder"] = None self.force_num_accepted_tokens = get_force_num_accepted_tokens() + self.use_flashinfer = IS_FLASHINFER_AVAILABLE and flashinfer.__version__ >= "0.6.0" + self.seed = 0 + self.offset = 0 @property @abstractmethod @@ -446,8 +453,17 @@ class SpecWorkerBase(nn.Module, ABC): top_ks = spec_metadata.top_ks[:num_tokens] top_ps = spec_metadata.top_ps[:num_tokens] + if self.use_flashinfer: + self.seed += 1 + sampled_tokens = sampling_batch_spec_dec_one_model( - logits, temperatures, top_ks, top_ps) + logits, + temperatures, + top_ks, + top_ps, + use_flashinfer=self.use_flashinfer, + seed=self.seed, + offset=self.offset) else: sampled_tokens = torch.argmax(logits, dim=-1) diff --git a/tensorrt_llm/_torch/speculative/one_model_sampler.py b/tensorrt_llm/_torch/speculative/one_model_sampler.py index ca48c03f28..7d49aa85dd 100644 --- a/tensorrt_llm/_torch/speculative/one_model_sampler.py +++ b/tensorrt_llm/_torch/speculative/one_model_sampler.py @@ -1,6 +1,7 @@ from typing import Optional import torch +from flashinfer.sampling import top_k_top_p_sampling_from_logits def forward_native( @@ -78,6 +79,9 @@ def sampling_batch_spec_dec_one_model( temperatures: torch.Tensor, top_k: torch.Tensor, top_p: torch.Tensor, + use_flashinfer: bool = False, + seed: Optional[int] = None, + offset: Optional[int] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ CUDA-graph compatible sampling. Supports mixed sampling params. @@ -87,5 +91,7 @@ def sampling_batch_spec_dec_one_model( sampling is opt-in for now. """ logits = apply_temperature(logits, temperatures) + if use_flashinfer: + return top_k_top_p_sampling_from_logits(logits, top_k, top_p, seed=seed, offset=offset) random_sampled = forward_native(logits, top_k, top_p) return random_sampled