mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-25 21:22:57 +08:00
110 lines
3.1 KiB
Python
110 lines
3.1 KiB
Python
import copy
|
|
from dataclasses import dataclass, field
|
|
from enum import IntEnum, auto
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
|
|
|
|
class SpeculativeDecodingMode(IntEnum):
|
|
MTP = auto()
|
|
MEDUSA = auto()
|
|
EAGLE = auto()
|
|
LOOKAHEAD = auto()
|
|
NONE = auto()
|
|
|
|
def is_mtp(self):
|
|
return self == SpeculativeDecodingMode.MTP
|
|
|
|
def is_medusa(self):
|
|
return self == SpeculativeDecodingMode.MEDUSA
|
|
|
|
def is_eagle(self):
|
|
return self == SpeculativeDecodingMode.EAGLE
|
|
|
|
def is_lookahead(self):
|
|
return self == SpeculativeDecodingMode.LOOKAHEAD
|
|
|
|
def is_none(self):
|
|
return self == SpeculativeDecodingMode.NONE
|
|
|
|
def without_logits(self):
|
|
return self.is_mtp()
|
|
|
|
def needs_kv_cache_rewind(self):
|
|
return self.is_mtp() or self.is_eagle() or self.is_lookahead(
|
|
) or self.is_medusa()
|
|
|
|
def support_overlap_scheduler(self):
|
|
return self.is_mtp()
|
|
|
|
@staticmethod
|
|
def from_string(name: str):
|
|
name_map = {
|
|
"MTP": SpeculativeDecodingMode.MTP,
|
|
"MEDUSA": SpeculativeDecodingMode.MEDUSA,
|
|
"EAGLE": SpeculativeDecodingMode.EAGLE,
|
|
"LOOKAHEAD": SpeculativeDecodingMode.LOOKAHEAD,
|
|
None: SpeculativeDecodingMode.NONE,
|
|
}
|
|
return name_map[name]
|
|
|
|
|
|
@dataclass
|
|
class SpecConfig:
|
|
"""
|
|
Configuration for speculative decoding.
|
|
"""
|
|
# The name of speculative decoding.
|
|
spec_dec_name = None
|
|
# The mode of speculative decoding.
|
|
spec_dec_mode: SpeculativeDecodingMode = SpeculativeDecodingMode.NONE
|
|
# The max number of draft tokens
|
|
max_draft_tokens: int = 1024
|
|
|
|
def __post_init__(self) -> None:
|
|
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
|
|
self.spec_dec_name)
|
|
|
|
|
|
@dataclass
|
|
class SpecMetadata:
|
|
"""
|
|
Metadata for speculative decoding.
|
|
"""
|
|
# The max number of requests in a single batch.
|
|
max_num_requests: int
|
|
# The max number of draft tokens.
|
|
max_draft_tokens: int
|
|
# Whether CUDA graph is enabled.
|
|
is_cuda_graph: bool = field(default=False, repr=False)
|
|
# The mode of speculative decoding.
|
|
spec_dec_mode: SpeculativeDecodingMode = SpeculativeDecodingMode.NONE,
|
|
# Draft tokens.
|
|
draft_tokens: Optional[torch.Tensor] = None,
|
|
# The length of the draft tokens.
|
|
draft_lens: Optional[torch.Tensor] = None,
|
|
# The request ID of each sequence in the batch.
|
|
# The shape is (batch_size).
|
|
request_ids: Optional[List[int]] = None
|
|
# The gather ids for logits.
|
|
gather_ids: Optional[torch.Tensor] = None
|
|
|
|
def prepare():
|
|
"""
|
|
Hook to be called before the forward step of the model.
|
|
"""
|
|
|
|
def create_cuda_graph_metadata(self, max_batch_size: int):
|
|
"""
|
|
Creates metadata for CUDA graph execution.
|
|
"""
|
|
if self.is_cuda_graph:
|
|
return self
|
|
|
|
cuda_graph_metadata = copy.copy(self)
|
|
cuda_graph_metadata.is_cuda_graph = True
|
|
cuda_graph_metadata.max_num_requests = max_batch_size
|
|
cuda_graph_metadata.__post_init__()
|
|
return cuda_graph_metadata
|