TensorRT-LLMs/tensorrt_llm/_torch/speculative/interface.py
Simeng Liu 8cf3faa26a
[feat] Auto-enable ngram with concurrency <= 32. (#6232)
Signed-off-by: Simeng Liu <simengl@nvidia.com>
Signed-off-by: Mike Iovine <miovine@nvidia.com>
Signed-off-by: Mike Iovine <mike.iovine7@gmail.com>
Co-authored-by: Mike Iovine <miovine@nvidia.com>
Co-authored-by: Mike Iovine <mike.iovine7@gmail.com>
2025-07-31 18:45:51 -04:00

200 lines
7.0 KiB
Python

import copy
from dataclasses import dataclass, field
from enum import IntEnum, auto
from typing import List, Optional, Type
import torch
from ..._utils import get_sm_version
from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention
class SpeculativeDecodingMode(IntEnum):
MTP = auto()
MTP_EAGLE = auto()
EAGLE3 = auto()
EAGLE3_ONE_MODEL = auto()
NGRAM = auto()
DRAFT_TARGET = auto()
USER_PROVIDED = auto()
NONE = auto()
AUTO = auto()
def is_mtp(self):
return self == SpeculativeDecodingMode.MTP or self == SpeculativeDecodingMode.MTP_EAGLE
def is_mtp_eagle(self):
return self == SpeculativeDecodingMode.MTP_EAGLE
def is_eagle3(self):
return self == SpeculativeDecodingMode.EAGLE3
def use_one_engine(self):
return self.is_mtp() or self.is_eagle3_one_model()
def is_eagle3_one_model(self):
return self == SpeculativeDecodingMode.EAGLE3_ONE_MODEL
def is_ngram(self):
return self == SpeculativeDecodingMode.NGRAM
def is_user_provided(self):
return self == SpeculativeDecodingMode.USER_PROVIDED
def is_none(self):
return self == SpeculativeDecodingMode.NONE
def is_draft_target(self):
return self == SpeculativeDecodingMode.DRAFT_TARGET
def without_logits(self):
return self.is_mtp() or self.is_eagle3_one_model()
def needs_kv_cache_rewind(self):
return self.is_mtp() or self.is_eagle3_one_model() or self.is_ngram()
def support_overlap_scheduler(self):
return self.is_mtp() or self.is_eagle3_one_model()
def has_draft_model(self):
return self.is_eagle3() or self.is_draft_target()
def needs_kv_cache_recompute(self):
"""
Whether the draft model needs to recompute the kv cache.
If true, the 1st draft model forward will recompute the kv cache for
the accepted draft tokens.
"""
return self.is_eagle3()
def need_load_draft_weights(self):
"""
Whether the draft model and target model are in the same model engine,
and the draft model needs to load weights from the separate checkpoint.
"""
return self.is_eagle3_one_model()
def has_spec_decoder(self):
return self.is_mtp() or self.is_eagle3() or self.is_eagle3_one_model()
def has_spec_drafter(self):
return self.is_eagle3() or self.is_draft_target() or self.is_ngram(
) or self.is_user_provided()
def extend_ctx(self, attention_backend: Type[AttentionBackend]):
"""
If true, treat generation requests with draft tokens as
chunked context requests at the kernel level. Required for
any spec dec mode that uses the SpecExecutor.
"""
# Fixme: only trtllm attention backend supports eagle3 generation-phase kernels on blackwell.
return ((self.is_eagle3() or self.is_draft_target())
and not (isinstance(attention_backend, TrtllmAttention)
and get_sm_version() == 100)
) or self.is_ngram() or self.is_user_provided()
def attention_need_spec_dec_mode(self):
"""
If true, the attention backend kernel needs to run in spec-dec mode (multi-token query mode).
"""
return self.is_eagle3_one_model()
@staticmethod
def from_string(name: Optional[str]) -> "SpeculativeDecodingMode":
if name is None:
return SpeculativeDecodingMode.NONE
return SpeculativeDecodingMode[name.upper()]
@dataclass
class SpecMetadata:
"""
Metadata for speculative decoding.
"""
# The max number of requests in a single batch.
max_num_requests: int
# The max number of draft tokens.
max_draft_len: int
# The number of gen-phase sequences in the batch.
num_generations: int = 0
# Whether CUDA graph is enabled.
is_cuda_graph: bool = field(default=False, repr=False)
# The mode of speculative decoding.
spec_dec_mode: SpeculativeDecodingMode = SpeculativeDecodingMode.NONE,
# Draft tokens.
draft_tokens: Optional[torch.Tensor] = None,
# The length of the draft tokens.
draft_lens: Optional[torch.Tensor] = None,
# The request ID of each sequence in the batch.
# The shape is (batch_size).
request_ids: Optional[List[int]] = None
# Sequence length for each request.
seq_lens: Optional[List[int]] = None
# The gather ids for logits.
gather_ids: Optional[torch.Tensor] = None
# The number of tokens for speculative model/layer
num_tokens: int = 0
# The number of tokens for speculative model/layer of different rank
_all_rank_num_tokens: Optional[List[int]] = field(init=False,
default=None,
repr=False)
all_rank_num_tokens: Optional[List[int]]
# The max number of tokens among all ranks.
all_rank_max_num_tokens: Optional[int] = None
# The number of sequences for speculative model/layer of different rank
all_rank_num_seqs: Optional[List[int]] = None
# The number of extra kv tokens
# Some speculative decoding methods need to use different kv lengths for the
# draft/target layers. But KVCacheManager can only support kv caches with the
# same kv lengths for different layers. Add extra kv token in kv cache manager
# to handle this issue.
num_extra_kv_tokens: Optional[int] = 0 # Number of layers in target model
# The number of layers
num_layers: int = 0
# if spec-dec tree is a tree or a chain (linear tree)
is_spec_dec_tree: bool = False
# if spec-dec tree wouldn't be changed at all, the mask won't be computed every step.
is_spec_dec_dynamic_tree: bool = False
def __post_init__(self):
pass
def prepare(self):
"""
Hook to be called before the forward step of the model.
"""
def create_cuda_graph_metadata(self, max_batch_size: int):
"""
Creates metadata for CUDA graph execution.
"""
if self.is_cuda_graph:
return self
cuda_graph_metadata = copy.copy(self)
cuda_graph_metadata.is_cuda_graph = True
cuda_graph_metadata.max_num_requests = max_batch_size
cuda_graph_metadata.__post_init__()
return cuda_graph_metadata
def maybe_capture_hidden_states(self, layer_id: int,
hidden_states: torch.Tensor,
residual: torch.Tensor) -> None:
"""
Some spec decode algorithms require hidden states from the target
model. Use this method to record them. By default, does nothing.
"""
@property
def all_rank_num_tokens(self) -> Optional[List[int]]:
return self._all_rank_num_tokens
@all_rank_num_tokens.setter
def all_rank_num_tokens(self, value: Optional[List[int]]):
value = value if value is not SpecMetadata.all_rank_num_tokens else None
self._all_rank_num_tokens = value
self.all_rank_max_num_tokens = max(value) if value is not None else None