TensorRT-LLMs/tensorrt_llm/_torch/speculative/eagle3.py
Netanel Haber 9cd8148f28
API Breaking Change + Readability: "decoder"->"sampler" (#4121)
* *decoder*->*sampler*; new_tensors_device: dict[str, torch.Tensor] -> device: SampleStateTensors

* **Breaking Change**, as it changes public interfaces, main changes:
* PyTorchConfig [consumed via LLM(pytorch_backend_config)]: Configuration parameters mixed_decoder and enable_trtllm_decoder -> sampler.
* Command-line argument --enable_trtllm_decoder becomes --enable_trtllm_sampler in examples/pytorch/quickstart_advanced.py.

---------

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
2025-05-16 23:52:25 +08:00

133 lines
5.1 KiB
Python

from dataclasses import dataclass, field
from itertools import chain
from typing import Dict, List, Optional, Tuple
import torch
from ..pyexecutor.sampler import SampleState, SampleStateTensors, TorchSampler
from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode
@dataclass
class Eagle3Config(SpecConfig):
spec_dec_name: str = "EAGLE3"
eagle_weights_path: Optional[str] = None
num_layers: int = 0
hidden_size: int = 0
def __post_init__(self):
if self.eagle_weights_path is None:
raise ValueError("Path to EAGLE3 weights must be specified.")
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
self.spec_dec_name)
self.num_extra_kv_tokens = 0
def update_from_model_config(self, model_config):
self.num_layers = model_config.num_hidden_layers
self.hidden_size = model_config.hidden_size
def get_draft_model_prompt(self,
input_tokens: torch.Tensor) -> torch.Tensor:
"""
Eagle3 always throws away the first token when processing draft inputs
"""
return input_tokens[1:]
@dataclass
class Eagle3SpecMetadata(SpecMetadata):
hidden_states: List[torch.Tensor] = field(default_factory=list)
num_layers: int = 0
layers_to_capture: Tuple[int, ...] = field(init=False)
target_model_embed_tokens: Optional[torch.nn.Module] = None
hidden_size: int = 0
def __post_init__(self):
if self.num_layers == 1:
self.layers_to_capture = (0, )
else:
if self.num_layers <= 5:
raise ValueError("Not enough hidden layers for EAGLE")
self.layers_to_capture = (1, self.num_layers // 2 - 1,
self.num_layers - 4)
self.hidden_states = []
if self.is_cuda_graph:
# CUDA graphs need to use the same buffers between runs.
max_seqlen = self.max_num_requests * (self.max_draft_tokens + 1)
hidden_state_shape = (max_seqlen, self.hidden_size)
for layer in self.layers_to_capture:
self.hidden_states.append(
torch.empty(hidden_state_shape, device='cuda'))
def prepare(self):
if not self.is_cuda_graph:
self.hidden_states = []
def maybe_capture_hidden_states(self, layer_id: int,
hidden_states: torch.Tensor,
residual: torch.Tensor) -> None:
if not self.is_cuda_graph:
if layer_id in self.layers_to_capture:
self.hidden_states.append(hidden_states + residual)
else:
assert len(self.hidden_states) == len(self.layers_to_capture)
for i, captured_layer_id in enumerate(self.layers_to_capture):
if captured_layer_id == layer_id:
self.hidden_states[i].copy_(hidden_states + residual)
break
def get_hidden_states(
self,
scheduled_requests,
num_rejected_tokens: Optional[Dict] = None) -> torch.Tensor:
req_id_to_gather_ids = {}
seq_start = 0
for req_id, seqlen in zip(self.request_ids, self.seq_lens):
if num_rejected_tokens is not None:
if req_id in num_rejected_tokens:
req_id_to_gather_ids[req_id] = list(
range(seq_start,
seq_start + seqlen - num_rejected_tokens[req_id]))
else:
req_id_to_gather_ids[req_id] = [seq_start + seqlen - 1]
seq_start += seqlen
hidden_states_gather_ids = []
for req in chain(scheduled_requests.context_requests,
scheduled_requests.generation_requests):
hidden_states_gather_ids.extend(
req_id_to_gather_ids[req.py_request_id])
if len(self.hidden_states) == 1:
return self.hidden_states[0][hidden_states_gather_ids]
else:
# Note that we must call cat() here. We can't have this control
# flow inside the model - that would break CUDA graphs.
return torch.cat(
[h[hidden_states_gather_ids] for h in self.hidden_states],
dim=-1)
class Eagle3Sampler(TorchSampler):
def _batch_sample(self, scheduled_requests, model_outputs) -> SampleState:
logits = model_outputs["logits"]
new_tokens_device = torch.argmax(logits, dim=-1)
if "d2t" in model_outputs:
d2t = model_outputs["d2t"]
new_tokens_device = d2t[new_tokens_device] + new_tokens_device
device = SampleStateTensors(new_tokens=new_tokens_device)
host = SampleStateTensors(
new_tokens=new_tokens_device.to('cpu', non_blocking=True))
sampler_event = torch.cuda.Event()
sampler_event.record()
return SampleState(scheduled_requests=scheduled_requests,
logits=logits,
device=device,
host=host,
sampler_event=sampler_event)