mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-06 19:21:52 +08:00
[None][feat] Speculative One Model: FlashInfer sampling (#10284)
Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
This commit is contained in:
parent
66b239a9a9
commit
864b61cadd
@ -12,11 +12,15 @@ from tensorrt_llm.logger import logger
|
||||
|
||||
from ..._utils import get_sm_version
|
||||
from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention
|
||||
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
|
||||
from ..pyexecutor.resource_manager import BaseResourceManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
|
||||
|
||||
if IS_FLASHINFER_AVAILABLE:
|
||||
import flashinfer
|
||||
|
||||
# Environment variable name for forcing the number of accepted tokens in speculative decoding
|
||||
FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR = "TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS"
|
||||
|
||||
@ -371,6 +375,9 @@ class SpecWorkerBase(nn.Module, ABC):
|
||||
super().__init__()
|
||||
self.guided_decoder: Optional["CapturableGuidedDecoder"] = None
|
||||
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
|
||||
self.use_flashinfer = IS_FLASHINFER_AVAILABLE and flashinfer.__version__ >= "0.6.0"
|
||||
self.seed = 0
|
||||
self.offset = 0
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@ -446,8 +453,17 @@ class SpecWorkerBase(nn.Module, ABC):
|
||||
top_ks = spec_metadata.top_ks[:num_tokens]
|
||||
top_ps = spec_metadata.top_ps[:num_tokens]
|
||||
|
||||
if self.use_flashinfer:
|
||||
self.seed += 1
|
||||
|
||||
sampled_tokens = sampling_batch_spec_dec_one_model(
|
||||
logits, temperatures, top_ks, top_ps)
|
||||
logits,
|
||||
temperatures,
|
||||
top_ks,
|
||||
top_ps,
|
||||
use_flashinfer=self.use_flashinfer,
|
||||
seed=self.seed,
|
||||
offset=self.offset)
|
||||
else:
|
||||
sampled_tokens = torch.argmax(logits, dim=-1)
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from flashinfer.sampling import top_k_top_p_sampling_from_logits
|
||||
|
||||
|
||||
def forward_native(
|
||||
@ -78,6 +79,9 @@ def sampling_batch_spec_dec_one_model(
|
||||
temperatures: torch.Tensor,
|
||||
top_k: torch.Tensor,
|
||||
top_p: torch.Tensor,
|
||||
use_flashinfer: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
CUDA-graph compatible sampling. Supports mixed sampling params.
|
||||
@ -87,5 +91,7 @@ def sampling_batch_spec_dec_one_model(
|
||||
sampling is opt-in for now.
|
||||
"""
|
||||
logits = apply_temperature(logits, temperatures)
|
||||
if use_flashinfer:
|
||||
return top_k_top_p_sampling_from_logits(logits, top_k, top_p, seed=seed, offset=offset)
|
||||
random_sampled = forward_native(logits, top_k, top_p)
|
||||
return random_sampled
|
||||
|
||||
Loading…
Reference in New Issue
Block a user