From 585fbb2734d4b50f4dd5aa03cc128d272aaa8b57 Mon Sep 17 00:00:00 2001 From: gramnarayan <105831528+govind-ramnarayan@users.noreply.github.com> Date: Mon, 2 Feb 2026 09:51:10 -0800 Subject: [PATCH] [#10826][feat] AutoDeploy: Eagle One-Model [2/n]: Prefill-Only Implementation (#11073) Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com> --- .../models/custom/modeling_eagle.py | 521 ++++++++++- .../_torch/auto_deploy/models/eagle.py | 35 +- .../examples/test_ad_speculative_decoding.py | 817 +++++++++++++++++- .../test_lists/test-db/l0_h100.yml | 2 + .../unit/singlegpu/models/test_eagle.py | 92 +- .../singlegpu/test_ad_speculative_decoding.py | 3 + 6 files changed, 1376 insertions(+), 94 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py index c16a6750d6..042a3cebbc 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py @@ -22,14 +22,59 @@ This file contains model definitions used for executing Eagle3 speculative decod """ from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn -from transformers import PreTrainedModel +from transformers import PretrainedConfig, PreTrainedModel from transformers.activations import ACT2FN from transformers.utils import ModelOutput +from ...utils._config import deep_merge_dicts +from ...utils.logger import ad_logger + + +class EagleConfig(PretrainedConfig): + """Config for Eagle3 drafter models. + + Extends PretrainedConfig with Eagle-specific parameters while preserving + all base model config values. + + Args: + config: Base config for the draft model from its config.json. + model_type: The base model type (e.g., "llama") used to look up defaults. + """ + + # Map model_type -> default Eagle config values + _drafter_defaults: Dict[str, Dict[str, Any]] = { + "llama": { + "load_embedding_from_target": True, + "load_lm_head_from_target": False, + "num_capture_layers": 3, + }, + } + + def __init__(self, config: PretrainedConfig, model_type: str): + if model_type not in self._drafter_defaults: + raise ValueError( + f"Unsupported model_type '{model_type}' for EagleConfig. " + f"Supported types: {list(self._drafter_defaults.keys())}" + ) + + defaults = self._drafter_defaults[model_type] + config_dict = config.to_dict() + + # Log when config overrides a default + for key, value in defaults.items(): + if key in config_dict and config_dict[key] != value: + ad_logger.info( + f"EagleConfig: config has '{key}={config_dict[key]}', " + f"overriding default '{value}'" + ) + + merged = deep_merge_dicts(defaults, config_dict) + super().__init__(**merged) + class LlamaRotaryEmbedding(nn.Module): def __init__(self, config, dim, device=None, scaling_factor=1.0): @@ -247,6 +292,7 @@ class Eagle3DecoderLayer(nn.Module): def __init__(self, config, layer_idx: int = 0): super().__init__() + self.dtype = config.torch_dtype self.self_attn = Eagle3Attention(config, layer_idx=layer_idx) self.hidden_norm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -287,7 +333,14 @@ class Eagle3Model(nn.Module): def __init__(self, config): super().__init__() - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.dtype = config.torch_dtype + + load_embedding_from_target = getattr(config, "load_embedding_from_target", False) + self.embed_tokens = ( + None + if load_embedding_from_target + else nn.Embedding(config.vocab_size, config.hidden_size) + ) if config.draft_vocab_size is not None and config.draft_vocab_size != config.vocab_size: # Vocab mappings for draft <-> target token conversion @@ -299,13 +352,17 @@ class Eagle3Model(nn.Module): requires_grad=False, ) - # Input feature fusion: 3 * hidden_size -> hidden_size for Eagle3. - # TODO: Can make this configurable based on number of capture layers. - self.fc = nn.Linear( - config.hidden_size * 3, - config.hidden_size, - bias=getattr(config, "bias", False), - dtype=config.torch_dtype, + # Hidden size compression for target hidden states. + # Assumption: No feedforward fusion needed if we have just one capture layer (valid for MTPEagle) + self.fc = ( + nn.Linear( + config.hidden_size * config.num_capture_layers, + config.hidden_size, + bias=getattr(config, "bias", False), + dtype=self.dtype, + ) + if config.num_capture_layers > 1 + else None ) self.head_dim = getattr( @@ -328,12 +385,10 @@ class Eagle3Model(nn.Module): # Assumption: The hidden states are already fused if necessary def forward( self, - input_ids: torch.LongTensor, + inputs_embeds: torch.Tensor, position_ids: torch.LongTensor, hidden_states: torch.Tensor, ) -> torch.Tensor: - input_embeds = self.embed_tokens(input_ids) - cos, sin = self.rotary_emb(hidden_states, position_ids) position_embeds = (cos, sin) @@ -341,25 +396,23 @@ class Eagle3Model(nn.Module): for layer in self.midlayer: hidden_states = layer( hidden_states=hidden_states, - embeds=input_embeds, + embeds=inputs_embeds, position_embeds=position_embeds, ) else: hidden_states = self.midlayer( hidden_states=hidden_states, - embeds=input_embeds, + embeds=inputs_embeds, position_embeds=position_embeds, ) return hidden_states - def apply_eagle3_fc(self, target_hidden_states: torch.Tensor) -> torch.Tensor: - return self.fc(target_hidden_states) - @dataclass class Eagle3DraftOutput(ModelOutput): logits: Optional[torch.FloatTensor] = None + norm_hidden_state: Optional[torch.FloatTensor] = None last_hidden_state: Optional[torch.FloatTensor] = None @@ -385,16 +438,24 @@ class Eagle3DrafterForCausalLM(PreTrainedModel): def __init__(self, config): super().__init__(config) + + self.load_embedding_from_target = getattr(config, "load_embedding_from_target", False) + self.load_lm_head_from_target = getattr(config, "load_lm_head_from_target", False) + self.model = Eagle3Model(config) self.norm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.lm_head = nn.Linear(config.hidden_size, config.draft_vocab_size, bias=False) + self.lm_head = ( + None + if self.load_lm_head_from_target + else nn.Linear(config.hidden_size, config.draft_vocab_size, bias=False) + ) eagle_config = getattr(config, "eagle_config", {}) self._return_hidden_post_norm = eagle_config.get("return_hidden_post_norm", False) def forward( self, - input_ids: torch.LongTensor, + inputs_embeds: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None, **kwargs, ) -> Eagle3DraftOutput: @@ -405,8 +466,8 @@ class Eagle3DrafterForCausalLM(PreTrainedModel): Raises: ValueError: If hidden_states is not provided in kwargs. """ - batch_size, seq_len = input_ids.shape - device = input_ids.device + batch_size, seq_len, _ = inputs_embeds.shape + device = inputs_embeds.device # Generate position_ids if not provided if position_ids is None: @@ -418,17 +479,419 @@ class Eagle3DrafterForCausalLM(PreTrainedModel): raise ValueError("hidden_states must be provided.") hidden_states = self.model( - input_ids=input_ids, - position_ids=position_ids, - hidden_states=hidden_states, + inputs_embeds=inputs_embeds, position_ids=position_ids, hidden_states=hidden_states ) - norm_hidden_states = self.norm(hidden_states) - logits = self.lm_head(norm_hidden_states) + norm_hidden_state = self.norm(hidden_states) - last_hidden_state = norm_hidden_states if self._return_hidden_post_norm else hidden_states + last_hidden_state = norm_hidden_state if self._return_hidden_post_norm else hidden_states return Eagle3DraftOutput( - logits=logits, + norm_hidden_state=norm_hidden_state, last_hidden_state=last_hidden_state, ) + + def get_input_embeddings(self): + if self.model.embed_tokens is not None: + return self.model.embed_tokens + else: + raise NotImplementedError( + "Eagle3DrafterForCausalLM does not have an input embedding layer." + ) + + def get_output_embeddings(self): + if self.lm_head is not None: + return self.lm_head + else: + raise NotImplementedError( + "Eagle3DrafterForCausalLM does not have an output embedding layer." + ) + + +@dataclass +class EagleWrapperOutput(ModelOutput): + """Output format compatible with Eagle3OneModelSampler/MTPSampler. + + This output format allows the one-model speculative decoding flow to bypass + logits-based sampling in the sampler. The EagleWrapper performs all sampling + and verification internally, returning pre-computed tokens. + """ + + # logits: [batch_size, 1, vocab_size]. Used for compatibility. + logits: Optional[torch.Tensor] = None + + # new_tokens: [batch_size, max_draft_len + 1]. Accepted tokens from verification. + # This is a 2D tensor where each row contains the accepted tokens for a request, + # padded if fewer tokens were accepted. + new_tokens: Optional[torch.Tensor] = None + + # new_tokens_lens: [batch_size]. Number of newly accepted tokens per request in this iteration. + new_tokens_lens: Optional[torch.Tensor] = None + + # next_draft_tokens: [batch_size, max_draft_len]. Draft tokens for the next iteration. + # These are the tokens predicted by the draft model, already converted via d2t. + next_draft_tokens: Optional[torch.Tensor] = None + + # next_new_tokens: [batch_size, max_draft_len + 1]. Input tokens for the next iteration. + # Format: [last_accepted_token, draft_token_0, draft_token_1, ...] + next_new_tokens: Optional[torch.Tensor] = None + + +@dataclass +class EagleWrapperConfig: + max_draft_len: int + load_embedding_from_target: bool + load_lm_head_from_target: bool + + +class EagleWrapper(nn.Module): + def __init__(self, config, target_model, draft_model, resource_manager): + super().__init__() + self.target_model = target_model + self.draft_model = draft_model + self.resource_manager = resource_manager + self.max_draft_len = config.max_draft_len + self.load_embedding_from_target = config.load_embedding_from_target + self.load_lm_head_from_target = config.load_lm_head_from_target + + def apply_eagle3_fc(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Apply the fc layer that fuses hidden states from multiple target layers.""" + draft_model = self.draft_model.model + hidden_states = hidden_states.to(draft_model.dtype) + + fc = getattr(draft_model, "fc", None) + if fc is not None: + hidden_states = fc(hidden_states) + return hidden_states + + def apply_d2t(self, draft_output_ids: torch.Tensor) -> torch.Tensor: + """Apply draft-to-target token mapping if available.""" + d2t = getattr(self.draft_model.model, "d2t", None) + if d2t is not None: + draft_output_ids = d2t[draft_output_ids] + draft_output_ids + return draft_output_ids + + def apply_draft_embedding(self, input_ids: torch.Tensor) -> torch.Tensor: + """Apply embedding to input_ids for the draft model.""" + if self.load_embedding_from_target: + embeds = self.target_model.get_input_embeddings()(input_ids) + return embeds.to(self.draft_model.dtype) + else: + return self.draft_model.get_input_embeddings()(input_ids) + + def apply_lm_head(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Apply lm_head to get logits from hidden states.""" + if self.load_lm_head_from_target: + lm_head_weights = self.target_model.get_output_embeddings()(hidden_states) + return lm_head_weights.to(self.draft_model.dtype) + else: + return self.draft_model.get_output_embeddings()(hidden_states) + + def sample_greedy(self, logits: torch.Tensor) -> torch.Tensor: + ret = torch.argmax(logits, dim=-1) + return ret + + def sample_and_verify( + self, input_ids, target_logits: torch.Tensor, num_previously_accepted: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + input_ids: [batch_size, seq_len] + target_logits: [batch_size, seq_len, vocab_size] + num_previously_accepted: [batch_size]. Number of input tokens accepted so far for each batch. + + Returns: + output_ids: [batch_size, seq_len] (result of greedy sampling on input ids) + num_newly_accepted_tokens: [batch_size]. Number of newly accepted tokens in each batch. + num_accepted_tokens: [batch_size]. Number of tokens accepted in each batch, including previously accepted. + So num_accepted_tokens[i] = num_previously_accepted + num_newly_accepted_tokens. + last_logits_3d: [batch_size, 1, vocab_size]. The logit used to sample the bonus token. + + How it works: + - Get input ids that were not previously accepted. + - Get the corresponding target logits to these input ids (target_logit[j-1] corresponds to input_ids[j]) + - Sample a token from the logits for each batch and compare to input_ids to get the newly accepted tokens. + - The output_ids consist of the previously accepted tokens, the newly accepted tokens, + and a newly sampled token after the last accepted token. + """ + + batch_size, seq_len = input_ids.shape + + # First, check that num_previously_accepted is <= seq_len for each batch + # Additionally, num_previously_accepted should be >= 1 for each batch, + # which corresponds to having some context tokens (context tokens are always accepted). + assert (num_previously_accepted >= 1).all(), ( + "num_previously_accepted must be >= 1. Please provide non-empty context in each batch." + ) + assert (num_previously_accepted <= seq_len).all(), ( + "num_previously_accepted must be <= seq_len for each batch" + ) + + # We get input tokens that were not yet accepted. + unchecked_input_ids = [ + input_ids[i, num_previously_accepted[i] : seq_len].unsqueeze(0) + for i in range(batch_size) + ] + + # We get the corresponding target logits for the unchecked input tokens. + # logit j-1 corresponds to input j. + # Note that because of our check that num_previously_accepted is >= 1 + # We also get the last output token for each batch, because we may need to append it + # at the end. + unchecked_target_logits = [ + target_logits[i, (num_previously_accepted[i] - 1) : seq_len, :].unsqueeze(0) + for i in range(batch_size) + ] + + unchecked_output_ids = [self.sample_greedy(x) for x in unchecked_target_logits] + + # corresponding_output_ids: [batch_size, seq_len - 1]. The output ids that correspond to the unchecked input ids + # Omits the last index because that corresponds to a freshly sampled output model. + corresponding_output_ids = [output_id[:, :-1] for output_id in unchecked_output_ids] + + # After sample_greedy, corresponding_output_ids should have same shape as unchecked_input_ids + assert [x.shape for x in unchecked_input_ids] == [ + x.shape for x in corresponding_output_ids + ], "unchecked_input_ids and corresponding_output_ids must have the same shape" + + matches = [ + (corresponding_output_ids[i] == unchecked_input_ids[i]).int() for i in range(batch_size) + ] + + # Compute num_newly_accepted_tokens per batch (handles different sizes across batches) + num_newly_accepted_tokens = [] + for i in range(batch_size): + if matches[i].numel() == 0: + # No unchecked tokens for this batch (num_previously_accepted == seq_len) + num_newly_accepted_tokens.append( + torch.tensor(0, dtype=torch.long, device=input_ids.device) + ) + else: + # prefix_matches[j] is 1 if first j+1 tokens all matched + prefix_matches = matches[i].cumprod(dim=-1) + num_newly_accepted_tokens.append(prefix_matches.sum().long()) + num_newly_accepted_tokens = torch.stack(num_newly_accepted_tokens) + + # num_accepted_tokens: [batch_size]. The total number of accepted tokens in each batch, + # including previously accepted tokens. + num_accepted_tokens = num_previously_accepted + num_newly_accepted_tokens + + assert (num_accepted_tokens <= seq_len).all(), ( + "num_accepted_tokens must be <= seq_len for each batch" + ) + + # Construct draft_input_ids for the draft model + # For each sequence: + # 1. Take previously accepted tokens (skipping the first one) + # 2. Append newly accepted tokens directly from input_ids. + # 3. Append the sampled token for last accepted position: unchecked_output_ids[0][num_newly_accepted] + # 4. Fill the rest with zeros (padding) + # Total real tokens: (num_previously_accepted - 1) + num_newly_accepted + 1 = num_accepted_tokens + + draft_input_ids = torch.zeros( + (batch_size, seq_len), dtype=input_ids.dtype, device=input_ids.device + ) + + for i in range(batch_size): + # 1. Previously accepted tokens (skip the first one in keeping with Eagle convention) + # Note that this potentially includes context tokens, but is structured this way because we + # want the output to contain the entire prefix of accepted tokens because the drafters have no KV cache. + prev_accepted = input_ids[i, 1 : num_previously_accepted[i]] + + # 2. Newly accepted input tokens + newly_accepted = input_ids[ + i, + num_previously_accepted[i] : num_previously_accepted[i] + + num_newly_accepted_tokens[i], + ] + + # 3. The sampled output token for the last accepted position + # unchecked_output_ids[i][j] is the sampled token for position (num_previously_accepted + j) + # We want the token for position num_accepted_tokens, which is index num_newly_accepted_tokens + next_token = unchecked_output_ids[i][0][num_newly_accepted_tokens[i]].unsqueeze(0) + + # Concatenate all parts + draft_prefix = torch.cat([prev_accepted, newly_accepted, next_token]) + + # Sanity check: draft_prefix length should equal num_accepted_tokens + assert draft_prefix.shape[0] == num_accepted_tokens[i], ( + f"draft_prefix length {draft_prefix.shape[0]} != num_accepted_tokens {num_accepted_tokens[i]}" + ) + + # Fill into draft_input_ids (rest remains zeros as padding) + draft_input_ids[i, : num_accepted_tokens[i]] = draft_prefix + + # Construct last_logits_3d: [batch_size, 1, vocab_size] + # This is the logit used to sample the bonus token for each sequence. + # The bonus token is sampled from unchecked_target_logits[i][0][num_newly_accepted_tokens[i]] + last_logits_list = [] + for i in range(batch_size): + # unchecked_target_logits[i] has shape [1, num_unchecked + 1, vocab_size] + # Index num_newly_accepted_tokens[i] gives the logit for the bonus token + bonus_logit = unchecked_target_logits[i][0, num_newly_accepted_tokens[i], :].unsqueeze( + 0 + ) + last_logits_list.append(bonus_logit) + last_logits_3d = torch.stack(last_logits_list, dim=0) # [batch_size, 1, vocab_size] + + return draft_input_ids, num_newly_accepted_tokens, num_accepted_tokens, last_logits_3d + + def forward(self, input_ids, position_ids, **kwargs): + """Dispatch to appropriate forward implementation based on kwargs. + + If num_previously_accepted is provided, use the prefill-only (no KV cache) implementation. + Otherwise, this is the KV cache case which is not yet implemented. + """ + num_previously_accepted = kwargs.get("num_previously_accepted", None) + + if num_previously_accepted is not None: + return self._forward_prefill_only(input_ids, position_ids, **kwargs) + else: + # KV cached case - not implemented yet + raise NotImplementedError( + "EagleWrapper forward with KV cache is not implemented. " + "This code path is reached when num_previously_accepted is not provided in kwargs." + ) + + def _forward_prefill_only(self, input_ids, position_ids, **kwargs): + """Forward pass without KV cache (prefill-only mode). + + This is the original implementation that recomputes all attention + from scratch on every forward call. + """ + batch_size, seq_len = input_ids.shape + device = input_ids.device + + num_previously_accepted = kwargs.get("num_previously_accepted", None) + if num_previously_accepted is None: + raise ValueError("num_previously_accepted must be provided for prefill-only mode.") + + # Compute embeddings using the target embedding layer + input_embeds = self.target_model.get_input_embeddings()(input_ids) + + # target_logits: [batch_size, seq_len, vocab_size] + # Pass embeddings to target model instead of input_ids + target_logits = self.target_model( + inputs_embeds=input_embeds, position_ids=position_ids + ).logits + + # output_ids: [batch_size, seq_len]. Contains a prefix of accepted tokens from the target model, + # a generated token from the target model, and some padding to fill out the tensor. + # num_accepted_tokens: [batch_size]. The number of accepted tokens in each batch. + # num_newly_accepted_tokens: [batch_size]. The number of newly accepted tokens in each batch. + + output_ids, num_newly_accepted_tokens, num_accepted_tokens, _ = self.sample_and_verify( + input_ids, target_logits, num_previously_accepted + ) + + # Get hidden states from the resource manager + # resource_manager.hidden_states is [max_tokens, hidden_size * num_capture_layers] (flattened) + # We slice to get [batch_size * seq_len, hidden_size * num_capture_layers] + hidden_states = self.resource_manager.hidden_states[: (batch_size * seq_len), :] + + # Apply eagle3 fc to reduce hidden size. + # Note: Since we are in prefill-only mode, this is extremely wasteful - we will apply the eagle3 fc layer + # to hidden states that we have applied it to previously. But, this is generally the case in prefill-only mode. + # Input: [batch_size * seq_len, hidden_size * num_capture_layers] + # Output: [batch_size * seq_len, hidden_size] + hidden_states = self.apply_eagle3_fc(hidden_states) + + # Reshape from [batch_size * seq_len, hidden_size] to [batch_size, seq_len, hidden_size] + hidden_size = hidden_states.shape[-1] + hidden_states = hidden_states.view(batch_size, seq_len, hidden_size) + + # Create a working buffer for the drafting loop in [batch, seq + draft_len, hidden] format. + # This is separate from resource_manager.hidden_states which remains in flattened format. + all_hidden_states = torch.zeros( + (batch_size, seq_len + self.max_draft_len, hidden_size), + device=device, + dtype=hidden_states.dtype, + ) + # Copy the initial hidden states from target model + all_hidden_states[:, :seq_len, :] = hidden_states + + # Construct our inputs for the drafting loop. + # We want tensors that will be able to hold all the tokens we draft. + + dummy_input_ids = torch.zeros( + (batch_size, self.max_draft_len), device=device, dtype=output_ids.dtype + ) + + # draft_input_ids: [batch_size, seq_len + self.max_draft_len] + draft_input_ids = torch.cat((output_ids, dummy_input_ids), dim=1) + + draft_position_ids = 1 + torch.arange( + self.max_draft_len, device=device, dtype=torch.long + ).unsqueeze(0).expand(batch_size, -1) + + draft_position_ids = draft_position_ids + position_ids[:, -1:].expand( + -1, self.max_draft_len + ) + + # draft_position_ids: [batch_size, seq_len + self.max_draft_len] + # These position ids will work throughout the drafting loop. + draft_position_ids = torch.cat((position_ids, draft_position_ids), dim=1) + + # The number of tokens currently in the draft input ids. Possibly includes padding. + curr_num_tokens = seq_len + + # [batch_size] + # The number of valid tokens currently in the draft input ids (does not include padding). + curr_valid_tokens = num_accepted_tokens.clone() + + batch_indices = torch.arange(batch_size, device=device) + + for _ in range(self.max_draft_len): + # Get the input ids, position ids, and hidden states for the current tokens. + # size of tensor is constant for the current iteration and constant across dimensions (curr_num_tokens) + # These tensors may correspond to padding tokens, but due to the causality of the draft model, + # we can extract the draft tokens and hidden states corresponding to the valid tokens. + + input_ids = draft_input_ids[:, :curr_num_tokens] + position_ids = draft_position_ids[:, :curr_num_tokens] + hidden_states = all_hidden_states[:, :curr_num_tokens, :] + + inputs_embeds = self.apply_draft_embedding(input_ids) + draft_output = self.draft_model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + hidden_states=hidden_states, + ) + + draft_output_logits = self.apply_lm_head(draft_output.norm_hidden_state) + + # get the output logits for the latest valid token in each batch + # It is at curr_valid_tokens-1 due to 0-indexing. + latest_draft_logits = draft_output_logits[batch_indices, curr_valid_tokens - 1, :] + + # draft_output_tokens: [batch_size, 1] + draft_output_tokens = self.sample_greedy(latest_draft_logits) + + # if the lm_head outputs tokens from the draft vocab, we need to convert them to tokens + # from the target vocab before the next iteration. + draft_output_tokens = self.apply_d2t(draft_output_tokens) + + # insert the draft output tokens into the draft input ids. + draft_input_ids[batch_indices, curr_valid_tokens] = draft_output_tokens + + # Similarly, we want the hidden state for the latest drafted token in each batch. + # This is a draft hidden state for the token that was just created from the latest valid token. + + # [batch_size, seq_len + self.max_draft_len, hidden_size] + all_hidden_states[batch_indices, curr_valid_tokens, :] = draft_output.last_hidden_state[ + batch_indices, curr_valid_tokens - 1, : + ] + + curr_valid_tokens = curr_valid_tokens + 1 + curr_num_tokens = curr_num_tokens + 1 + + # Return the full draft_input_ids tensor for each batch element. + # The valid prefix within each tensor has length: + # num_previously_accepted[i] + num_newly_accepted_tokens[i] + max_draft_len + # Callers should use this to slice out the valid tokens if needed. + new_tokens = [draft_input_ids[i] for i in range(batch_size)] + + return EagleWrapperOutput( + new_tokens=new_tokens, + new_tokens_lens=num_newly_accepted_tokens, + ) diff --git a/tensorrt_llm/_torch/auto_deploy/models/eagle.py b/tensorrt_llm/_torch/auto_deploy/models/eagle.py index fd60b40b35..38464e6f4f 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/eagle.py +++ b/tensorrt_llm/_torch/auto_deploy/models/eagle.py @@ -22,14 +22,14 @@ Eagle drafter implementations. """ from contextlib import nullcontext -from typing import Dict, Type +from typing import Dict import torch.nn as nn from accelerate import init_empty_weights from torch._prims_common import DeviceLikeType -from transformers import PreTrainedModel -from .custom.modeling_eagle import Eagle3DrafterForCausalLM +from ..utils.logger import ad_logger +from .custom.modeling_eagle import Eagle3DrafterForCausalLM, EagleConfig from .factory import ModelFactoryRegistry from .hf import AutoModelForCausalLMFactory @@ -47,23 +47,28 @@ class EagleDrafterFactory(AutoModelForCausalLMFactory): (e.g., "llama") along with Eagle-specific fields like draft_vocab_size. """ - # Map config model_type -> Eagle drafter model class - _drafter_model_mapping: Dict[str, Type[PreTrainedModel]] = { + _drafter_classes: Dict[str, type] = { "llama": Eagle3DrafterForCausalLM, } def _build_model(self, device: DeviceLikeType) -> nn.Module: model_config, unused_kwargs = self._get_model_config() - # Select the appropriate drafter class based on the base model type - match model_config.model_type: - case "llama": - drafter_cls = self._drafter_model_mapping["llama"] - case _: - raise ValueError( - f"Unsupported model_type '{model_config.model_type}' for Eagle drafter. " - f"Supported types: {list(self._drafter_model_mapping.keys())}" - ) + # Select the appropriate drafter class and config based on the base model type + model_type = model_config.model_type + if model_type not in self._drafter_classes: + raise ValueError( + f"Unsupported model_type '{model_type}' for Eagle drafter. " + f"Supported types: {list(self._drafter_classes.keys())}" + ) + drafter_cls = self._drafter_classes[model_type] + ad_logger.info( + f"EagleDrafterFactory: model_type='{model_type}' -> drafter_cls={drafter_cls.__name__}" + ) + + # Convert base config to EagleConfig, preserving existing values + # and applying model-specific defaults based on model_type + model_config = EagleConfig(model_config, model_type) # Build the model (same pattern as parent's _build_model) with (init_empty_weights if device == "meta" else nullcontext)(): @@ -83,7 +88,7 @@ class EagleDrafterFactory(AutoModelForCausalLMFactory): return model - def build_and_load_model(self, device: DeviceLikeType) -> nn.Module: + def build_and_load_model(self, _device: DeviceLikeType) -> nn.Module: raise NotImplementedError( "EagleDrafterFactory does not support build_and_load_model(). " "Use build_model() + load_or_random_init() instead." diff --git a/tests/integration/defs/examples/test_ad_speculative_decoding.py b/tests/integration/defs/examples/test_ad_speculative_decoding.py index 19c5beb66e..2228c4a6f5 100644 --- a/tests/integration/defs/examples/test_ad_speculative_decoding.py +++ b/tests/integration/defs/examples/test_ad_speculative_decoding.py @@ -15,15 +15,27 @@ import os import re +from dataclasses import dataclass from pathlib import Path +from typing import Optional, Set import pytest import torch +import torch.nn as nn from build_and_run_ad import ExperimentConfig, main from defs.conftest import llm_models_root +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.masking_utils import create_causal_mask +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama.modeling_llama import LlamaModel +from transformers.utils.generic import ModelOutput from tensorrt_llm import SamplingParams from tensorrt_llm._torch.auto_deploy.llm import LLM +from tensorrt_llm._torch.auto_deploy.models.custom.modeling_eagle import ( + EagleWrapper, + EagleWrapperConfig, +) from tensorrt_llm._torch.auto_deploy.models.eagle import EagleDrafterFactory from tensorrt_llm.llmapi import DraftTargetDecodingConfig, Eagle3DecodingConfig, KvCacheConfig @@ -83,8 +95,6 @@ def run_with_autodeploy(model, speculative_config, batch_size): # Select prompts based on batch size selected_prompts = prompts[:batch_size] - spec_config = speculative_config - # Configure KV cache kv_cache_config = KvCacheConfig( free_gpu_memory_fraction=0.01, @@ -94,7 +104,7 @@ def run_with_autodeploy(model, speculative_config, batch_size): llm_args = { "model": model, "skip_loading_weights": False, - "speculative_config": spec_config, + "speculative_config": speculative_config, "runtime": "trtllm", "world_size": 1, "kv_cache_config": kv_cache_config, @@ -388,30 +398,20 @@ def test_eagle_model_with_weights(): print(f"⚠️ Missing in checkpoint (will be random init): {len(missing_keys)}") print(f"⚠️ Unexpected in checkpoint (will be ignored): {len(unexpected_keys)}") - if missing_keys: - print("\nMissing keys (model expects but checkpoint doesn't have):") - for key in sorted(missing_keys): - if "embed_tokens" in key: - print(f" - {key} (expected: shared from target model)") - elif "rotary_emb" in key: - print(f" - {key} (expected: computed at runtime)") - else: - print(f" - {key}") - if unexpected_keys: print("\nUnexpected keys (in checkpoint but model doesn't expect):") for key in sorted(unexpected_keys): if "t2d" in key: - print(f" - {key} (expected: not used in Eagle3)") + print(f" - {key} (expected: not used in Eagle3 for Llama3.1-8B-Instruct)") else: print(f" - {key}") if loaded_keys: print(f"\nLoaded keys ({len(loaded_keys)} total):") - for key in sorted(loaded_keys)[:10]: + for key in sorted(loaded_keys)[:20]: print(f" - {key}") - if len(loaded_keys) > 10: - print(f" ... and {len(loaded_keys) - 10} more") + if len(loaded_keys) > 20: + print(f" ... and {len(loaded_keys) - 20} more") print("--- End Weight Analysis ---\n") @@ -419,15 +419,10 @@ def test_eagle_model_with_weights(): # These are the keys we expect based on Eagle3 architecture: # - embed_tokens: shared from target model (not in Eagle checkpoint) # - t2d: target-to-draft mapping, not used in Eagle3 (uses d2t instead) - expected_missing_keys = {"model.embed_tokens.weight"} expected_unexpected_keys = {"model.t2d"} - assert missing_keys == expected_missing_keys, ( - f"Unexpected missing keys.\n" - f"Expected: {expected_missing_keys}\n" - f"Got: {missing_keys}\n" - f"Extra missing: {missing_keys - expected_missing_keys}\n" - f"Not missing (but expected): {expected_missing_keys - missing_keys}" + assert len(missing_keys) == 0, ( + f"Expect all keys to be loaded.\nKeys that are missing: {missing_keys}\n" ) assert unexpected_keys == expected_unexpected_keys, ( @@ -448,3 +443,777 @@ def test_eagle_model_with_weights(): print("Weights loaded successfully via factory interface!") model.eval() + + +############################################################################### +# Set up to test the prefill-only version of the EagleWrapper model in test_eagle_wrapper_forward(). +# This helps us guarantee that the EagleWrapper model, before it enters AutoDeploy, is working correctly, +# The test does not rely on any TRTLLM logic. +############################################################################### +class PrefillOnlyEagleResourceManager: + """Simple resource manager for Eagle speculative decoding (prefill-only variant). + + Stores hidden states for use by draft loop in EagleWrapper.forward(). + """ + + def __init__( + self, + hidden_size: int, + num_capture_layers: int, + max_batch_size: int, + max_seq_len: int, + max_draft_len: int, + target_dtype: torch.dtype, + ): + # Buffer for hidden states from target model: [max_tokens, hidden_size * num_capture_layers] + # Uses flattened 2D format to match ADHiddenStateManager + self.hidden_states = torch.empty( + max_batch_size * (max_seq_len + max_draft_len), + hidden_size * num_capture_layers, + device="cuda", + dtype=target_dtype, + ) + + +class LlamaModelWithCapture(LlamaModel): + """LlamaModel that captures un-normalized hidden states from specified layers. + + Overwrites the base model's forward method to capture hidden states from specified layers. + Base model's forward method is otherwise copied from LlamaModel in HuggingFace. + Takes PrefillOnlyEagleResourceManager as an argument to store captured hidden states. + """ + + def __init__( + self, + config, + layers_to_capture: Optional[Set[int]] = None, + resource_manager: Optional[PrefillOnlyEagleResourceManager] = None, + ): + super().__init__(config) + # layers_to_capture: set of layer indices (0-indexed) to capture + # If None, capture all layers + if layers_to_capture is None: + self.layers_to_capture = set(range(config.num_hidden_layers)) + else: + self.layers_to_capture = set(layers_to_capture) + + self.resource_manager = resource_manager + + # Validate layer indices + for idx in self.layers_to_capture: + if idx < 0 or idx >= config.num_hidden_layers: + raise ValueError( + f"Layer index {idx} out of range. " + f"Model has {config.num_hidden_layers} layers (0 to {config.num_hidden_layers - 1})" + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + # prefill only - no past key values. + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=None, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # Buffer to collect captured hidden states + captured_hidden_states = [] + + for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + # Capture this layer's output if it's in our list + if layer_idx in self.layers_to_capture: + captured_hidden_states.append(hidden_states) + + # Apply final normalization for last_hidden_state + last_hidden_state = self.norm(hidden_states) + + # Store captured hidden states in resource manager if available + # Resource manager uses 2D flattened format: [max_tokens, hidden_size * num_capture_layers] + if self.resource_manager is not None and captured_hidden_states: + concatenated = torch.cat(captured_hidden_states, dim=-1) + batch_size, seq_len, total_hidden_size = concatenated.shape + assert self.resource_manager.hidden_states.shape[-1] == total_hidden_size, ( + f"Resource manager buffer last dim {self.resource_manager.hidden_states.shape[-1]} " + f"!= concatenated hidden states last dim {total_hidden_size}" + ) + # Flatten to [batch_size * seq_len, total_hidden_size] for 2D format + flattened = concatenated.view(batch_size * seq_len, total_hidden_size) + self.resource_manager.hidden_states[: (batch_size * seq_len), :].copy_(flattened) + + return BaseModelOutputWithPast( + last_hidden_state=last_hidden_state, + hidden_states=tuple(captured_hidden_states) if captured_hidden_states else None, + ) + + +@dataclass +class LlamaForCausalLMOutput(ModelOutput): + logits: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + + +class LlamaForCausalLMWithCapture(nn.Module): + """Wrapper combining LlamaModelWithCapture with lm_head for EagleWrapper testing. + + EagleWrapper.forward() expects target_model(input_ids, position_ids) to return logits. + This class wraps LlamaModelWithCapture (which captures hidden states to resource manager) + and adds the lm_head to produce logits. + """ + + def __init__(self, base_model, capture_model): + super().__init__() + self.model = capture_model # LlamaModelWithCapture with resource_manager + self.lm_head = base_model.lm_head + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + **kwargs, + ): + outputs = self.model( + input_ids=input_ids, inputs_embeds=inputs_embeds, position_ids=position_ids, **kwargs + ) + logits = self.lm_head(outputs.last_hidden_state) + return LlamaForCausalLMOutput( + logits=logits, + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + ) + + def get_input_embeddings(self): + return self.model.embed_tokens + + def get_output_embeddings(self): + return self.model.lm_head + + @classmethod + def from_pretrained( + cls, + model_name: str, + resource_manager, + capture_layers, + dtype=torch.bfloat16, + ): + """Load a base model and create a LlamaForCausalLMWithCapture with shared weights.""" + print(f"Loading {model_name}...") + base_model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=dtype, + device_map={"": 0}, + ) + base_model.eval() + + # Create LlamaModelWithCapture that shares weights with the base model + original_llama_model = base_model.model + + capture_model = LlamaModelWithCapture.__new__(LlamaModelWithCapture) + nn.Module.__init__(capture_model) + + capture_model.config = original_llama_model.config + capture_model.layers_to_capture = capture_layers + capture_model.resource_manager = resource_manager + + # Share all modules (no weight copying) + capture_model.embed_tokens = original_llama_model.embed_tokens + capture_model.layers = original_llama_model.layers + capture_model.norm = original_llama_model.norm + capture_model.rotary_emb = original_llama_model.rotary_emb + capture_model.gradient_checkpointing = original_llama_model.gradient_checkpointing + + return cls(base_model, capture_model) + + +def build_eagle_wrapper( + base_model_path: str, + eagle_model_path: str, + resource_manager: PrefillOnlyEagleResourceManager, + capture_layers: Set[int], + max_seq_len: int, + max_draft_len: int, + target_dtype: torch.dtype, + device: torch.device, +) -> tuple[EagleWrapper, nn.Module]: + """Build an EagleWrapper model for testing. + + This function encapsulates the model building logic using manual model building. + + Returns: + A tuple of (eagle_wrapper, target_model) where: + - eagle_wrapper: The EagleWrapper model ready for inference. + - target_model: The target model (for verification steps). + """ + # Build EagleWrapper manually. + print("\n" + "-" * 40) + print("Building EagleWrapper") + print("-" * 40) + + # Create target model with capture + target_model = LlamaForCausalLMWithCapture.from_pretrained( + base_model_path, resource_manager, capture_layers, target_dtype + ) + print("✓ Created target model with capture") + + # Create draft model using EagleDrafterFactory (mimics production pipeline) + # This ensures weights are loaded correctly via the same path as AutoDeploy + print("\nCreating draft model via EagleDrafterFactory...") + draft_factory = EagleDrafterFactory( + model=eagle_model_path, + skip_loading_weights=False, + ) + + # Build model on meta device first, then load weights + draft_model = draft_factory.build_model("meta") + print(f" Model type: {type(draft_model).__name__}") + + # Load weights via factory + print(" Loading weights via factory.load_or_random_init()...") + draft_factory.load_or_random_init(draft_model, device) + draft_model.eval() + + # Create EagleWrapper config + wrapper_config = EagleWrapperConfig( + max_draft_len=max_draft_len, + load_embedding_from_target=draft_model.load_embedding_from_target, + load_lm_head_from_target=draft_model.load_lm_head_from_target, + ) + + # Build EagleWrapper (this also loads weights from target into draft model where necessary) + eagle_wrapper = EagleWrapper( + config=wrapper_config, + target_model=target_model, + draft_model=draft_model, + resource_manager=resource_manager, + ) + eagle_wrapper.eval() + print("✓ Built EagleWrapper") + + return eagle_wrapper, target_model + + +def generate_target_outputs( + target_model: nn.Module, + input_ids: torch.Tensor, + num_iterations: int, +) -> torch.Tensor: + """Generate tokens from target model using greedy sampling. + + Runs target_model.forward() in a loop, taking the last logit from each output, + applying greedy sampling with torch.argmax, and appending to input_ids. + + Args: + target_model: Model that returns logits from forward(input_ids, position_ids). + input_ids: Initial input token ids of shape [batch_size, seq_len]. + num_iterations: Number of tokens to generate. + + Returns: + output_ids: Tensor of shape [batch_size, seq_len + num_iterations] containing + the original input_ids plus the generated tokens. + """ + device = input_ids.device + init_seq_len = input_ids.shape[1] + print(f"Initial sequence length: {init_seq_len}") + current_ids = input_ids.clone() + + with torch.no_grad(): + for _ in range(num_iterations): + # Generate position_ids from current sequence length + seq_len = current_ids.shape[1] + position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) + position_ids = position_ids.expand(current_ids.shape[0], -1) + + # Forward pass + logits = target_model(current_ids, position_ids=position_ids).logits + + # Take the last logit and apply greedy sampling + last_logits = logits[:, -1, :] # [batch_size, vocab_size] + next_token = torch.argmax(last_logits, dim=-1, keepdim=True) # [batch_size, 1] + + # Append to current_ids + current_ids = torch.cat([current_ids, next_token], dim=1) + + return current_ids + + +def print_token_analysis( + input_ids: torch.Tensor, + num_previously_accepted: torch.Tensor, + target_output_ids: torch.Tensor, + tokenizer, +) -> None: + """Print debug analysis of accepted vs speculative tokens for each batch. + + Args: + input_ids: Current input token ids of shape [batch_size, seq_len]. + num_previously_accepted: Number of accepted tokens per batch [batch_size]. + target_output_ids: Reference output from target model [batch_size, total_seq_len]. + tokenizer: Tokenizer for decoding tokens to text. + """ + batch_size = input_ids.shape[0] + print("\n --- Token Analysis (per batch) ---") + + for i in range(batch_size): + prev_accepted_i = num_previously_accepted[i].item() + + # Accepted tokens (before speculation): input_ids[i, :num_previously_accepted[i]] + accepted_tokens = input_ids[i, :prev_accepted_i] + # Speculative tokens: input_ids[i, num_previously_accepted[i]:] + speculative_tokens = input_ids[i, prev_accepted_i:] + + # Target model's expected token at this position + if prev_accepted_i < target_output_ids.shape[1]: + target_token_at_pos = target_output_ids[i, prev_accepted_i] + else: + target_token_at_pos = None + + print(f"\n Batch {i}:") + print(f" num_previously_accepted: {prev_accepted_i}") + print( + f" Accepted tokens ({accepted_tokens.shape[0]} tokens): {accepted_tokens.tolist()}" + ) + accepted_text = tokenizer.decode(accepted_tokens, skip_special_tokens=True) + print(f' Accepted text: "{accepted_text}"') + print( + f" Speculative tokens ({speculative_tokens.shape[0]} tokens): {speculative_tokens.tolist()}" + ) + if speculative_tokens.shape[0] > 0: + spec_text = tokenizer.decode(speculative_tokens, skip_special_tokens=False) + print(f' Speculative text: "{spec_text}"') + if target_token_at_pos is not None: + target_tok_id = target_token_at_pos.item() + target_tok_str = tokenizer.decode([target_tok_id]) + print( + f' Target model\'s next token at pos {prev_accepted_i}: {target_tok_id} ("{target_tok_str}")' + ) + + +def manual_sample_and_verify( + next_target_inputs: list, + num_accepted_tokens: torch.Tensor, + target_model: nn.Module, + eagle_wrapper: nn.Module, + max_draft_len: int, + device: torch.device, +) -> list: + """Manually verify speculative tokens using sample_and_verify. + + This is used for batch_size > 1 where truncation prevents speculative tokens + from being fed back, so we verify them manually before truncation. + + Args: + next_target_inputs: List of tensors, one per batch element. + num_accepted_tokens: Number of tokens accepted so far per batch [batch_size]. + target_model: The target model for running forward pass. + eagle_wrapper: The EagleWrapper containing sample_and_verify. + max_draft_len: Maximum draft length (for capping counts). + device: Device to run on. + + Returns: + List of (num_accepted, num_speculative) tuples for each batch element. + """ + batch_size = len(next_target_inputs) + + # Due to our truncation trick, all sequences should have the same length + seq_lens = [seq.shape[0] for seq in next_target_inputs] + assert all(slen == seq_lens[0] for slen in seq_lens), ( + f"All sequences should have same length due to truncation, got {seq_lens}" + ) + verify_seq_len = seq_lens[0] + + # Stack into batched tensor + stacked_inputs = torch.stack(next_target_inputs, dim=0) # [batch_size, seq_len] + + # Run target model forward to get logits + verify_position_ids = ( + torch.arange(verify_seq_len, device=device, dtype=torch.long) + .unsqueeze(0) + .expand(batch_size, -1) + ) + with torch.no_grad(): + verify_target_logits = target_model(stacked_inputs, position_ids=verify_position_ids).logits + + # new_num_previously_accepted = num_accepted_tokens + 1 + # This represents the tokens accepted after target model's output from this iteration + new_num_previously_accepted = num_accepted_tokens + 1 + + # Call sample_and_verify to get acceptance counts + _, verify_newly_accepted, _, _ = eagle_wrapper.sample_and_verify( + stacked_inputs, verify_target_logits, new_num_previously_accepted + ) + + # Build results list + results = [] + for i in range(batch_size): + num_accepted_i = min(verify_newly_accepted[i].item(), max_draft_len) + num_speculative = next_target_inputs[i].shape[0] - new_num_previously_accepted[i].item() + results.append((num_accepted_i, num_speculative)) + + return results + + +def verify_eagle_wrapper_output(output, tokenizer, batch_size, num_previously_accepted): + """Verify the output structure and values from EagleWrapper forward pass. + + Args: + output: The output from EagleWrapper forward pass. + tokenizer: The tokenizer for decoding tokens. + batch_size: The batch size. + num_previously_accepted: Tensor of previously accepted token counts. + """ + # Verify output structure + print("\nOutput verification:") + assert output is not None, "Output should not be None" + assert hasattr(output, "new_tokens"), "Output should have new_tokens" + assert hasattr(output, "new_tokens_lens"), "Output should have new_tokens_lens" + + print(f" new_tokens: {type(output.new_tokens)} with {len(output.new_tokens)} items") + for i, tokens in enumerate(output.new_tokens): + new_tokens_text = tokenizer.decode(tokens, skip_special_tokens=True) + print(f" batch {i}: shape {tokens.shape}, tokens: {tokens.tolist()}") + print(f' batch {i}: decoded: "{new_tokens_text}"') + + # Compute num_accepted_tokens from new_tokens_lens + num_previously_accepted + num_accepted_tokens = num_previously_accepted + output.new_tokens_lens + + print(f" new_tokens_lens: {output.new_tokens_lens}") + print(f" num_accepted_tokens (computed): {num_accepted_tokens}") + + # Verify new_tokens_lens is within expected bounds + assert output.new_tokens_lens.shape == (batch_size,), ( + f"new_tokens_lens shape should be ({batch_size},), got {output.new_tokens_lens.shape}" + ) + + +@pytest.mark.parametrize("batch_size", [1, 2]) +def test_eagle_wrapper_forward(batch_size: int): + """Test EagleWrapper forward pass with target and draft models. + + This test validates the full speculative decoding loop: + 1. Target model processes input and captures hidden states + 2. Draft model generates speculative tokens + 3. EagleWrapper orchestrates verification and drafting + + For batch size 1, we call EagleWrapper forward in the expected way. Each iteration generates a "golden token" + (target output) and draft tokens. We input all of them to the wrapper model, + which verifies the draft tokens against the target output. It then outputs the accepted tokens + and newly generated draft tokens, along with numbers of accepted tokens, and the process repeats. + + For batch size > 1, we need to work around the fact that as we run the loop described above, the sequences lengths + in the batch will get out of sync. So instead, we do not provide validated draft tokens as input in each iteration - + we just input the first accepted token from the previous iteration + (which we know was generated by the target model), which keeps the batches in sync. + + To verify that the output draft tokens are reasonable, we run a manual target model verification step + after each iteration. We record how many of the output draft tokens were accepted. + + In the end, we test that the acceptance ratio of the draft tokens generated by the EagleWrapper is reasonable. + + Args: + batch_size: Number of prompts to process in parallel. + """ + print("\n" + "=" * 80) + print("Test: EagleWrapper forward pass") + print("=" * 80) + + # Set random seeds for reproducibility + torch.manual_seed(42) + + # Get model paths using integration test conventions + base_model_path, _, eagle_model_path = get_model_paths() + eagle_path = Path(eagle_model_path) + + if not eagle_path.exists(): + pytest.skip("Eagle model not found (model missing)") + + # Configuration + capture_layers = {1, 15, 28} # Layers to capture for Eagle3 + num_capture_layers = len(capture_layers) + hidden_size = 4096 # Llama 3.1-8B hidden size + dtype = torch.bfloat16 + device = torch.device("cuda") + + # Test dimensions + max_batch_size = 4 + max_seq_len = 1024 + max_draft_len = 3 + + # Tokenize the test prompts + tokenizer = AutoTokenizer.from_pretrained(base_model_path) + # Llama uses left padding for batch inference + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + + if batch_size == 1: + input_ids = tokenizer.encode(prompts[0], return_tensors="pt").to(device) + else: + tokenized = tokenizer( + prompts[:batch_size], + return_tensors="pt", + padding=True, + ) + input_ids = tokenized.input_ids.to(device) + + print(f"input_ids: {input_ids}") + seq_len = input_ids.shape[1] + init_seq_len = seq_len # Store initial sequence length for final comparison + + print("\nTest configuration:") + print(f" target_model: {base_model_path}") + print(f" eagle_model: {eagle_path}") + print(f" batch_size: {batch_size}, seq_len: {seq_len}") + print(f" max_draft_len: {max_draft_len}") + print(f" capture_layers: {capture_layers}") + print(f" prompts: {prompts[:batch_size]}") + print(f" input_ids: {input_ids}") + + # Create resource manager + resource_manager = PrefillOnlyEagleResourceManager( + hidden_size=hidden_size, + num_capture_layers=num_capture_layers, + max_batch_size=max_batch_size, + max_seq_len=max_seq_len, + max_draft_len=max_draft_len, + target_dtype=dtype, + ) + print("\n✓ Created resource manager") + print(f" target_hidden_states shape: {resource_manager.hidden_states.shape}") + + # Build eagle_wrapper and target_model using the refactored function + eagle_wrapper, target_model = build_eagle_wrapper( + base_model_path=base_model_path, + eagle_model_path=str(eagle_path), + resource_manager=resource_manager, + capture_layers=capture_layers, + max_seq_len=max_seq_len, + max_draft_len=max_draft_len, + target_dtype=dtype, + device=device, + ) + + # Create test inputs (input_ids already created from tokenizer above) + position_ids = ( + torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1) + ) + # Set previously_accepted_tokens to the length of input_ids (all context tokens are accepted) + # Shape should be [batch_size] - a 1D tensor with one value per batch + num_previously_accepted = torch.full((batch_size,), seq_len, device=device, dtype=torch.long) + + print("\nTest inputs:") + print(f" input_ids shape: {input_ids.shape}") + print(f" input_ids: {input_ids}") + print(f" position_ids shape: {position_ids.shape}") + print(f" num_previously_accepted: {num_previously_accepted}") + + # Generate target model outputs with greedy sampling + print("\nGenerating target model outputs with greedy sampling (for verification)...") + target_output_ids = generate_target_outputs(target_model, input_ids, num_iterations=100) + print(f" target_output_ids shape: {target_output_ids.shape}") + print(f" target_output_ids: {target_output_ids}") + + # Decode to text as sanity check + generated_text = tokenizer.decode(target_output_ids[0], skip_special_tokens=True) + print(f"\n Target model greedy generation decoded text:\n {generated_text}") + + print("\n✓ EagleWrapper forward pass completed successfully!") + print("✓ Output structure verified") + print("✓ new_tokens_lens within expected bounds") + print("✓ Target model greedy generation completed") + + print("\n================================================") + + num_iterations = 70 + + # Dictionary to track distribution of new_tokens_lens + # keys: 0 to max_draft_len + # newly_accepted_counts[i]: number of times the number of accepted draft tokens was i + newly_accepted_counts = {i: 0 for i in range(max_draft_len + 1)} + + for iteration in range(num_iterations): + print(f"\n{'=' * 40}") + print(f"EagleWrapper forward pass - Iteration {iteration + 1}/{num_iterations}") + print(f"{'=' * 40}") + + seq_len = input_ids.shape[1] + + # Debug: Print speculative tokens, accepted tokens, and target comparison + print_token_analysis(input_ids, num_previously_accepted, target_output_ids, tokenizer) + + kwargs = { + "num_previously_accepted": num_previously_accepted, + } + with torch.no_grad(): + output = eagle_wrapper( + input_ids=input_ids, + position_ids=position_ids, + **kwargs, + ) + + verify_eagle_wrapper_output(output, tokenizer, batch_size, num_previously_accepted) + + # Prepare next_target_inputs + # output.new_tokens[i] contains the full draft_input_ids tensor, but the valid prefix + # has length num_accepted_tokens[i] + max_draft_len. We slice to get only valid tokens. + # We then prepend the first token from the previous iteration's input_ids. + # This prepending is only needed for prefill-only mode, since in the cached case, the first token + # will always be in the KV cache. + # Compute num_accepted_tokens from num_previously_accepted + new_tokens_lens + num_accepted_tokens = num_previously_accepted + output.new_tokens_lens + valid_prefix_len = num_accepted_tokens + max_draft_len + next_target_inputs = [ + torch.cat( + (input_ids[i, 0].unsqueeze(0), output.new_tokens[i][: valid_prefix_len[i]]), + dim=0, + ) + for i in range(batch_size) + ] + + # Track distribution of newly accepted tokens by reading new_tokens_lens from the output. + # For batch size = 1, we are inputting draft tokens to the wrapper model, so new_tokens_lens + # gives the number of accepted tokens from drafts in the previous iteration. + if batch_size == 1: + for val in output.new_tokens_lens.tolist(): + newly_accepted_counts[val] += 1 + print(f" newly_accepted_counts so far: {newly_accepted_counts}") + + # For batch_size > 1, we use manual target model verification below instead to check which of the draft tokens + # generated in *this* iteration would be accepted by the target model. + else: + # For batch_size > 1, verify acceptance using sample_and_verify() + # before truncation (since truncation prevents speculative tokens from being fed back) + verify_results = manual_sample_and_verify( + next_target_inputs, + num_accepted_tokens, + target_model, + eagle_wrapper, + max_draft_len, + device, + ) + + # Update newly_accepted_counts map + for i, (num_accepted_i, num_speculative) in enumerate(verify_results): + newly_accepted_counts[num_accepted_i] += 1 + print( + f" [Batch {i}] sample_and_verify: {num_accepted_i}/{num_speculative} speculative accepted" + ) + + # Truncate to keep shapes consistent across batches in each iteration. + # We know that the first token that is generated in this iteration is accepted, so it is "safe". + # All speculative tokens are truncated regardless of whether they are accepted or not. + # This is a hack to prevent the sequence lengths from getting out of sync across batches in each iteration + # without needing to change the padding every iteration. + truncate_len = input_ids.shape[1] + 1 + next_target_inputs = [seq[:truncate_len] for seq in next_target_inputs] + + next_target_inputs = torch.stack(next_target_inputs, dim=0) + + print(f" next_target_inputs: {next_target_inputs}") + print(f" next_target_inputs.shape: {next_target_inputs.shape}") + + # Update for next iteration + input_ids = next_target_inputs + seq_len = input_ids.shape[1] + position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) + position_ids = position_ids.expand(batch_size, -1) + + if batch_size > 1: + # For multi-batch: increment by 1 (we truncated, so just advance by one token) + num_previously_accepted = num_previously_accepted + 1 + else: + # For single batch: accept the tokens accepted in the previous iteration, plus one + # for the output token that was generated by the target. + num_previously_accepted = num_accepted_tokens + 1 + + print(f"\n{'=' * 40}") + print(f"Loop completed: {num_iterations} iterations") + print("Newly accepted tokens distribution:") + for k, v in newly_accepted_counts.items(): + print(f" {k}: {v}") + + # Calculate acceptance ratio + # For batch_size == 1: uses new_tokens_lens from eagle wrapper + # For batch_size > 1: uses manual verification against target model (since truncation + # prevents speculative tokens from being fed back) + total_accepted = sum(k * v for k, v in newly_accepted_counts.items()) + # First iteration has no tokens to newly accept, subsequent iterations have max_draft_len potential + + num_iterations_with_drafts = num_iterations - 1 if batch_size == 1 else num_iterations + total_potential = max_draft_len * (num_iterations_with_drafts) * batch_size + acceptance_ratio = total_accepted / total_potential if total_potential > 0 else 0.0 + print(f"\nAcceptance ratio: {total_accepted}/{total_potential} = {acceptance_ratio:.3f}") + if batch_size > 1: + print(" (batch_size > 1: measured via manual target model verification)") + assert acceptance_ratio > 0.1, ( + f"Acceptance ratio {acceptance_ratio:.3f} is too low (expected > 0.1)" + ) + + print("\n" + "=" * 80) + print("FINAL OUTPUT COMPARISON") + print("=" * 80) + for i in range(batch_size): + print(f"\n{'─' * 40}") + print(f"BATCH {i}") + print(f"{'─' * 40}") + print(f"\n[Target Model Output] ({target_output_ids[i].shape[0]} tokens):") + print(f" Tokens: {target_output_ids[i].tolist()}") + print(f' Text: "{tokenizer.decode(target_output_ids[i], skip_special_tokens=True)}"') + print(f"\n[Eagle Wrapper Output] ({input_ids[i].shape[0]} tokens):") + print(f" Tokens: {input_ids[i].tolist()}") + print(f' Text: "{tokenizer.decode(input_ids[i], skip_special_tokens=True)}"') + print("\n" + "=" * 80) + + # Verify that the first 10 generated tokens match between target model and eagle wrapper + # They seem to diverge after awhile but are semantically the same. + # Note that even running the target model in decode vs prefill mode, the outputs seem to diverge similarly, + # so this is not worrisome. This test provides a check that they are "similar enough" to each other. + num_tokens_to_check = 10 + print(f"\nVerifying first {num_tokens_to_check} generated tokens match...") + for i in range(batch_size): + target_generated = target_output_ids[i, init_seq_len : init_seq_len + num_tokens_to_check] + eagle_generated = input_ids[i, init_seq_len : init_seq_len + num_tokens_to_check] + + print(f" Batch {i}:") + print(f" Target: {target_generated.tolist()}") + print(f" Eagle: {eagle_generated.tolist()}") + + assert torch.equal(target_generated, eagle_generated), ( + f"Batch {i}: First {num_tokens_to_check} generated tokens do not match!\n" + f" Target: {target_generated.tolist()}\n" + f" Eagle: {eagle_generated.tolist()}" + ) + print(f"✓ First {num_tokens_to_check} generated tokens match for all batches!") diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index bec6a89f90..0e744e51de 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -443,3 +443,5 @@ l0_h100: - examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[eagle3] - examples/test_ad_speculative_decoding.py::test_autodeploy_eagle3_acceptance_rate - examples/test_ad_speculative_decoding.py::test_eagle_model_with_weights + - examples/test_ad_speculative_decoding.py::test_eagle_wrapper_forward[1] + - examples/test_ad_speculative_decoding.py::test_eagle_wrapper_forward[2] diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_eagle.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_eagle.py index bbf4ede408..85d2d2d2f4 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_eagle.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_eagle.py @@ -21,11 +21,11 @@ import pytest import torch from _model_test_utils import get_small_model_config from build_and_run_ad import ExperimentConfig, main -from transformers import AutoConfig from tensorrt_llm._torch.auto_deploy.models.custom.modeling_eagle import ( Eagle3DrafterForCausalLM, - Eagle3Model, + Eagle3DraftOutput, + EagleConfig, ) from tensorrt_llm._torch.auto_deploy.models.eagle import EagleDrafterFactory from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactoryRegistry @@ -33,7 +33,6 @@ from tests.test_common.llm_data import hf_id_to_local_model_dir EAGLE_MODEL_HUB_ID = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" - ############################################################################### # Mock classes for standalone Eagle testing # @@ -43,6 +42,22 @@ EAGLE_MODEL_HUB_ID = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" ############################################################################### +class MockEagleConfig(EagleConfig): + """Config for standalone Eagle testing with embedding/lm_head loaded from checkpoint. + + In production, Eagle shares embedding/lm_head with the target model. + For standalone testing, we need to load these from the checkpoint. + """ + + _drafter_defaults = { + "llama": { + "load_embedding_from_target": False, + "load_lm_head_from_target": False, + "num_capture_layers": 1, + }, + } + + class MockEagle3ModelForCausalLM(Eagle3DrafterForCausalLM): """Test wrapper that provides random hidden states for standalone Eagle testing. @@ -55,7 +70,17 @@ class MockEagle3ModelForCausalLM(Eagle3DrafterForCausalLM): self._hidden_size = config.hidden_size self._dtype = config.dtype - def forward(self, input_ids, **kwargs): + def forward(self, input_ids, position_ids, input_embeds=None, **kwargs): + assert self.model.embed_tokens is not None, ( + "embed_tokens must be set before running standalone Eagle model." + ) + assert self.lm_head is not None, ( + "lm_head must be set before running standalone Eagle model." + ) + + if input_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + # Inject mock hidden states if not provided if "hidden_states" not in kwargs: batch_size, seq_len = input_ids.shape @@ -64,19 +89,39 @@ class MockEagle3ModelForCausalLM(Eagle3DrafterForCausalLM): dtype=self._dtype, device=input_ids.device, ) - return super().forward(input_ids, **kwargs) + draft_output = super().forward(inputs_embeds, position_ids, **kwargs) + logits = self.lm_head(draft_output.norm_hidden_state) + return Eagle3DraftOutput(logits=logits, last_hidden_state=draft_output.last_hidden_state) class MockEagleDrafterFactory(EagleDrafterFactory): """Test factory that uses MockEagle3ModelForCausalLM for standalone Eagle testing. - This factory overrides the drafter mapping to use the mock model class which - generates random hidden states, enabling testing without a target model. + This factory directly builds MockEagle3ModelForCausalLM with MockEagleConfig, + which loads embedding/lm_head from checkpoint for standalone testing. """ - _drafter_model_mapping = { - "llama": MockEagle3ModelForCausalLM, - } + def _build_model(self, device): + from contextlib import nullcontext + + from accelerate import init_empty_weights + + model_config, unused_kwargs = self._get_model_config() + model_config = MockEagleConfig(model_config, model_config.model_type) + + with (init_empty_weights if device == "meta" else nullcontext)(): + model = MockEagle3ModelForCausalLM._from_config(model_config, **unused_kwargs) + + if device == "meta": + if hasattr(model, "post_init"): + model.post_init() + else: + model.to(device) + + self._checkpoint_conversion_mapping = getattr(model, "_checkpoint_conversion_mapping", None) + model.eval() + + return model @pytest.fixture @@ -125,7 +170,7 @@ def test_eagle_model_torch_export(): torch.export for potential TensorRT compilation. Note: We skip loading weights since torch.export only traces the computation - graph (model architecture), not the actual weight values. Random init is fine. + graph (model architecture). """ print("\n" + "=" * 80) print("Test: EagleModel torch.export") @@ -141,40 +186,35 @@ def test_eagle_model_torch_export(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float16 - # Use pretrained config (llama) - can provide it directly to instantiate Eagle3Model. - config_path = eagle_path / "config.json" - config = AutoConfig.from_pretrained(config_path) - - # Create model with random weights (no need to load for export test) - model = Eagle3Model(config) - model.to(device) - model.eval() + # Create model via EagleDrafterFactory (creates Eagle3DrafterForCausalLM) + factory = EagleDrafterFactory(model=str(eagle_path), skip_loading_weights=True) + model = factory.build_model(device) + config = model.config # Create inputs for export batch_size = 1 seq_len = 8 hidden_dim = config.hidden_size - input_ids = torch.randint( - 0, config.vocab_size, (batch_size, seq_len), device=device, dtype=torch.long - ) + inputs_embeds = torch.randn((batch_size, seq_len, hidden_dim), device=device, dtype=dtype) position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) mock_hidden_states = torch.randn((batch_size, seq_len, hidden_dim), device=device, dtype=dtype) print("Export input shapes:") - print(f" input_ids: {input_ids.shape}") + print(f" inputs_embeds: {inputs_embeds.shape}") print(f" position_ids: {position_ids.shape}") print(f" hidden_states: {mock_hidden_states.shape}") example_args = ( - input_ids, + inputs_embeds, position_ids, - mock_hidden_states, ) # Attempt torch.export try: - exported_program = torch.export.export(model, args=example_args) + exported_program = torch.export.export( + model, args=example_args, kwargs={"hidden_states": mock_hidden_states} + ) print("✅ torch.export successful!") print("Graph module code preview (first 20 lines):") code_lines = exported_program.graph_module.code.split("\n")[:20] diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_speculative_decoding.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_speculative_decoding.py index 5b78647083..3d4e8a8794 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_speculative_decoding.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_speculative_decoding.py @@ -21,6 +21,9 @@ from test_common.llm_data import with_mocked_hf_download from tensorrt_llm.llmapi import DraftTargetDecodingConfig, KvCacheConfig +@pytest.mark.skip( + reason="OOM on A30 GPUs on CI - speculative model loading does not support model_kwargs reduction" +) @pytest.mark.parametrize("use_hf_speculative_model", [False]) @with_mocked_hf_download def test_ad_speculative_decoding_smoke(use_hf_speculative_model: bool):