from dataclasses import dataclass, field from itertools import chain from typing import Dict, List, Optional, Tuple import torch from ..pyexecutor.sampler import SampleState, SampleStateTensors, TorchSampler from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode @dataclass class Eagle3Config(SpecConfig): spec_dec_name: str = "EAGLE3" eagle_weights_path: Optional[str] = None num_layers: int = 0 hidden_size: int = 0 def __post_init__(self): if self.eagle_weights_path is None: raise ValueError("Path to EAGLE3 weights must be specified.") self.spec_dec_mode = SpeculativeDecodingMode.from_string( self.spec_dec_name) self.num_extra_kv_tokens = 0 def update_from_model_config(self, model_config): self.num_layers = model_config.num_hidden_layers self.hidden_size = model_config.hidden_size def get_draft_model_prompt(self, input_tokens: torch.Tensor) -> torch.Tensor: """ Eagle3 always throws away the first token when processing draft inputs """ return input_tokens[1:] @dataclass class Eagle3SpecMetadata(SpecMetadata): hidden_states: List[torch.Tensor] = field(default_factory=list) num_layers: int = 0 layers_to_capture: Tuple[int, ...] = field(init=False) target_model_embed_tokens: Optional[torch.nn.Module] = None hidden_size: int = 0 def __post_init__(self): if self.num_layers == 1: self.layers_to_capture = (0, ) else: if self.num_layers <= 5: raise ValueError("Not enough hidden layers for EAGLE") self.layers_to_capture = (1, self.num_layers // 2 - 1, self.num_layers - 4) self.hidden_states = [] if self.is_cuda_graph: # CUDA graphs need to use the same buffers between runs. max_seqlen = self.max_num_requests * (self.max_draft_tokens + 1) hidden_state_shape = (max_seqlen, self.hidden_size) for layer in self.layers_to_capture: self.hidden_states.append( torch.empty(hidden_state_shape, device='cuda')) def prepare(self): if not self.is_cuda_graph: self.hidden_states = [] def maybe_capture_hidden_states(self, layer_id: int, hidden_states: torch.Tensor, residual: torch.Tensor) -> None: if not self.is_cuda_graph: if layer_id in self.layers_to_capture: self.hidden_states.append(hidden_states + residual) else: assert len(self.hidden_states) == len(self.layers_to_capture) for i, captured_layer_id in enumerate(self.layers_to_capture): if captured_layer_id == layer_id: self.hidden_states[i].copy_(hidden_states + residual) break def get_hidden_states( self, scheduled_requests, num_rejected_tokens: Optional[Dict] = None) -> torch.Tensor: req_id_to_gather_ids = {} seq_start = 0 for req_id, seqlen in zip(self.request_ids, self.seq_lens): if num_rejected_tokens is not None: if req_id in num_rejected_tokens: req_id_to_gather_ids[req_id] = list( range(seq_start, seq_start + seqlen - num_rejected_tokens[req_id])) else: req_id_to_gather_ids[req_id] = [seq_start + seqlen - 1] seq_start += seqlen hidden_states_gather_ids = [] for req in chain(scheduled_requests.context_requests, scheduled_requests.generation_requests): hidden_states_gather_ids.extend( req_id_to_gather_ids[req.py_request_id]) if len(self.hidden_states) == 1: return self.hidden_states[0][hidden_states_gather_ids] else: # Note that we must call cat() here. We can't have this control # flow inside the model - that would break CUDA graphs. return torch.cat( [h[hidden_states_gather_ids] for h in self.hidden_states], dim=-1) class Eagle3Sampler(TorchSampler): def _batch_sample(self, scheduled_requests, model_outputs) -> SampleState: logits = model_outputs["logits"] new_tokens_device = torch.argmax(logits, dim=-1) if "d2t" in model_outputs: d2t = model_outputs["d2t"] new_tokens_device = d2t[new_tokens_device] + new_tokens_device device = SampleStateTensors(new_tokens=new_tokens_device) host = SampleStateTensors( new_tokens=new_tokens_device.to('cpu', non_blocking=True)) sampler_event = torch.cuda.Event() sampler_event.record() return SampleState(scheduled_requests=scheduled_requests, logits=logits, device=device, host=host, sampler_event=sampler_event)