mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* v1.5 Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com> v1.5.4 Add back draft_overhead to spec dec stats Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com> * v1.5.5: fix CI error Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com> * v1.6: fix CI error 8196 > 8192 Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com> * Address reviewer concerns Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com> * Address reviewer concerns Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com> * precommit run Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com> * v2.0: Address reviewer concerns Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com> * v2.1: add fix from wili Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com> * Revert changes that require use of TypeAlias because that requires python version >= 3.10 Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com> --------- Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com> Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com> Co-authored-by: wili-65535 <wili-65535@users.noreply.github.com>
175 lines
6.0 KiB
Python
175 lines
6.0 KiB
Python
import copy
|
|
from dataclasses import dataclass, field
|
|
from enum import IntEnum, auto
|
|
from typing import Dict, List, Optional
|
|
|
|
import torch
|
|
|
|
from ..._utils import get_sm_version
|
|
from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention
|
|
from ..model_config import TConfig
|
|
from ..pyexecutor.scheduler import ScheduledRequests
|
|
|
|
|
|
class SpeculativeDecodingMode(IntEnum):
|
|
MTP = auto()
|
|
MTP_EAGLE = auto()
|
|
EAGLE3 = auto()
|
|
NGRAM = auto()
|
|
NONE = auto()
|
|
|
|
def is_mtp(self):
|
|
return self == SpeculativeDecodingMode.MTP or self == SpeculativeDecodingMode.MTP_EAGLE
|
|
|
|
def is_mtp_eagle(self):
|
|
return self == SpeculativeDecodingMode.MTP_EAGLE
|
|
|
|
def is_eagle3(self):
|
|
return self == SpeculativeDecodingMode.EAGLE3
|
|
|
|
def is_ngram(self):
|
|
return self == SpeculativeDecodingMode.NGRAM
|
|
|
|
def is_none(self):
|
|
return self == SpeculativeDecodingMode.NONE
|
|
|
|
def needs_kv_cache_rewind(self):
|
|
return self.is_mtp()
|
|
|
|
def support_overlap_scheduler(self):
|
|
return self.is_mtp()
|
|
|
|
def has_spec_decoder(self):
|
|
return self.is_mtp() or self.is_eagle3()
|
|
|
|
def extend_ctx(self, attention_backend: AttentionBackend):
|
|
"""
|
|
If true, treat generation requests with draft tokens as
|
|
chunked context requests at the kernel level. Required for
|
|
any spec dec mode that uses the SpecExecutor.
|
|
"""
|
|
|
|
# Fixme: only trtllm attention backend supports eagle3 generation-phase kernels on blackwell.
|
|
return (self.is_eagle3()
|
|
and not (isinstance(attention_backend, TrtllmAttention)
|
|
and get_sm_version() == 100)) or self.is_ngram()
|
|
|
|
@staticmethod
|
|
def from_string(name: Optional[str]) -> "SpeculativeDecodingMode":
|
|
if name is None:
|
|
return SpeculativeDecodingMode.NONE
|
|
return SpeculativeDecodingMode[name.upper()]
|
|
|
|
|
|
@dataclass
|
|
class SpecConfig:
|
|
"""
|
|
Configuration for speculative decoding.
|
|
"""
|
|
# The name of speculative decoding.
|
|
spec_dec_name = None
|
|
# The mode of speculative decoding.
|
|
spec_dec_mode: SpeculativeDecodingMode = SpeculativeDecodingMode.NONE
|
|
# The max number of draft tokens
|
|
max_draft_tokens: int = 1024
|
|
|
|
def __post_init__(self) -> None:
|
|
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
|
|
self.spec_dec_name)
|
|
|
|
def update_from_model_config(self, model_config: TConfig):
|
|
pass
|
|
|
|
def get_draft_model_prompt(self,
|
|
input_tokens: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Override for spec dec modes that need to preprocess prompt
|
|
tokens before passing them to the draft model.
|
|
"""
|
|
return input_tokens
|
|
|
|
|
|
@dataclass
|
|
class SpecMetadata:
|
|
"""
|
|
Metadata for speculative decoding.
|
|
"""
|
|
# The max number of requests in a single batch.
|
|
max_num_requests: int
|
|
# The max number of draft tokens.
|
|
max_draft_tokens: int
|
|
# The number of gen-phase sequences in the batch.
|
|
num_generations: int = 0
|
|
# Whether CUDA graph is enabled.
|
|
is_cuda_graph: bool = field(default=False, repr=False)
|
|
# The mode of speculative decoding.
|
|
spec_dec_mode: SpeculativeDecodingMode = SpeculativeDecodingMode.NONE,
|
|
# Draft tokens.
|
|
draft_tokens: Optional[torch.Tensor] = None,
|
|
# The length of the draft tokens.
|
|
draft_lens: Optional[torch.Tensor] = None,
|
|
# The request ID of each sequence in the batch.
|
|
# The shape is (batch_size).
|
|
request_ids: Optional[List[int]] = None
|
|
# Sequence length for each request.
|
|
seq_lens: Optional[List[int]] = None
|
|
# The gather ids for logits.
|
|
gather_ids: Optional[torch.Tensor] = None
|
|
# The number of tokens for speculative model/layer
|
|
num_tokens: int = 0
|
|
# The number of tokens for speculative model/layer of different rank
|
|
all_rank_num_tokens: Optional[List[int]] = None
|
|
# The number of sequences for speculative model/layer of different rank
|
|
all_rank_num_seqs: Optional[List[int]] = None
|
|
# The number of extra kv tokens
|
|
# Some speculative decoding methods need to use different kv lengths for the
|
|
# draft/target layers. But KVCacheManager can only support kv caches with the
|
|
# same kv lengths for different layers. Add extra kv token in kv cache manager
|
|
# to haddle this issue.
|
|
num_extra_kv_tokens: Optional[int] = 0 # Number of layers in target model
|
|
num_layers: int = 0
|
|
|
|
def prepare(self):
|
|
"""
|
|
Hook to be called before the forward step of the model.
|
|
"""
|
|
|
|
def create_cuda_graph_metadata(self, max_batch_size: int):
|
|
"""
|
|
Creates metadata for CUDA graph execution.
|
|
"""
|
|
if self.is_cuda_graph:
|
|
return self
|
|
|
|
cuda_graph_metadata = copy.copy(self)
|
|
cuda_graph_metadata.is_cuda_graph = True
|
|
cuda_graph_metadata.max_num_requests = max_batch_size
|
|
cuda_graph_metadata.__post_init__()
|
|
return cuda_graph_metadata
|
|
|
|
def maybe_capture_hidden_states(self, layer_id: int,
|
|
hidden_states: torch.Tensor,
|
|
residual: torch.Tensor) -> None:
|
|
"""
|
|
Some spec decode algorithms require hidden states from the target
|
|
model. Use this method to record them. By default, does nothing.
|
|
"""
|
|
|
|
def get_hidden_states(
|
|
self,
|
|
scheduled_requests: ScheduledRequests,
|
|
num_rejected_tokens: Optional[Dict] = None) -> List[torch.Tensor]:
|
|
"""
|
|
Return any captured hidden states. Should do any necessary
|
|
pre-processing.
|
|
|
|
num_rejected_tokens is a dictionary mapping request IDs to the
|
|
number of tokens rejected for that request. If a request ID isn't
|
|
in the dictionary, it means that the request is not needed for drafting.
|
|
|
|
If the dictionary is not given, this function assumes that the hidden
|
|
states are being prepared for running the draft model autoregressively,
|
|
and only the last hidden state vector for each sequence is returned.
|
|
"""
|
|
return []
|