[None][feat] Speculative One Model: FlashInfer sampling (#10284)

Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
This commit is contained in:
Izzy Putterman 2026-01-20 09:56:43 -08:00 committed by GitHub
parent 66b239a9a9
commit 864b61cadd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 1 deletions

View File

@ -12,11 +12,15 @@ from tensorrt_llm.logger import logger
from ..._utils import get_sm_version
from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
from ..pyexecutor.resource_manager import BaseResourceManager
if TYPE_CHECKING:
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
if IS_FLASHINFER_AVAILABLE:
import flashinfer
# Environment variable name for forcing the number of accepted tokens in speculative decoding
FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR = "TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS"
@ -371,6 +375,9 @@ class SpecWorkerBase(nn.Module, ABC):
super().__init__()
self.guided_decoder: Optional["CapturableGuidedDecoder"] = None
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
self.use_flashinfer = IS_FLASHINFER_AVAILABLE and flashinfer.__version__ >= "0.6.0"
self.seed = 0
self.offset = 0
@property
@abstractmethod
@ -446,8 +453,17 @@ class SpecWorkerBase(nn.Module, ABC):
top_ks = spec_metadata.top_ks[:num_tokens]
top_ps = spec_metadata.top_ps[:num_tokens]
if self.use_flashinfer:
self.seed += 1
sampled_tokens = sampling_batch_spec_dec_one_model(
logits, temperatures, top_ks, top_ps)
logits,
temperatures,
top_ks,
top_ps,
use_flashinfer=self.use_flashinfer,
seed=self.seed,
offset=self.offset)
else:
sampled_tokens = torch.argmax(logits, dim=-1)

View File

@ -1,6 +1,7 @@
from typing import Optional
import torch
from flashinfer.sampling import top_k_top_p_sampling_from_logits
def forward_native(
@ -78,6 +79,9 @@ def sampling_batch_spec_dec_one_model(
temperatures: torch.Tensor,
top_k: torch.Tensor,
top_p: torch.Tensor,
use_flashinfer: bool = False,
seed: Optional[int] = None,
offset: Optional[int] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
CUDA-graph compatible sampling. Supports mixed sampling params.
@ -87,5 +91,7 @@ def sampling_batch_spec_dec_one_model(
sampling is opt-in for now.
"""
logits = apply_temperature(logits, temperatures)
if use_flashinfer:
return top_k_top_p_sampling_from_logits(logits, top_k, top_p, seed=seed, offset=offset)
random_sampled = forward_native(logits, top_k, top_p)
return random_sampled