from dataclasses import dataclass, field from itertools import chain from typing import Dict, List, Optional, Tuple import torch from torch import nn from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from ..attention_backend import AttentionMetadata from ..pyexecutor.sampler import SampleState, SampleStateTensors, TorchSampler from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode from .mtp import MTPSampler @dataclass class Eagle3Config(SpecConfig): spec_dec_name: str = "EAGLE3" num_layers: int = 0 hidden_size: int = 0 eagle3_one_model: bool = True def __post_init__(self): if self.draft_model_path is None: raise ValueError("Path to EAGLE3 weights must be specified.") if self.eagle3_one_model: self.spec_dec_mode = SpeculativeDecodingMode.EAGLE3_ONE_MODEL self.num_extra_kv_tokens = self.max_draft_tokens - 1 else: self.spec_dec_mode = SpeculativeDecodingMode.from_string( self.spec_dec_name) self.num_extra_kv_tokens = 0 logger.info(f"EAGLE3 Config: {self}") def update_from_model_config(self, model_config): self.num_layers = model_config.num_hidden_layers self.hidden_size = model_config.hidden_size def get_draft_model_prompt(self, input_tokens: torch.Tensor) -> torch.Tensor: """ Eagle3 always throws away the first token when processing draft inputs """ return input_tokens[1:] @dataclass class Eagle3SpecMetadata(SpecMetadata): hidden_states: List[torch.Tensor] = field(default_factory=list) num_layers: int = 0 layers_to_capture: Tuple[int, ...] = field(init=False) target_model_embed_tokens: Optional[torch.nn.Module] = None hidden_size: int = 0 def __post_init__(self): if self.num_layers == 1: self.layers_to_capture = (0, ) else: if self.num_layers <= 5: raise ValueError("Not enough hidden layers for EAGLE") self.layers_to_capture = (1, self.num_layers // 2 - 1, self.num_layers - 4) self.hidden_states = [] if self.is_cuda_graph: # CUDA graphs need to use the same buffers between runs. max_seqlen = self.max_num_requests * (self.max_draft_tokens + 1) hidden_state_shape = (max_seqlen, self.hidden_size) for layer in self.layers_to_capture: self.hidden_states.append( torch.empty(hidden_state_shape, device='cuda')) def prepare(self): if not self.is_cuda_graph: self.hidden_states = [] def is_layer_capture(self, layer_id: int): return layer_id in self.layers_to_capture def maybe_capture_hidden_states(self, layer_id: int, hidden_states: torch.Tensor, residual: torch.Tensor) -> None: if not self.is_cuda_graph: if layer_id in self.layers_to_capture: self.hidden_states.append(hidden_states + residual) else: assert len(self.hidden_states) == len(self.layers_to_capture) for i, captured_layer_id in enumerate(self.layers_to_capture): if captured_layer_id == layer_id: self.hidden_states[i].copy_(hidden_states + residual) break def get_hidden_states( self, scheduled_requests, num_rejected_tokens: Optional[Dict] = None) -> torch.Tensor: req_id_to_gather_ids = {} seq_start = 0 for req_id, seqlen in zip(self.request_ids, self.seq_lens): if num_rejected_tokens is not None: if req_id in num_rejected_tokens: req_id_to_gather_ids[req_id] = list( range(seq_start, seq_start + seqlen - num_rejected_tokens[req_id])) else: req_id_to_gather_ids[req_id] = [seq_start + seqlen - 1] seq_start += seqlen hidden_states_gather_ids = [] for req in chain(scheduled_requests.context_requests, scheduled_requests.generation_requests): hidden_states_gather_ids.extend( req_id_to_gather_ids[req.py_request_id]) if len(self.hidden_states) == 1: return self.hidden_states[0][hidden_states_gather_ids] else: # Note that we must call cat() here. We can't have this control # flow inside the model - that would break CUDA graphs. return torch.cat( [h[hidden_states_gather_ids] for h in self.hidden_states], dim=-1) class Eagle3Sampler(TorchSampler): def _batch_sample(self, scheduled_requests, model_outputs) -> SampleState: logits = model_outputs["logits"] new_tokens_device = torch.argmax(logits, dim=-1) if "d2t" in model_outputs: d2t = model_outputs["d2t"] new_tokens_device = d2t[new_tokens_device] + new_tokens_device device = SampleStateTensors(new_tokens=new_tokens_device) host = SampleStateTensors( new_tokens=new_tokens_device.to('cpu', non_blocking=True)) sampler_event = torch.cuda.Event() sampler_event.record() return SampleState(scheduled_requests=scheduled_requests, logits=logits, device=device, host=host, sampler_event=sampler_event) @dataclass class Eagle3OneModelSpecMetadata(SpecMetadata): # The hidden states hidden_states: Optional[torch.Tensor] = None # The number of layers num_layers: int = 0 # The layers to be captured layers_to_capture: Tuple[int, ...] = field(init=False) # The hidden size of the hidden states hidden_size: int = 0 # The max number of tokens max_num_tokens: int = 0 # The dtype of the hidden states dtype: torch.dtype = torch.bfloat16 # The index of the batche inputs batch_indices_cuda: Optional[torch.Tensor] = None def __post_init__(self): if self.num_layers == 1: self.layers_to_capture = (1, ) else: if self.num_layers <= 5: raise ValueError("Not enough hidden layers for EAGLE") self.layers_to_capture = (1, self.num_layers // 2 - 1, self.num_layers - 4) self.hidden_states = torch.empty( (self.max_num_tokens, self.hidden_size * len(self.layers_to_capture)), dtype=self.dtype, device='cuda') self.batch_indices_cuda = torch.empty( [self.max_num_requests], dtype=torch.int, device='cuda', ) def is_layer_capture(self, layer_id: int): return layer_id in self.layers_to_capture def prepare(self): assert self.request_ids is not None # update batch indeices num_seqs = len(self.request_ids) batch_indices = torch.arange(num_seqs, dtype=torch.int, device='cpu', pin_memory=True) self.batch_indices_cuda[:num_seqs].copy_(batch_indices, non_blocking=True) self.num_tokens -= (self.num_generations) * self.max_draft_tokens def maybe_capture_hidden_states( self, layer_id: int, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None) -> None: for i, captured_layer_id in enumerate(self.layers_to_capture): if captured_layer_id == layer_id: num_tokens = hidden_states.shape[0] to_save = hidden_states + residual if residual is not None else hidden_states self.hidden_states[:num_tokens, i * self.hidden_size:(i + 1) * self.hidden_size].copy_(to_save, non_blocking=True) break class Eagle3Decoder(TorchSampler): def _batch_sample(self, scheduled_requests, model_outputs) -> SampleState: logits = model_outputs["logits"] new_tokens_device = torch.argmax(logits, dim=-1) if "d2t" in model_outputs: d2t = model_outputs["d2t"] new_tokens_device = d2t[new_tokens_device] + new_tokens_device new_tokens_host = new_tokens_device.to('cpu', non_blocking=True) new_tensors_device = {"new_tokens_device": new_tokens_device} new_tensors_host = {"new_tokens_host": new_tokens_host} decoder_event = torch.cuda.Event() decoder_event.record() return SampleState(scheduled_requests=scheduled_requests, logits=logits, new_tensors_device=new_tensors_device, new_tensors_host=new_tensors_host, decoder_event=decoder_event) class Eagle3OneModelDecoder(MTPSampler): def __init__(self, max_seq_len: int, config: Eagle3Config): super().__init__(max_seq_len, None) self.draft_len = config.max_draft_tokens class Eagle3OneModelWorker(nn.Module): def __init__(self, spec_config: Eagle3Config, mapping: Mapping): super().__init__() self.spec_config = spec_config self.max_draft_tokens = self.spec_config.max_draft_tokens self.mapping = mapping @torch.compile(mode="max-autotune-no-cudagraphs") def forward(self, input_ids, position_ids, hidden_states, logits, attn_metadata, spec_metadata, draft_model): batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts num_gens = batch_size - num_contexts raw_logits = logits # Sample and accept tokens accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens( logits, attn_metadata, spec_metadata) # Save the old attn_metadata and spec_metadata if attn_metadata.is_cuda_graph: seq_len = attn_metadata._seq_lens[:batch_size].clone() seq_len_cuda = attn_metadata._seq_lens_cuda[:batch_size].clone() # Prepare inputs for the 1st draft model forward position_ids = position_ids.squeeze(0) last_tokens_idx = torch.cumsum( attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1 inputs = self.prepare_1st_drafter_inputs( input_ids=input_ids, position_ids=position_ids, last_tokens_idx=last_tokens_idx, hidden_states=hidden_states, accepted_tokens=accepted_tokens, attn_metadata=attn_metadata, spec_metadata=spec_metadata, draft_model=draft_model) # Predict draft tokens next_draft_tokens = [] for i in range(self.max_draft_tokens): hidden_states, hidden_states_to_save = draft_model.model(**inputs) if i == 0: start_ids_gen = (spec_metadata.batch_indices_cuda[:num_gens] * (self.max_draft_tokens + 1)).long() gather_ids_gen = (start_ids_gen + num_accepted_tokens[num_contexts:] - 1 + attn_metadata.num_ctx_tokens) gather_ids = torch.concat( [last_tokens_idx[:num_contexts], gather_ids_gen], dim=0) else: # All of the seq_len are 1, use batch_indices_cuda as gather_ids gather_ids = spec_metadata.batch_indices_cuda[:batch_size] logits = draft_model.logits_processor(hidden_states[gather_ids], draft_model.lm_head, attn_metadata, True) new_draft_token = self.draft_decoder(logits, draft_model) next_draft_tokens.append(new_draft_token) # update inputs hidden_states = hidden_states_to_save[gather_ids] position_ids = inputs["position_ids"][gather_ids] + 1 # update attn_metadata if i == 0: attn_metadata._seq_lens[:batch_size].fill_(1) attn_metadata._seq_lens_cuda[:batch_size].fill_(1) attn_metadata.on_update() # cannot run generation if their is no kv cache if inputs["attn_metadata"].kv_cache_manager is not None: attn_metadata.host_request_types[:attn_metadata. num_contexts].fill_(1) attn_metadata.num_contexts = 0 # update kv_lens_cuda if hasattr(attn_metadata, 'kv_lens_cuda'): attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= ( self.max_draft_tokens - num_accepted_tokens[num_contexts:]) attn_metadata.kv_lens_cuda[:num_contexts] += 1 elif hasattr(attn_metadata, 'kv_lens_cuda'): attn_metadata.kv_lens_cuda[:batch_size] += 1 # support attention dp if spec_metadata.all_rank_num_tokens is not None: spec_metadata.all_rank_num_tokens = spec_metadata.all_rank_num_seqs inputs = { "input_ids": new_draft_token, "position_ids": position_ids, "hidden_states": hidden_states, "attn_metadata": attn_metadata, "spec_metadata": spec_metadata, } next_draft_tokens = torch.stack(next_draft_tokens, dim=1) # restore attn_metadata to support cuda graph if attn_metadata.is_cuda_graph: attn_metadata._seq_lens[:batch_size].copy_(seq_len) attn_metadata._seq_lens_cuda[:batch_size].copy_(seq_len_cuda) attn_metadata.on_update() # prepare next new tokens to support overlap scheduler next_new_tokens = accepted_tokens[ spec_metadata.batch_indices_cuda[:batch_size], num_accepted_tokens - 1].unsqueeze(1) next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens], dim=1) return { 'logits': raw_logits, 'new_tokens': accepted_tokens, 'new_tokens_lens': num_accepted_tokens, 'next_draft_tokens': next_draft_tokens, 'next_new_tokens': next_new_tokens, } def sample_and_accept_draft_tokens( self, logits: torch.Tensor, attn_metadata: AttentionMetadata, spec_metadata: Eagle3OneModelSpecMetadata, ): batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts num_gens = batch_size - num_contexts if logits.dim() == 1: logits = logits.unsqueeze(0) # The return buffer accepted_tokens = torch.empty((batch_size, (self.max_draft_tokens + 1)), dtype=torch.int, device=logits.device) num_accepted_tokens = torch.ones(batch_size, dtype=torch.int, device=logits.device) # Do greedy sampling for the input logits target_tokens = torch.argmax(logits, dim=-1) # context accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts] # generation gen_target_tokens = target_tokens[num_contexts:].reshape( num_gens, self.max_draft_tokens + 1) accepted_tokens[num_contexts:, :] = gen_target_tokens draft_tokens = spec_metadata.draft_tokens.reshape( num_gens, self.max_draft_tokens) num_accepted_tokens[num_contexts:] += torch.cumprod(( draft_tokens == gen_target_tokens[:, :self.max_draft_tokens]).int(), dim=-1).sum(1) return accepted_tokens, num_accepted_tokens def draft_decoder( self, logits: torch.Tensor, draft_model: nn.Module, ): ''' Sampling draft tokens. Args: logits: torch.Tensor [num_tokens, vocab_size] Logits produced by the draft model. draft_model: nn.Module The draft model. Returns: draft_tokens: torch.Tensor [batch_size * max_draft_tokens] Draft token ids. Flattened. ''' draft_tokens = torch.argmax(logits, dim=-1).type(torch.int32) # Apply d2t (offsets between draft model dictionary and main model dictionary). if hasattr(draft_model.model, "d2t") and draft_model.model.d2t is not None: draft_tokens = draft_model.model.d2t[draft_tokens] + draft_tokens return draft_tokens def prepare_1st_drafter_inputs( self, input_ids: torch.LongTensor, position_ids: torch.LongTensor, last_tokens_idx: torch.LongTensor, hidden_states: torch.Tensor, accepted_tokens: torch.Tensor, attn_metadata: AttentionMetadata, spec_metadata: Eagle3OneModelSpecMetadata, draft_model: nn.Module, ): num_contexts = attn_metadata.num_contexts num_tokens = input_ids.shape[0] # prepare hidden states hidden_size_up = spec_metadata.hidden_size * len( spec_metadata.layers_to_capture) hidden_states = spec_metadata.hidden_states[:num_tokens, : hidden_size_up] hidden_states = draft_model.apply_eagle3_fc(hidden_states) # context input_ctx_ids = input_ids[:attn_metadata.num_ctx_tokens] input_ids_ctx = torch.empty_like(input_ctx_ids, dtype=torch.int32, device="cuda") input_ids_ctx[:-1].copy_(input_ctx_ids[1:]) input_ids_ctx[ last_tokens_idx[:num_contexts]] = accepted_tokens[:num_contexts, 0] # generation input_ids_gen = accepted_tokens[num_contexts:, :].flatten() # get draft inputs input_ids = torch.concat([input_ids_ctx, input_ids_gen], dim=0) return { "input_ids": input_ids, "position_ids": position_ids, "hidden_states": hidden_states, "attn_metadata": attn_metadata, "spec_metadata": spec_metadata, }