TensorRT-LLMs/tensorrt_llm/_torch/speculative/eagle3.py
Yilin Fan 31bb650298
Cherry pick feat/llama4 to main (#4739)
Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>
Signed-off-by: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com>
Co-authored-by: Chenfei Zhang <chenfeiz@nvidia.com>
2025-05-30 05:28:40 +08:00

469 lines
18 KiB
Python

from dataclasses import dataclass, field
from itertools import chain
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from ..attention_backend import AttentionMetadata
from ..pyexecutor.sampler import SampleState, SampleStateTensors, TorchSampler
from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode
from .mtp import MTPSampler
@dataclass
class Eagle3Config(SpecConfig):
spec_dec_name: str = "EAGLE3"
num_layers: int = 0
hidden_size: int = 0
eagle3_one_model: bool = True
def __post_init__(self):
if self.draft_model_path is None:
raise ValueError("Path to EAGLE3 weights must be specified.")
if self.eagle3_one_model:
self.spec_dec_mode = SpeculativeDecodingMode.EAGLE3_ONE_MODEL
self.num_extra_kv_tokens = self.max_draft_tokens - 1
else:
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
self.spec_dec_name)
self.num_extra_kv_tokens = 0
logger.info(f"EAGLE3 Config: {self}")
def update_from_model_config(self, model_config):
self.num_layers = model_config.num_hidden_layers
self.hidden_size = model_config.hidden_size
def get_draft_model_prompt(self,
input_tokens: torch.Tensor) -> torch.Tensor:
"""
Eagle3 always throws away the first token when processing draft inputs
"""
return input_tokens[1:]
@dataclass
class Eagle3SpecMetadata(SpecMetadata):
hidden_states: List[torch.Tensor] = field(default_factory=list)
num_layers: int = 0
layers_to_capture: Tuple[int, ...] = field(init=False)
target_model_embed_tokens: Optional[torch.nn.Module] = None
hidden_size: int = 0
def __post_init__(self):
if self.num_layers == 1:
self.layers_to_capture = (0, )
else:
if self.num_layers <= 5:
raise ValueError("Not enough hidden layers for EAGLE")
self.layers_to_capture = (1, self.num_layers // 2 - 1,
self.num_layers - 4)
self.hidden_states = []
if self.is_cuda_graph:
# CUDA graphs need to use the same buffers between runs.
max_seqlen = self.max_num_requests * (self.max_draft_tokens + 1)
hidden_state_shape = (max_seqlen, self.hidden_size)
for layer in self.layers_to_capture:
self.hidden_states.append(
torch.empty(hidden_state_shape, device='cuda'))
def prepare(self):
if not self.is_cuda_graph:
self.hidden_states = []
def is_layer_capture(self, layer_id: int):
return layer_id in self.layers_to_capture
def maybe_capture_hidden_states(self, layer_id: int,
hidden_states: torch.Tensor,
residual: torch.Tensor) -> None:
if not self.is_cuda_graph:
if layer_id in self.layers_to_capture:
self.hidden_states.append(hidden_states + residual)
else:
assert len(self.hidden_states) == len(self.layers_to_capture)
for i, captured_layer_id in enumerate(self.layers_to_capture):
if captured_layer_id == layer_id:
self.hidden_states[i].copy_(hidden_states + residual)
break
def get_hidden_states(
self,
scheduled_requests,
num_rejected_tokens: Optional[Dict] = None) -> torch.Tensor:
req_id_to_gather_ids = {}
seq_start = 0
for req_id, seqlen in zip(self.request_ids, self.seq_lens):
if num_rejected_tokens is not None:
if req_id in num_rejected_tokens:
req_id_to_gather_ids[req_id] = list(
range(seq_start,
seq_start + seqlen - num_rejected_tokens[req_id]))
else:
req_id_to_gather_ids[req_id] = [seq_start + seqlen - 1]
seq_start += seqlen
hidden_states_gather_ids = []
for req in chain(scheduled_requests.context_requests,
scheduled_requests.generation_requests):
hidden_states_gather_ids.extend(
req_id_to_gather_ids[req.py_request_id])
if len(self.hidden_states) == 1:
return self.hidden_states[0][hidden_states_gather_ids]
else:
# Note that we must call cat() here. We can't have this control
# flow inside the model - that would break CUDA graphs.
return torch.cat(
[h[hidden_states_gather_ids] for h in self.hidden_states],
dim=-1)
class Eagle3Sampler(TorchSampler):
def _batch_sample(self, scheduled_requests, model_outputs) -> SampleState:
logits = model_outputs["logits"]
new_tokens_device = torch.argmax(logits, dim=-1)
if "d2t" in model_outputs:
d2t = model_outputs["d2t"]
new_tokens_device = d2t[new_tokens_device] + new_tokens_device
device = SampleStateTensors(new_tokens=new_tokens_device)
host = SampleStateTensors(
new_tokens=new_tokens_device.to('cpu', non_blocking=True))
sampler_event = torch.cuda.Event()
sampler_event.record()
return SampleState(scheduled_requests=scheduled_requests,
logits=logits,
device=device,
host=host,
sampler_event=sampler_event)
@dataclass
class Eagle3OneModelSpecMetadata(SpecMetadata):
# The hidden states
hidden_states: Optional[torch.Tensor] = None
# The number of layers
num_layers: int = 0
# The layers to be captured
layers_to_capture: Tuple[int, ...] = field(init=False)
# The hidden size of the hidden states
hidden_size: int = 0
# The max number of tokens
max_num_tokens: int = 0
# The dtype of the hidden states
dtype: torch.dtype = torch.bfloat16
# The index of the batche inputs
batch_indices_cuda: Optional[torch.Tensor] = None
def __post_init__(self):
if self.num_layers == 1:
self.layers_to_capture = (1, )
else:
if self.num_layers <= 5:
raise ValueError("Not enough hidden layers for EAGLE")
self.layers_to_capture = (1, self.num_layers // 2 - 1,
self.num_layers - 4)
self.hidden_states = torch.empty(
(self.max_num_tokens,
self.hidden_size * len(self.layers_to_capture)),
dtype=self.dtype,
device='cuda')
self.batch_indices_cuda = torch.empty(
[self.max_num_requests],
dtype=torch.int,
device='cuda',
)
def is_layer_capture(self, layer_id: int):
return layer_id in self.layers_to_capture
def prepare(self):
assert self.request_ids is not None
# update batch indeices
num_seqs = len(self.request_ids)
batch_indices = torch.arange(num_seqs,
dtype=torch.int,
device='cpu',
pin_memory=True)
self.batch_indices_cuda[:num_seqs].copy_(batch_indices,
non_blocking=True)
self.num_tokens -= (self.num_generations) * self.max_draft_tokens
def maybe_capture_hidden_states(
self,
layer_id: int,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] = None) -> None:
for i, captured_layer_id in enumerate(self.layers_to_capture):
if captured_layer_id == layer_id:
num_tokens = hidden_states.shape[0]
to_save = hidden_states + residual if residual is not None else hidden_states
self.hidden_states[:num_tokens, i * self.hidden_size:(i + 1) *
self.hidden_size].copy_(to_save,
non_blocking=True)
break
class Eagle3Decoder(TorchSampler):
def _batch_sample(self, scheduled_requests, model_outputs) -> SampleState:
logits = model_outputs["logits"]
new_tokens_device = torch.argmax(logits, dim=-1)
if "d2t" in model_outputs:
d2t = model_outputs["d2t"]
new_tokens_device = d2t[new_tokens_device] + new_tokens_device
new_tokens_host = new_tokens_device.to('cpu', non_blocking=True)
new_tensors_device = {"new_tokens_device": new_tokens_device}
new_tensors_host = {"new_tokens_host": new_tokens_host}
decoder_event = torch.cuda.Event()
decoder_event.record()
return SampleState(scheduled_requests=scheduled_requests,
logits=logits,
new_tensors_device=new_tensors_device,
new_tensors_host=new_tensors_host,
decoder_event=decoder_event)
class Eagle3OneModelDecoder(MTPSampler):
def __init__(self, max_seq_len: int, config: Eagle3Config):
super().__init__(max_seq_len, None)
self.draft_len = config.max_draft_tokens
class Eagle3OneModelWorker(nn.Module):
def __init__(self, spec_config: Eagle3Config, mapping: Mapping):
super().__init__()
self.spec_config = spec_config
self.max_draft_tokens = self.spec_config.max_draft_tokens
self.mapping = mapping
@torch.compile(mode="max-autotune-no-cudagraphs")
def forward(self, input_ids, position_ids, hidden_states, logits,
attn_metadata, spec_metadata, draft_model):
batch_size = attn_metadata.num_seqs
num_contexts = attn_metadata.num_contexts
num_gens = batch_size - num_contexts
raw_logits = logits
# Sample and accept tokens
accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens(
logits, attn_metadata, spec_metadata)
# Save the old attn_metadata and spec_metadata
if attn_metadata.is_cuda_graph:
seq_len = attn_metadata._seq_lens[:batch_size].clone()
seq_len_cuda = attn_metadata._seq_lens_cuda[:batch_size].clone()
# Prepare inputs for the 1st draft model forward
position_ids = position_ids.squeeze(0)
last_tokens_idx = torch.cumsum(
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
inputs = self.prepare_1st_drafter_inputs(
input_ids=input_ids,
position_ids=position_ids,
last_tokens_idx=last_tokens_idx,
hidden_states=hidden_states,
accepted_tokens=accepted_tokens,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata,
draft_model=draft_model)
# Predict draft tokens
next_draft_tokens = []
for i in range(self.max_draft_tokens):
hidden_states, hidden_states_to_save = draft_model.model(**inputs)
if i == 0:
start_ids_gen = (spec_metadata.batch_indices_cuda[:num_gens] *
(self.max_draft_tokens + 1)).long()
gather_ids_gen = (start_ids_gen +
num_accepted_tokens[num_contexts:] - 1 +
attn_metadata.num_ctx_tokens)
gather_ids = torch.concat(
[last_tokens_idx[:num_contexts], gather_ids_gen], dim=0)
else:
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]
logits = draft_model.logits_processor(hidden_states[gather_ids],
draft_model.lm_head,
attn_metadata, True)
new_draft_token = self.draft_decoder(logits, draft_model)
next_draft_tokens.append(new_draft_token)
# update inputs
hidden_states = hidden_states_to_save[gather_ids]
position_ids = inputs["position_ids"][gather_ids] + 1
# update attn_metadata
if i == 0:
attn_metadata._seq_lens[:batch_size].fill_(1)
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
attn_metadata.on_update()
# cannot run generation if their is no kv cache
if inputs["attn_metadata"].kv_cache_manager is not None:
attn_metadata.host_request_types[:attn_metadata.
num_contexts].fill_(1)
attn_metadata.num_contexts = 0
# update kv_lens_cuda
if hasattr(attn_metadata, 'kv_lens_cuda'):
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
self.max_draft_tokens -
num_accepted_tokens[num_contexts:])
attn_metadata.kv_lens_cuda[:num_contexts] += 1
elif hasattr(attn_metadata, 'kv_lens_cuda'):
attn_metadata.kv_lens_cuda[:batch_size] += 1
# support attention dp
if spec_metadata.all_rank_num_tokens is not None:
spec_metadata.all_rank_num_tokens = spec_metadata.all_rank_num_seqs
inputs = {
"input_ids": new_draft_token,
"position_ids": position_ids,
"hidden_states": hidden_states,
"attn_metadata": attn_metadata,
"spec_metadata": spec_metadata,
}
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
# restore attn_metadata to support cuda graph
if attn_metadata.is_cuda_graph:
attn_metadata._seq_lens[:batch_size].copy_(seq_len)
attn_metadata._seq_lens_cuda[:batch_size].copy_(seq_len_cuda)
attn_metadata.on_update()
# prepare next new tokens to support overlap scheduler
next_new_tokens = accepted_tokens[
spec_metadata.batch_indices_cuda[:batch_size],
num_accepted_tokens - 1].unsqueeze(1)
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens],
dim=1)
return {
'logits': raw_logits,
'new_tokens': accepted_tokens,
'new_tokens_lens': num_accepted_tokens,
'next_draft_tokens': next_draft_tokens,
'next_new_tokens': next_new_tokens,
}
def sample_and_accept_draft_tokens(
self,
logits: torch.Tensor,
attn_metadata: AttentionMetadata,
spec_metadata: Eagle3OneModelSpecMetadata,
):
batch_size = attn_metadata.num_seqs
num_contexts = attn_metadata.num_contexts
num_gens = batch_size - num_contexts
if logits.dim() == 1:
logits = logits.unsqueeze(0)
# The return buffer
accepted_tokens = torch.empty((batch_size, (self.max_draft_tokens + 1)),
dtype=torch.int,
device=logits.device)
num_accepted_tokens = torch.ones(batch_size,
dtype=torch.int,
device=logits.device)
# Do greedy sampling for the input logits
target_tokens = torch.argmax(logits, dim=-1)
# context
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]
# generation
gen_target_tokens = target_tokens[num_contexts:].reshape(
num_gens, self.max_draft_tokens + 1)
accepted_tokens[num_contexts:, :] = gen_target_tokens
draft_tokens = spec_metadata.draft_tokens.reshape(
num_gens, self.max_draft_tokens)
num_accepted_tokens[num_contexts:] += torch.cumprod((
draft_tokens == gen_target_tokens[:, :self.max_draft_tokens]).int(),
dim=-1).sum(1)
return accepted_tokens, num_accepted_tokens
def draft_decoder(
self,
logits: torch.Tensor,
draft_model: nn.Module,
):
'''
Sampling draft tokens.
Args:
logits: torch.Tensor
[num_tokens, vocab_size]
Logits produced by the draft model.
draft_model: nn.Module
The draft model.
Returns:
draft_tokens: torch.Tensor
[batch_size * max_draft_tokens]
Draft token ids. Flattened.
'''
draft_tokens = torch.argmax(logits, dim=-1).type(torch.int32)
# Apply d2t (offsets between draft model dictionary and main model dictionary).
if hasattr(draft_model.model,
"d2t") and draft_model.model.d2t is not None:
draft_tokens = draft_model.model.d2t[draft_tokens] + draft_tokens
return draft_tokens
def prepare_1st_drafter_inputs(
self,
input_ids: torch.LongTensor,
position_ids: torch.LongTensor,
last_tokens_idx: torch.LongTensor,
hidden_states: torch.Tensor,
accepted_tokens: torch.Tensor,
attn_metadata: AttentionMetadata,
spec_metadata: Eagle3OneModelSpecMetadata,
draft_model: nn.Module,
):
num_contexts = attn_metadata.num_contexts
num_tokens = input_ids.shape[0]
# prepare hidden states
hidden_size_up = spec_metadata.hidden_size * len(
spec_metadata.layers_to_capture)
hidden_states = spec_metadata.hidden_states[:num_tokens, :
hidden_size_up]
hidden_states = draft_model.apply_eagle3_fc(hidden_states)
# context
input_ctx_ids = input_ids[:attn_metadata.num_ctx_tokens]
input_ids_ctx = torch.empty_like(input_ctx_ids,
dtype=torch.int32,
device="cuda")
input_ids_ctx[:-1].copy_(input_ctx_ids[1:])
input_ids_ctx[
last_tokens_idx[:num_contexts]] = accepted_tokens[:num_contexts, 0]
# generation
input_ids_gen = accepted_tokens[num_contexts:, :].flatten()
# get draft inputs
input_ids = torch.concat([input_ids_ctx, input_ids_gen], dim=0)
return {
"input_ids": input_ids,
"position_ids": position_ids,
"hidden_states": hidden_states,
"attn_metadata": attn_metadata,
"spec_metadata": spec_metadata,
}