TensorRT-LLMs/tensorrt_llm/_torch/speculative/interface.py
Kaiyu Xie 77d7fe1eb2
Update TensorRT-LLM (#2849)
* Update TensorRT-LLM

---------

Co-authored-by: aotman <chenhangatm@gmail.com>
2025-03-04 18:44:00 +08:00

110 lines
3.1 KiB
Python

import copy
from dataclasses import dataclass, field
from enum import IntEnum, auto
from typing import List, Optional
import torch
class SpeculativeDecodingMode(IntEnum):
MTP = auto()
MEDUSA = auto()
EAGLE = auto()
LOOKAHEAD = auto()
NONE = auto()
def is_mtp(self):
return self == SpeculativeDecodingMode.MTP
def is_medusa(self):
return self == SpeculativeDecodingMode.MEDUSA
def is_eagle(self):
return self == SpeculativeDecodingMode.EAGLE
def is_lookahead(self):
return self == SpeculativeDecodingMode.LOOKAHEAD
def is_none(self):
return self == SpeculativeDecodingMode.NONE
def without_logits(self):
return self.is_mtp()
def needs_kv_cache_rewind(self):
return self.is_mtp() or self.is_eagle() or self.is_lookahead(
) or self.is_medusa()
def support_overlap_scheduler(self):
return self.is_mtp()
@staticmethod
def from_string(name: str):
name_map = {
"MTP": SpeculativeDecodingMode.MTP,
"MEDUSA": SpeculativeDecodingMode.MEDUSA,
"EAGLE": SpeculativeDecodingMode.EAGLE,
"LOOKAHEAD": SpeculativeDecodingMode.LOOKAHEAD,
None: SpeculativeDecodingMode.NONE,
}
return name_map[name]
@dataclass
class SpecConfig:
"""
Configuration for speculative decoding.
"""
# The name of speculative decoding.
spec_dec_name = None
# The mode of speculative decoding.
spec_dec_mode: SpeculativeDecodingMode = SpeculativeDecodingMode.NONE
# The max number of draft tokens
max_draft_tokens: int = 1024
def __post_init__(self) -> None:
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
self.spec_dec_name)
@dataclass
class SpecMetadata:
"""
Metadata for speculative decoding.
"""
# The max number of requests in a single batch.
max_num_requests: int
# The max number of draft tokens.
max_draft_tokens: int
# Whether CUDA graph is enabled.
is_cuda_graph: bool = field(default=False, repr=False)
# The mode of speculative decoding.
spec_dec_mode: SpeculativeDecodingMode = SpeculativeDecodingMode.NONE,
# Draft tokens.
draft_tokens: Optional[torch.Tensor] = None,
# The length of the draft tokens.
draft_lens: Optional[torch.Tensor] = None,
# The request ID of each sequence in the batch.
# The shape is (batch_size).
request_ids: Optional[List[int]] = None
# The gather ids for logits.
gather_ids: Optional[torch.Tensor] = None
def prepare():
"""
Hook to be called before the forward step of the model.
"""
def create_cuda_graph_metadata(self, max_batch_size: int):
"""
Creates metadata for CUDA graph execution.
"""
if self.is_cuda_graph:
return self
cuda_graph_metadata = copy.copy(self)
cuda_graph_metadata.is_cuda_graph = True
cuda_graph_metadata.max_num_requests = max_batch_size
cuda_graph_metadata.__post_init__()
return cuda_graph_metadata