import copy import os from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import IntEnum, auto from typing import TYPE_CHECKING, List, Optional, Type import torch from torch import nn from tensorrt_llm.logger import logger from ..._utils import get_sm_version from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention from ..pyexecutor.resource_manager import BaseResourceManager if TYPE_CHECKING: from ..pyexecutor.guided_decoder import CapturableGuidedDecoder # Environment variable name for forcing the number of accepted tokens in speculative decoding FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR = "TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS" def get_force_num_accepted_tokens() -> int: """ Read and parse the TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS environment variable. Returns: int: The forced number of accepted tokens, or 0 if not set or invalid. """ env_value = os.environ.get(FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR, "0") try: return int(env_value) except ValueError: logger.warning( f"{FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR} must be a valid integer, " f"got '{env_value}'. Using default value 0.") return 0 class SpeculativeDecodingMode(IntEnum): MTP = auto() MTP_EAGLE = auto() MTP_EAGLE_ONE_MODEL = auto() EAGLE3 = auto() EAGLE3_ONE_MODEL = auto() NGRAM = auto() DRAFT_TARGET = auto() USER_PROVIDED = auto() SAVE_HIDDEN_STATES = auto() NONE = auto() AUTO = auto() def is_mtp_one_model(self): return self == SpeculativeDecodingMode.MTP or self == SpeculativeDecodingMode.MTP_EAGLE_ONE_MODEL def is_mtp_eagle_one_model(self): return self == SpeculativeDecodingMode.MTP_EAGLE_ONE_MODEL def is_mtp_vanilla(self): return self == SpeculativeDecodingMode.MTP def is_mtp_eagle(self): return self == SpeculativeDecodingMode.MTP_EAGLE def is_eagle3(self): return self == SpeculativeDecodingMode.EAGLE3 def use_one_engine(self): return self.is_eagle3_one_model() or self.is_mtp_one_model() def is_eagle3_one_model(self): return self == SpeculativeDecodingMode.EAGLE3_ONE_MODEL def is_ngram(self): return self == SpeculativeDecodingMode.NGRAM def is_user_provided(self): return self == SpeculativeDecodingMode.USER_PROVIDED def is_none(self): return self == SpeculativeDecodingMode.NONE def is_draft_target(self): return self == SpeculativeDecodingMode.DRAFT_TARGET def is_save_hidden_states(self): return self == SpeculativeDecodingMode.SAVE_HIDDEN_STATES def without_logits(self): return self.is_mtp_one_model() or self.is_eagle3_one_model() def needs_kv_cache_rewind(self): return self.is_mtp_one_model() or self.is_eagle3_one_model( ) or self.is_ngram() def support_overlap_scheduler(self): return self.is_mtp_one_model() or self.is_eagle3_one_model( ) or self.has_draft_model() def support_guided_decoder(self): return self.is_none() or self.has_spec_drafter() def support_capturable_guided_decoder(self): return self.is_mtp_one_model() or self.is_eagle3_one_model() def has_draft_model(self): return self.is_eagle3() or self.is_draft_target() or self.is_mtp_eagle() def needs_kv_cache_recompute(self): """ Whether the draft model needs to recompute the kv cache. If true, the 1st draft model forward will recompute the kv cache for the accepted draft tokens. """ return self.is_eagle3() or self.is_mtp_eagle() def need_load_draft_weights(self): """ Whether the draft model and target model are in the same model engine, and the draft model needs to load weights from the separate checkpoint. """ return self.is_eagle3_one_model() def has_spec_decoder(self): return self.is_mtp_one_model() or self.is_mtp_eagle() or self.is_eagle3( ) or self.is_eagle3_one_model() def has_spec_drafter(self): return self.is_eagle3( ) or self.is_draft_target() or self.is_ngram() or self.is_user_provided( ) or self.is_mtp_eagle() or self.is_save_hidden_states() def extend_ctx(self, attention_backend: Type[AttentionBackend]): """ If true, treat generation requests with draft tokens as chunked context requests at the kernel level. """ if self.use_one_engine(): # 1-model has separate logic for handling draft tokens return False return not issubclass(attention_backend, TrtllmAttention) or get_sm_version() < 90 def attention_need_spec_dec_mode( self, spec_resource_manager: Optional[BaseResourceManager], is_draft_model: bool, attention_backend: Type[AttentionBackend], use_chain_drafter: bool, # CDL ): """ If true, the attention backend kernel needs to run in spec-dec mode (multi-token query mode). Args: spec_resource_manager: the resource manager for the spec-dec mode. is_draft_model: whether the model is a draft model. attention_backend: the attention backend. use_chain_drafter: whether to use capturable drafting loops (CDL). For the target model, it is always False. """ is_trtllm_attention = issubclass(attention_backend, TrtllmAttention) # Always use the multi-token query mode for 1-model. # For 2-model, we need to enable it when we process multiple tokens at once. This occurs with # the target model (verification) or on the first draft for CDL based speculation. use_case_1 = self.is_eagle3_one_model() use_case_2 = (not is_draft_model or (spec_resource_manager is not None and spec_resource_manager.is_first_draft and use_chain_drafter)) and is_trtllm_attention return use_case_1 or use_case_2 @staticmethod def from_string(name: Optional[str]) -> "SpeculativeDecodingMode": if name is None: return SpeculativeDecodingMode.NONE return SpeculativeDecodingMode[name.upper()] @dataclass class SpecMetadata: """ Metadata for speculative decoding. """ # The max number of requests in a single batch. max_num_requests: int # The number of draft layers. (Also the number of draft tokens for the linear tree.) max_draft_len: int # The max number of draft tokens for the static tree and dynamic tree . max_total_draft_tokens: int # The number of gen-phase sequences in the batch. num_generations: int = 0 # Whether CUDA graph is enabled. is_cuda_graph: bool = field(default=False, repr=False) # The mode of speculative decoding. spec_dec_mode: SpeculativeDecodingMode = SpeculativeDecodingMode.NONE # Draft tokens. draft_tokens: Optional[torch.Tensor] = None # The length of the draft tokens. draft_lens: Optional[torch.Tensor] = None # The request ID of each sequence in the batch. # The shape is (batch_size). request_ids: Optional[List[int]] = None # Sequence length for each request. seq_lens: Optional[List[int]] = None # The gather ids for logits. gather_ids: Optional[torch.Tensor] = None # The number of accepted draft tokens for each request. num_accepted_draft_tokens: Optional[torch.Tensor] = None # The number of tokens for speculative model/layer num_tokens: int = 0 # The number of tokens for speculative model/layer of different rank all_rank_num_tokens: Optional[List[int]] = None # The number of sequences for speculative model/layer of different rank all_rank_num_seqs: Optional[List[int]] = None # The number of extra kv tokens # Some speculative decoding methods need to use different kv lengths for the # draft/target layers. But KVCacheManager can only support kv caches with the # same kv lengths for different layers. Add extra kv token in kv cache manager # to handle this issue. num_extra_kv_tokens: Optional[int] = 0 # Number of layers in target model # The number of layers num_layers: int = 0 # if spec-dec tree wouldn't be changed at all, the mask won't be computed every step. # NOTE: For the linear tree, though it can be treated as a special case of static tree. # NOTE: But we do not set `is_spec_dec_tree` to True for this cases. # NOTE: i.e., for the linear tree, is_spec_dec_tree == False and is_spec_dec_dynamic_tree == False. # whether the spec-dec mode is a tree (can be static tree or dynamic tree). is_spec_dec_tree: bool = False # whether the spec-dec mode is a dynamic tree. is_spec_dec_dynamic_tree: bool = False # For non-greedy sampling on 1-model. allow_advanced_sampling: bool = False # Sampling parameters for non-greedy sampling (per-request) temperatures: Optional[torch.Tensor] = None top_ks: Optional[torch.Tensor] = None top_ps: Optional[torch.Tensor] = None def __post_init__(self): pass def prepare(self): """ Hook to be called before the forward step of the model. """ def create_cuda_graph_metadata(self, max_batch_size: int): """ Creates metadata for CUDA graph execution. """ if self.is_cuda_graph: return self cuda_graph_metadata = copy.copy(self) cuda_graph_metadata.is_cuda_graph = True cuda_graph_metadata.max_num_requests = max_batch_size cuda_graph_metadata.__post_init__() return cuda_graph_metadata def is_layer_capture(self, layer_id: int): """ Whether the layer should be captured (eg for Eagle3). By default, does nothing. """ return False def maybe_capture_hidden_states(self, layer_id: int, hidden_states: torch.Tensor, residual: torch.Tensor) -> None: """ Some spec decode algorithms require hidden states from the target model. Use this method to record them. By default, does nothing. """ def populate_sampling_params_for_one_model( self, requests: list["LlmRequest"]) -> None: """ Set up topp/topk/temperatures for 1-model sampler. """ from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState from tensorrt_llm.sampling_params import SamplingParams if not self.allow_advanced_sampling or not self.spec_dec_mode.use_one_engine( ): return if self.temperatures is None: # Ensures determinism across ranks. torch.manual_seed(0) temperatures = [] top_ks = [] top_ps = [] # Need to use a very small value for temperature when disabled to avoid division by 0 DISABLE_TEMP_VAL = 1e-5 # Very large values disable topk. DISABLE_TOPK_VAL = torch.iinfo(torch.int32).max DISABLE_TOPP_VAL = 1.0 for request in requests: sampling_config = request.sampling_config temp = sampling_config.temperature temp_val = temp[0] if temp is not None and len(temp) > 0 else None tk = sampling_config.top_k tk_val = tk[0] if tk is not None and len(tk) > 0 else None tp = sampling_config.top_p tp_val = tp[0] if tp is not None and len(tp) > 0 else None # Context requests have no draft tokens yet. num_tokens = 1 + self.max_draft_len if request.state == LlmRequestState.GENERATION_IN_PROGRESS else 1 is_greedy = SamplingParams.params_imply_greedy_decoding( temperature=temp_val, top_k=tk_val, top_p=tp_val, use_beam_search=False) temp_val = DISABLE_TEMP_VAL if is_greedy or temp_val is None or temp_val == 0 else temp_val tk_val = DISABLE_TOPK_VAL if is_greedy or tk_val is None or tk_val <= 0 else tk_val tp_val = DISABLE_TOPP_VAL if is_greedy or tp_val is None else tp_val temperatures.extend(temp_val for _ in range(num_tokens)) top_ks.extend(tk_val for _ in range(num_tokens)) top_ps.extend(tp_val for _ in range(num_tokens)) if self.temperatures is None: self.temperatures = torch.ones( (self.max_draft_len + 1) * self.max_num_requests, dtype=torch.float32, device='cuda') self.top_ks = torch.zeros( (self.max_draft_len + 1) * self.max_num_requests, dtype=torch.int32, device='cuda') self.top_ps = torch.ones( (self.max_draft_len + 1) * self.max_num_requests, dtype=torch.float32, device='cuda') self.temperatures[:len(temperatures)].copy_(torch.tensor( temperatures, dtype=torch.float32, pin_memory=True), non_blocking=True) self.top_ks[:len(top_ks)].copy_(torch.tensor(top_ks, dtype=torch.int32, pin_memory=True), non_blocking=True) self.top_ps[:len(top_ps)].copy_(torch.tensor(top_ps, dtype=torch.float32, pin_memory=True), non_blocking=True) class SpecWorkerBase(nn.Module, ABC): """ Base class for speculative decoding workers. Provides common functionality for sampling and token handling. """ def __init__(self): super().__init__() self.guided_decoder: Optional["CapturableGuidedDecoder"] = None self.force_num_accepted_tokens = get_force_num_accepted_tokens() @property @abstractmethod def max_draft_len(self) -> int: """ Returns the maximum draft length for this worker. Subclasses should override this property. """ def set_guided_decoder(self, guided_decoder: "CapturableGuidedDecoder") -> bool: self.guided_decoder = guided_decoder return True def _sample_tokens_for_batch( self, logits: torch.Tensor, spec_metadata: SpecMetadata, num_contexts: int, batch_size: int, ) -> torch.Tensor: """ Sample tokens from logits using per-request sampling parameters. Supports both greedy and non-greedy sampling. Args: logits: [num_tokens, vocab_size] - Logits to sample from spec_metadata: Metadata containing sampling parameters num_contexts: Number of context requests in the batch batch_size: Number of requests in the batch Returns: sampled_tokens: [num_tokens] - Sampled token ids """ if spec_metadata.allow_advanced_sampling: from .one_model_sampler import sampling_batch_spec_dec_one_model num_gens = batch_size - num_contexts num_tokens = num_contexts + num_gens * (self.max_draft_len + 1) temperatures = spec_metadata.temperatures[:num_tokens] top_ks = spec_metadata.top_ks[:num_tokens] top_ps = spec_metadata.top_ps[:num_tokens] sampled_tokens = sampling_batch_spec_dec_one_model( logits, temperatures, top_ks, top_ps) else: sampled_tokens = torch.argmax(logits, dim=-1) return sampled_tokens