[#10826][feat] AutoDeploy: Eagle One-Model [2/n]: Prefill-Only Implementation (#11073)

Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
This commit is contained in:
gramnarayan 2026-02-02 09:51:10 -08:00 committed by GitHub
parent 3ef8a4639b
commit 585fbb2734
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 1376 additions and 94 deletions

View File

@ -22,14 +22,59 @@ This file contains model definitions used for executing Eagle3 speculative decod
"""
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from transformers import PretrainedConfig, PreTrainedModel
from transformers.activations import ACT2FN
from transformers.utils import ModelOutput
from ...utils._config import deep_merge_dicts
from ...utils.logger import ad_logger
class EagleConfig(PretrainedConfig):
"""Config for Eagle3 drafter models.
Extends PretrainedConfig with Eagle-specific parameters while preserving
all base model config values.
Args:
config: Base config for the draft model from its config.json.
model_type: The base model type (e.g., "llama") used to look up defaults.
"""
# Map model_type -> default Eagle config values
_drafter_defaults: Dict[str, Dict[str, Any]] = {
"llama": {
"load_embedding_from_target": True,
"load_lm_head_from_target": False,
"num_capture_layers": 3,
},
}
def __init__(self, config: PretrainedConfig, model_type: str):
if model_type not in self._drafter_defaults:
raise ValueError(
f"Unsupported model_type '{model_type}' for EagleConfig. "
f"Supported types: {list(self._drafter_defaults.keys())}"
)
defaults = self._drafter_defaults[model_type]
config_dict = config.to_dict()
# Log when config overrides a default
for key, value in defaults.items():
if key in config_dict and config_dict[key] != value:
ad_logger.info(
f"EagleConfig: config has '{key}={config_dict[key]}', "
f"overriding default '{value}'"
)
merged = deep_merge_dicts(defaults, config_dict)
super().__init__(**merged)
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, config, dim, device=None, scaling_factor=1.0):
@ -247,6 +292,7 @@ class Eagle3DecoderLayer(nn.Module):
def __init__(self, config, layer_idx: int = 0):
super().__init__()
self.dtype = config.torch_dtype
self.self_attn = Eagle3Attention(config, layer_idx=layer_idx)
self.hidden_norm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -287,7 +333,14 @@ class Eagle3Model(nn.Module):
def __init__(self, config):
super().__init__()
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.dtype = config.torch_dtype
load_embedding_from_target = getattr(config, "load_embedding_from_target", False)
self.embed_tokens = (
None
if load_embedding_from_target
else nn.Embedding(config.vocab_size, config.hidden_size)
)
if config.draft_vocab_size is not None and config.draft_vocab_size != config.vocab_size:
# Vocab mappings for draft <-> target token conversion
@ -299,13 +352,17 @@ class Eagle3Model(nn.Module):
requires_grad=False,
)
# Input feature fusion: 3 * hidden_size -> hidden_size for Eagle3.
# TODO: Can make this configurable based on number of capture layers.
self.fc = nn.Linear(
config.hidden_size * 3,
config.hidden_size,
bias=getattr(config, "bias", False),
dtype=config.torch_dtype,
# Hidden size compression for target hidden states.
# Assumption: No feedforward fusion needed if we have just one capture layer (valid for MTPEagle)
self.fc = (
nn.Linear(
config.hidden_size * config.num_capture_layers,
config.hidden_size,
bias=getattr(config, "bias", False),
dtype=self.dtype,
)
if config.num_capture_layers > 1
else None
)
self.head_dim = getattr(
@ -328,12 +385,10 @@ class Eagle3Model(nn.Module):
# Assumption: The hidden states are already fused if necessary
def forward(
self,
input_ids: torch.LongTensor,
inputs_embeds: torch.Tensor,
position_ids: torch.LongTensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
input_embeds = self.embed_tokens(input_ids)
cos, sin = self.rotary_emb(hidden_states, position_ids)
position_embeds = (cos, sin)
@ -341,25 +396,23 @@ class Eagle3Model(nn.Module):
for layer in self.midlayer:
hidden_states = layer(
hidden_states=hidden_states,
embeds=input_embeds,
embeds=inputs_embeds,
position_embeds=position_embeds,
)
else:
hidden_states = self.midlayer(
hidden_states=hidden_states,
embeds=input_embeds,
embeds=inputs_embeds,
position_embeds=position_embeds,
)
return hidden_states
def apply_eagle3_fc(self, target_hidden_states: torch.Tensor) -> torch.Tensor:
return self.fc(target_hidden_states)
@dataclass
class Eagle3DraftOutput(ModelOutput):
logits: Optional[torch.FloatTensor] = None
norm_hidden_state: Optional[torch.FloatTensor] = None
last_hidden_state: Optional[torch.FloatTensor] = None
@ -385,16 +438,24 @@ class Eagle3DrafterForCausalLM(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.load_embedding_from_target = getattr(config, "load_embedding_from_target", False)
self.load_lm_head_from_target = getattr(config, "load_lm_head_from_target", False)
self.model = Eagle3Model(config)
self.norm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.draft_vocab_size, bias=False)
self.lm_head = (
None
if self.load_lm_head_from_target
else nn.Linear(config.hidden_size, config.draft_vocab_size, bias=False)
)
eagle_config = getattr(config, "eagle_config", {})
self._return_hidden_post_norm = eagle_config.get("return_hidden_post_norm", False)
def forward(
self,
input_ids: torch.LongTensor,
inputs_embeds: torch.LongTensor,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
) -> Eagle3DraftOutput:
@ -405,8 +466,8 @@ class Eagle3DrafterForCausalLM(PreTrainedModel):
Raises:
ValueError: If hidden_states is not provided in kwargs.
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
batch_size, seq_len, _ = inputs_embeds.shape
device = inputs_embeds.device
# Generate position_ids if not provided
if position_ids is None:
@ -418,17 +479,419 @@ class Eagle3DrafterForCausalLM(PreTrainedModel):
raise ValueError("hidden_states must be provided.")
hidden_states = self.model(
input_ids=input_ids,
position_ids=position_ids,
hidden_states=hidden_states,
inputs_embeds=inputs_embeds, position_ids=position_ids, hidden_states=hidden_states
)
norm_hidden_states = self.norm(hidden_states)
logits = self.lm_head(norm_hidden_states)
norm_hidden_state = self.norm(hidden_states)
last_hidden_state = norm_hidden_states if self._return_hidden_post_norm else hidden_states
last_hidden_state = norm_hidden_state if self._return_hidden_post_norm else hidden_states
return Eagle3DraftOutput(
logits=logits,
norm_hidden_state=norm_hidden_state,
last_hidden_state=last_hidden_state,
)
def get_input_embeddings(self):
if self.model.embed_tokens is not None:
return self.model.embed_tokens
else:
raise NotImplementedError(
"Eagle3DrafterForCausalLM does not have an input embedding layer."
)
def get_output_embeddings(self):
if self.lm_head is not None:
return self.lm_head
else:
raise NotImplementedError(
"Eagle3DrafterForCausalLM does not have an output embedding layer."
)
@dataclass
class EagleWrapperOutput(ModelOutput):
"""Output format compatible with Eagle3OneModelSampler/MTPSampler.
This output format allows the one-model speculative decoding flow to bypass
logits-based sampling in the sampler. The EagleWrapper performs all sampling
and verification internally, returning pre-computed tokens.
"""
# logits: [batch_size, 1, vocab_size]. Used for compatibility.
logits: Optional[torch.Tensor] = None
# new_tokens: [batch_size, max_draft_len + 1]. Accepted tokens from verification.
# This is a 2D tensor where each row contains the accepted tokens for a request,
# padded if fewer tokens were accepted.
new_tokens: Optional[torch.Tensor] = None
# new_tokens_lens: [batch_size]. Number of newly accepted tokens per request in this iteration.
new_tokens_lens: Optional[torch.Tensor] = None
# next_draft_tokens: [batch_size, max_draft_len]. Draft tokens for the next iteration.
# These are the tokens predicted by the draft model, already converted via d2t.
next_draft_tokens: Optional[torch.Tensor] = None
# next_new_tokens: [batch_size, max_draft_len + 1]. Input tokens for the next iteration.
# Format: [last_accepted_token, draft_token_0, draft_token_1, ...]
next_new_tokens: Optional[torch.Tensor] = None
@dataclass
class EagleWrapperConfig:
max_draft_len: int
load_embedding_from_target: bool
load_lm_head_from_target: bool
class EagleWrapper(nn.Module):
def __init__(self, config, target_model, draft_model, resource_manager):
super().__init__()
self.target_model = target_model
self.draft_model = draft_model
self.resource_manager = resource_manager
self.max_draft_len = config.max_draft_len
self.load_embedding_from_target = config.load_embedding_from_target
self.load_lm_head_from_target = config.load_lm_head_from_target
def apply_eagle3_fc(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Apply the fc layer that fuses hidden states from multiple target layers."""
draft_model = self.draft_model.model
hidden_states = hidden_states.to(draft_model.dtype)
fc = getattr(draft_model, "fc", None)
if fc is not None:
hidden_states = fc(hidden_states)
return hidden_states
def apply_d2t(self, draft_output_ids: torch.Tensor) -> torch.Tensor:
"""Apply draft-to-target token mapping if available."""
d2t = getattr(self.draft_model.model, "d2t", None)
if d2t is not None:
draft_output_ids = d2t[draft_output_ids] + draft_output_ids
return draft_output_ids
def apply_draft_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Apply embedding to input_ids for the draft model."""
if self.load_embedding_from_target:
embeds = self.target_model.get_input_embeddings()(input_ids)
return embeds.to(self.draft_model.dtype)
else:
return self.draft_model.get_input_embeddings()(input_ids)
def apply_lm_head(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Apply lm_head to get logits from hidden states."""
if self.load_lm_head_from_target:
lm_head_weights = self.target_model.get_output_embeddings()(hidden_states)
return lm_head_weights.to(self.draft_model.dtype)
else:
return self.draft_model.get_output_embeddings()(hidden_states)
def sample_greedy(self, logits: torch.Tensor) -> torch.Tensor:
ret = torch.argmax(logits, dim=-1)
return ret
def sample_and_verify(
self, input_ids, target_logits: torch.Tensor, num_previously_accepted: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
input_ids: [batch_size, seq_len]
target_logits: [batch_size, seq_len, vocab_size]
num_previously_accepted: [batch_size]. Number of input tokens accepted so far for each batch.
Returns:
output_ids: [batch_size, seq_len] (result of greedy sampling on input ids)
num_newly_accepted_tokens: [batch_size]. Number of newly accepted tokens in each batch.
num_accepted_tokens: [batch_size]. Number of tokens accepted in each batch, including previously accepted.
So num_accepted_tokens[i] = num_previously_accepted + num_newly_accepted_tokens.
last_logits_3d: [batch_size, 1, vocab_size]. The logit used to sample the bonus token.
How it works:
- Get input ids that were not previously accepted.
- Get the corresponding target logits to these input ids (target_logit[j-1] corresponds to input_ids[j])
- Sample a token from the logits for each batch and compare to input_ids to get the newly accepted tokens.
- The output_ids consist of the previously accepted tokens, the newly accepted tokens,
and a newly sampled token after the last accepted token.
"""
batch_size, seq_len = input_ids.shape
# First, check that num_previously_accepted is <= seq_len for each batch
# Additionally, num_previously_accepted should be >= 1 for each batch,
# which corresponds to having some context tokens (context tokens are always accepted).
assert (num_previously_accepted >= 1).all(), (
"num_previously_accepted must be >= 1. Please provide non-empty context in each batch."
)
assert (num_previously_accepted <= seq_len).all(), (
"num_previously_accepted must be <= seq_len for each batch"
)
# We get input tokens that were not yet accepted.
unchecked_input_ids = [
input_ids[i, num_previously_accepted[i] : seq_len].unsqueeze(0)
for i in range(batch_size)
]
# We get the corresponding target logits for the unchecked input tokens.
# logit j-1 corresponds to input j.
# Note that because of our check that num_previously_accepted is >= 1
# We also get the last output token for each batch, because we may need to append it
# at the end.
unchecked_target_logits = [
target_logits[i, (num_previously_accepted[i] - 1) : seq_len, :].unsqueeze(0)
for i in range(batch_size)
]
unchecked_output_ids = [self.sample_greedy(x) for x in unchecked_target_logits]
# corresponding_output_ids: [batch_size, seq_len - 1]. The output ids that correspond to the unchecked input ids
# Omits the last index because that corresponds to a freshly sampled output model.
corresponding_output_ids = [output_id[:, :-1] for output_id in unchecked_output_ids]
# After sample_greedy, corresponding_output_ids should have same shape as unchecked_input_ids
assert [x.shape for x in unchecked_input_ids] == [
x.shape for x in corresponding_output_ids
], "unchecked_input_ids and corresponding_output_ids must have the same shape"
matches = [
(corresponding_output_ids[i] == unchecked_input_ids[i]).int() for i in range(batch_size)
]
# Compute num_newly_accepted_tokens per batch (handles different sizes across batches)
num_newly_accepted_tokens = []
for i in range(batch_size):
if matches[i].numel() == 0:
# No unchecked tokens for this batch (num_previously_accepted == seq_len)
num_newly_accepted_tokens.append(
torch.tensor(0, dtype=torch.long, device=input_ids.device)
)
else:
# prefix_matches[j] is 1 if first j+1 tokens all matched
prefix_matches = matches[i].cumprod(dim=-1)
num_newly_accepted_tokens.append(prefix_matches.sum().long())
num_newly_accepted_tokens = torch.stack(num_newly_accepted_tokens)
# num_accepted_tokens: [batch_size]. The total number of accepted tokens in each batch,
# including previously accepted tokens.
num_accepted_tokens = num_previously_accepted + num_newly_accepted_tokens
assert (num_accepted_tokens <= seq_len).all(), (
"num_accepted_tokens must be <= seq_len for each batch"
)
# Construct draft_input_ids for the draft model
# For each sequence:
# 1. Take previously accepted tokens (skipping the first one)
# 2. Append newly accepted tokens directly from input_ids.
# 3. Append the sampled token for last accepted position: unchecked_output_ids[0][num_newly_accepted]
# 4. Fill the rest with zeros (padding)
# Total real tokens: (num_previously_accepted - 1) + num_newly_accepted + 1 = num_accepted_tokens
draft_input_ids = torch.zeros(
(batch_size, seq_len), dtype=input_ids.dtype, device=input_ids.device
)
for i in range(batch_size):
# 1. Previously accepted tokens (skip the first one in keeping with Eagle convention)
# Note that this potentially includes context tokens, but is structured this way because we
# want the output to contain the entire prefix of accepted tokens because the drafters have no KV cache.
prev_accepted = input_ids[i, 1 : num_previously_accepted[i]]
# 2. Newly accepted input tokens
newly_accepted = input_ids[
i,
num_previously_accepted[i] : num_previously_accepted[i]
+ num_newly_accepted_tokens[i],
]
# 3. The sampled output token for the last accepted position
# unchecked_output_ids[i][j] is the sampled token for position (num_previously_accepted + j)
# We want the token for position num_accepted_tokens, which is index num_newly_accepted_tokens
next_token = unchecked_output_ids[i][0][num_newly_accepted_tokens[i]].unsqueeze(0)
# Concatenate all parts
draft_prefix = torch.cat([prev_accepted, newly_accepted, next_token])
# Sanity check: draft_prefix length should equal num_accepted_tokens
assert draft_prefix.shape[0] == num_accepted_tokens[i], (
f"draft_prefix length {draft_prefix.shape[0]} != num_accepted_tokens {num_accepted_tokens[i]}"
)
# Fill into draft_input_ids (rest remains zeros as padding)
draft_input_ids[i, : num_accepted_tokens[i]] = draft_prefix
# Construct last_logits_3d: [batch_size, 1, vocab_size]
# This is the logit used to sample the bonus token for each sequence.
# The bonus token is sampled from unchecked_target_logits[i][0][num_newly_accepted_tokens[i]]
last_logits_list = []
for i in range(batch_size):
# unchecked_target_logits[i] has shape [1, num_unchecked + 1, vocab_size]
# Index num_newly_accepted_tokens[i] gives the logit for the bonus token
bonus_logit = unchecked_target_logits[i][0, num_newly_accepted_tokens[i], :].unsqueeze(
0
)
last_logits_list.append(bonus_logit)
last_logits_3d = torch.stack(last_logits_list, dim=0) # [batch_size, 1, vocab_size]
return draft_input_ids, num_newly_accepted_tokens, num_accepted_tokens, last_logits_3d
def forward(self, input_ids, position_ids, **kwargs):
"""Dispatch to appropriate forward implementation based on kwargs.
If num_previously_accepted is provided, use the prefill-only (no KV cache) implementation.
Otherwise, this is the KV cache case which is not yet implemented.
"""
num_previously_accepted = kwargs.get("num_previously_accepted", None)
if num_previously_accepted is not None:
return self._forward_prefill_only(input_ids, position_ids, **kwargs)
else:
# KV cached case - not implemented yet
raise NotImplementedError(
"EagleWrapper forward with KV cache is not implemented. "
"This code path is reached when num_previously_accepted is not provided in kwargs."
)
def _forward_prefill_only(self, input_ids, position_ids, **kwargs):
"""Forward pass without KV cache (prefill-only mode).
This is the original implementation that recomputes all attention
from scratch on every forward call.
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
num_previously_accepted = kwargs.get("num_previously_accepted", None)
if num_previously_accepted is None:
raise ValueError("num_previously_accepted must be provided for prefill-only mode.")
# Compute embeddings using the target embedding layer
input_embeds = self.target_model.get_input_embeddings()(input_ids)
# target_logits: [batch_size, seq_len, vocab_size]
# Pass embeddings to target model instead of input_ids
target_logits = self.target_model(
inputs_embeds=input_embeds, position_ids=position_ids
).logits
# output_ids: [batch_size, seq_len]. Contains a prefix of accepted tokens from the target model,
# a generated token from the target model, and some padding to fill out the tensor.
# num_accepted_tokens: [batch_size]. The number of accepted tokens in each batch.
# num_newly_accepted_tokens: [batch_size]. The number of newly accepted tokens in each batch.
output_ids, num_newly_accepted_tokens, num_accepted_tokens, _ = self.sample_and_verify(
input_ids, target_logits, num_previously_accepted
)
# Get hidden states from the resource manager
# resource_manager.hidden_states is [max_tokens, hidden_size * num_capture_layers] (flattened)
# We slice to get [batch_size * seq_len, hidden_size * num_capture_layers]
hidden_states = self.resource_manager.hidden_states[: (batch_size * seq_len), :]
# Apply eagle3 fc to reduce hidden size.
# Note: Since we are in prefill-only mode, this is extremely wasteful - we will apply the eagle3 fc layer
# to hidden states that we have applied it to previously. But, this is generally the case in prefill-only mode.
# Input: [batch_size * seq_len, hidden_size * num_capture_layers]
# Output: [batch_size * seq_len, hidden_size]
hidden_states = self.apply_eagle3_fc(hidden_states)
# Reshape from [batch_size * seq_len, hidden_size] to [batch_size, seq_len, hidden_size]
hidden_size = hidden_states.shape[-1]
hidden_states = hidden_states.view(batch_size, seq_len, hidden_size)
# Create a working buffer for the drafting loop in [batch, seq + draft_len, hidden] format.
# This is separate from resource_manager.hidden_states which remains in flattened format.
all_hidden_states = torch.zeros(
(batch_size, seq_len + self.max_draft_len, hidden_size),
device=device,
dtype=hidden_states.dtype,
)
# Copy the initial hidden states from target model
all_hidden_states[:, :seq_len, :] = hidden_states
# Construct our inputs for the drafting loop.
# We want tensors that will be able to hold all the tokens we draft.
dummy_input_ids = torch.zeros(
(batch_size, self.max_draft_len), device=device, dtype=output_ids.dtype
)
# draft_input_ids: [batch_size, seq_len + self.max_draft_len]
draft_input_ids = torch.cat((output_ids, dummy_input_ids), dim=1)
draft_position_ids = 1 + torch.arange(
self.max_draft_len, device=device, dtype=torch.long
).unsqueeze(0).expand(batch_size, -1)
draft_position_ids = draft_position_ids + position_ids[:, -1:].expand(
-1, self.max_draft_len
)
# draft_position_ids: [batch_size, seq_len + self.max_draft_len]
# These position ids will work throughout the drafting loop.
draft_position_ids = torch.cat((position_ids, draft_position_ids), dim=1)
# The number of tokens currently in the draft input ids. Possibly includes padding.
curr_num_tokens = seq_len
# [batch_size]
# The number of valid tokens currently in the draft input ids (does not include padding).
curr_valid_tokens = num_accepted_tokens.clone()
batch_indices = torch.arange(batch_size, device=device)
for _ in range(self.max_draft_len):
# Get the input ids, position ids, and hidden states for the current tokens.
# size of tensor is constant for the current iteration and constant across dimensions (curr_num_tokens)
# These tensors may correspond to padding tokens, but due to the causality of the draft model,
# we can extract the draft tokens and hidden states corresponding to the valid tokens.
input_ids = draft_input_ids[:, :curr_num_tokens]
position_ids = draft_position_ids[:, :curr_num_tokens]
hidden_states = all_hidden_states[:, :curr_num_tokens, :]
inputs_embeds = self.apply_draft_embedding(input_ids)
draft_output = self.draft_model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
hidden_states=hidden_states,
)
draft_output_logits = self.apply_lm_head(draft_output.norm_hidden_state)
# get the output logits for the latest valid token in each batch
# It is at curr_valid_tokens-1 due to 0-indexing.
latest_draft_logits = draft_output_logits[batch_indices, curr_valid_tokens - 1, :]
# draft_output_tokens: [batch_size, 1]
draft_output_tokens = self.sample_greedy(latest_draft_logits)
# if the lm_head outputs tokens from the draft vocab, we need to convert them to tokens
# from the target vocab before the next iteration.
draft_output_tokens = self.apply_d2t(draft_output_tokens)
# insert the draft output tokens into the draft input ids.
draft_input_ids[batch_indices, curr_valid_tokens] = draft_output_tokens
# Similarly, we want the hidden state for the latest drafted token in each batch.
# This is a draft hidden state for the token that was just created from the latest valid token.
# [batch_size, seq_len + self.max_draft_len, hidden_size]
all_hidden_states[batch_indices, curr_valid_tokens, :] = draft_output.last_hidden_state[
batch_indices, curr_valid_tokens - 1, :
]
curr_valid_tokens = curr_valid_tokens + 1
curr_num_tokens = curr_num_tokens + 1
# Return the full draft_input_ids tensor for each batch element.
# The valid prefix within each tensor has length:
# num_previously_accepted[i] + num_newly_accepted_tokens[i] + max_draft_len
# Callers should use this to slice out the valid tokens if needed.
new_tokens = [draft_input_ids[i] for i in range(batch_size)]
return EagleWrapperOutput(
new_tokens=new_tokens,
new_tokens_lens=num_newly_accepted_tokens,
)

View File

@ -22,14 +22,14 @@ Eagle drafter implementations.
"""
from contextlib import nullcontext
from typing import Dict, Type
from typing import Dict
import torch.nn as nn
from accelerate import init_empty_weights
from torch._prims_common import DeviceLikeType
from transformers import PreTrainedModel
from .custom.modeling_eagle import Eagle3DrafterForCausalLM
from ..utils.logger import ad_logger
from .custom.modeling_eagle import Eagle3DrafterForCausalLM, EagleConfig
from .factory import ModelFactoryRegistry
from .hf import AutoModelForCausalLMFactory
@ -47,23 +47,28 @@ class EagleDrafterFactory(AutoModelForCausalLMFactory):
(e.g., "llama") along with Eagle-specific fields like draft_vocab_size.
"""
# Map config model_type -> Eagle drafter model class
_drafter_model_mapping: Dict[str, Type[PreTrainedModel]] = {
_drafter_classes: Dict[str, type] = {
"llama": Eagle3DrafterForCausalLM,
}
def _build_model(self, device: DeviceLikeType) -> nn.Module:
model_config, unused_kwargs = self._get_model_config()
# Select the appropriate drafter class based on the base model type
match model_config.model_type:
case "llama":
drafter_cls = self._drafter_model_mapping["llama"]
case _:
raise ValueError(
f"Unsupported model_type '{model_config.model_type}' for Eagle drafter. "
f"Supported types: {list(self._drafter_model_mapping.keys())}"
)
# Select the appropriate drafter class and config based on the base model type
model_type = model_config.model_type
if model_type not in self._drafter_classes:
raise ValueError(
f"Unsupported model_type '{model_type}' for Eagle drafter. "
f"Supported types: {list(self._drafter_classes.keys())}"
)
drafter_cls = self._drafter_classes[model_type]
ad_logger.info(
f"EagleDrafterFactory: model_type='{model_type}' -> drafter_cls={drafter_cls.__name__}"
)
# Convert base config to EagleConfig, preserving existing values
# and applying model-specific defaults based on model_type
model_config = EagleConfig(model_config, model_type)
# Build the model (same pattern as parent's _build_model)
with (init_empty_weights if device == "meta" else nullcontext)():
@ -83,7 +88,7 @@ class EagleDrafterFactory(AutoModelForCausalLMFactory):
return model
def build_and_load_model(self, device: DeviceLikeType) -> nn.Module:
def build_and_load_model(self, _device: DeviceLikeType) -> nn.Module:
raise NotImplementedError(
"EagleDrafterFactory does not support build_and_load_model(). "
"Use build_model() + load_or_random_init() instead."

View File

@ -15,15 +15,27 @@
import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Set
import pytest
import torch
import torch.nn as nn
from build_and_run_ad import ExperimentConfig, main
from defs.conftest import llm_models_root
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.masking_utils import create_causal_mask
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaModel
from transformers.utils.generic import ModelOutput
from tensorrt_llm import SamplingParams
from tensorrt_llm._torch.auto_deploy.llm import LLM
from tensorrt_llm._torch.auto_deploy.models.custom.modeling_eagle import (
EagleWrapper,
EagleWrapperConfig,
)
from tensorrt_llm._torch.auto_deploy.models.eagle import EagleDrafterFactory
from tensorrt_llm.llmapi import DraftTargetDecodingConfig, Eagle3DecodingConfig, KvCacheConfig
@ -83,8 +95,6 @@ def run_with_autodeploy(model, speculative_config, batch_size):
# Select prompts based on batch size
selected_prompts = prompts[:batch_size]
spec_config = speculative_config
# Configure KV cache
kv_cache_config = KvCacheConfig(
free_gpu_memory_fraction=0.01,
@ -94,7 +104,7 @@ def run_with_autodeploy(model, speculative_config, batch_size):
llm_args = {
"model": model,
"skip_loading_weights": False,
"speculative_config": spec_config,
"speculative_config": speculative_config,
"runtime": "trtllm",
"world_size": 1,
"kv_cache_config": kv_cache_config,
@ -388,30 +398,20 @@ def test_eagle_model_with_weights():
print(f"⚠️ Missing in checkpoint (will be random init): {len(missing_keys)}")
print(f"⚠️ Unexpected in checkpoint (will be ignored): {len(unexpected_keys)}")
if missing_keys:
print("\nMissing keys (model expects but checkpoint doesn't have):")
for key in sorted(missing_keys):
if "embed_tokens" in key:
print(f" - {key} (expected: shared from target model)")
elif "rotary_emb" in key:
print(f" - {key} (expected: computed at runtime)")
else:
print(f" - {key}")
if unexpected_keys:
print("\nUnexpected keys (in checkpoint but model doesn't expect):")
for key in sorted(unexpected_keys):
if "t2d" in key:
print(f" - {key} (expected: not used in Eagle3)")
print(f" - {key} (expected: not used in Eagle3 for Llama3.1-8B-Instruct)")
else:
print(f" - {key}")
if loaded_keys:
print(f"\nLoaded keys ({len(loaded_keys)} total):")
for key in sorted(loaded_keys)[:10]:
for key in sorted(loaded_keys)[:20]:
print(f" - {key}")
if len(loaded_keys) > 10:
print(f" ... and {len(loaded_keys) - 10} more")
if len(loaded_keys) > 20:
print(f" ... and {len(loaded_keys) - 20} more")
print("--- End Weight Analysis ---\n")
@ -419,15 +419,10 @@ def test_eagle_model_with_weights():
# These are the keys we expect based on Eagle3 architecture:
# - embed_tokens: shared from target model (not in Eagle checkpoint)
# - t2d: target-to-draft mapping, not used in Eagle3 (uses d2t instead)
expected_missing_keys = {"model.embed_tokens.weight"}
expected_unexpected_keys = {"model.t2d"}
assert missing_keys == expected_missing_keys, (
f"Unexpected missing keys.\n"
f"Expected: {expected_missing_keys}\n"
f"Got: {missing_keys}\n"
f"Extra missing: {missing_keys - expected_missing_keys}\n"
f"Not missing (but expected): {expected_missing_keys - missing_keys}"
assert len(missing_keys) == 0, (
f"Expect all keys to be loaded.\nKeys that are missing: {missing_keys}\n"
)
assert unexpected_keys == expected_unexpected_keys, (
@ -448,3 +443,777 @@ def test_eagle_model_with_weights():
print("Weights loaded successfully via factory interface!")
model.eval()
###############################################################################
# Set up to test the prefill-only version of the EagleWrapper model in test_eagle_wrapper_forward().
# This helps us guarantee that the EagleWrapper model, before it enters AutoDeploy, is working correctly,
# The test does not rely on any TRTLLM logic.
###############################################################################
class PrefillOnlyEagleResourceManager:
"""Simple resource manager for Eagle speculative decoding (prefill-only variant).
Stores hidden states for use by draft loop in EagleWrapper.forward().
"""
def __init__(
self,
hidden_size: int,
num_capture_layers: int,
max_batch_size: int,
max_seq_len: int,
max_draft_len: int,
target_dtype: torch.dtype,
):
# Buffer for hidden states from target model: [max_tokens, hidden_size * num_capture_layers]
# Uses flattened 2D format to match ADHiddenStateManager
self.hidden_states = torch.empty(
max_batch_size * (max_seq_len + max_draft_len),
hidden_size * num_capture_layers,
device="cuda",
dtype=target_dtype,
)
class LlamaModelWithCapture(LlamaModel):
"""LlamaModel that captures un-normalized hidden states from specified layers.
Overwrites the base model's forward method to capture hidden states from specified layers.
Base model's forward method is otherwise copied from LlamaModel in HuggingFace.
Takes PrefillOnlyEagleResourceManager as an argument to store captured hidden states.
"""
def __init__(
self,
config,
layers_to_capture: Optional[Set[int]] = None,
resource_manager: Optional[PrefillOnlyEagleResourceManager] = None,
):
super().__init__(config)
# layers_to_capture: set of layer indices (0-indexed) to capture
# If None, capture all layers
if layers_to_capture is None:
self.layers_to_capture = set(range(config.num_hidden_layers))
else:
self.layers_to_capture = set(layers_to_capture)
self.resource_manager = resource_manager
# Validate layer indices
for idx in self.layers_to_capture:
if idx < 0 or idx >= config.num_hidden_layers:
raise ValueError(
f"Layer index {idx} out of range. "
f"Model has {config.num_hidden_layers} layers (0 to {config.num_hidden_layers - 1})"
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if cache_position is None:
# prefill only - no past key values.
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=None,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# Buffer to collect captured hidden states
captured_hidden_states = []
for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
# Capture this layer's output if it's in our list
if layer_idx in self.layers_to_capture:
captured_hidden_states.append(hidden_states)
# Apply final normalization for last_hidden_state
last_hidden_state = self.norm(hidden_states)
# Store captured hidden states in resource manager if available
# Resource manager uses 2D flattened format: [max_tokens, hidden_size * num_capture_layers]
if self.resource_manager is not None and captured_hidden_states:
concatenated = torch.cat(captured_hidden_states, dim=-1)
batch_size, seq_len, total_hidden_size = concatenated.shape
assert self.resource_manager.hidden_states.shape[-1] == total_hidden_size, (
f"Resource manager buffer last dim {self.resource_manager.hidden_states.shape[-1]} "
f"!= concatenated hidden states last dim {total_hidden_size}"
)
# Flatten to [batch_size * seq_len, total_hidden_size] for 2D format
flattened = concatenated.view(batch_size * seq_len, total_hidden_size)
self.resource_manager.hidden_states[: (batch_size * seq_len), :].copy_(flattened)
return BaseModelOutputWithPast(
last_hidden_state=last_hidden_state,
hidden_states=tuple(captured_hidden_states) if captured_hidden_states else None,
)
@dataclass
class LlamaForCausalLMOutput(ModelOutput):
logits: Optional[torch.FloatTensor] = None
last_hidden_state: Optional[torch.FloatTensor] = None
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
class LlamaForCausalLMWithCapture(nn.Module):
"""Wrapper combining LlamaModelWithCapture with lm_head for EagleWrapper testing.
EagleWrapper.forward() expects target_model(input_ids, position_ids) to return logits.
This class wraps LlamaModelWithCapture (which captures hidden states to resource manager)
and adds the lm_head to produce logits.
"""
def __init__(self, base_model, capture_model):
super().__init__()
self.model = capture_model # LlamaModelWithCapture with resource_manager
self.lm_head = base_model.lm_head
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
):
outputs = self.model(
input_ids=input_ids, inputs_embeds=inputs_embeds, position_ids=position_ids, **kwargs
)
logits = self.lm_head(outputs.last_hidden_state)
return LlamaForCausalLMOutput(
logits=logits,
last_hidden_state=outputs.last_hidden_state,
hidden_states=outputs.hidden_states,
)
def get_input_embeddings(self):
return self.model.embed_tokens
def get_output_embeddings(self):
return self.model.lm_head
@classmethod
def from_pretrained(
cls,
model_name: str,
resource_manager,
capture_layers,
dtype=torch.bfloat16,
):
"""Load a base model and create a LlamaForCausalLMWithCapture with shared weights."""
print(f"Loading {model_name}...")
base_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
device_map={"": 0},
)
base_model.eval()
# Create LlamaModelWithCapture that shares weights with the base model
original_llama_model = base_model.model
capture_model = LlamaModelWithCapture.__new__(LlamaModelWithCapture)
nn.Module.__init__(capture_model)
capture_model.config = original_llama_model.config
capture_model.layers_to_capture = capture_layers
capture_model.resource_manager = resource_manager
# Share all modules (no weight copying)
capture_model.embed_tokens = original_llama_model.embed_tokens
capture_model.layers = original_llama_model.layers
capture_model.norm = original_llama_model.norm
capture_model.rotary_emb = original_llama_model.rotary_emb
capture_model.gradient_checkpointing = original_llama_model.gradient_checkpointing
return cls(base_model, capture_model)
def build_eagle_wrapper(
base_model_path: str,
eagle_model_path: str,
resource_manager: PrefillOnlyEagleResourceManager,
capture_layers: Set[int],
max_seq_len: int,
max_draft_len: int,
target_dtype: torch.dtype,
device: torch.device,
) -> tuple[EagleWrapper, nn.Module]:
"""Build an EagleWrapper model for testing.
This function encapsulates the model building logic using manual model building.
Returns:
A tuple of (eagle_wrapper, target_model) where:
- eagle_wrapper: The EagleWrapper model ready for inference.
- target_model: The target model (for verification steps).
"""
# Build EagleWrapper manually.
print("\n" + "-" * 40)
print("Building EagleWrapper")
print("-" * 40)
# Create target model with capture
target_model = LlamaForCausalLMWithCapture.from_pretrained(
base_model_path, resource_manager, capture_layers, target_dtype
)
print("✓ Created target model with capture")
# Create draft model using EagleDrafterFactory (mimics production pipeline)
# This ensures weights are loaded correctly via the same path as AutoDeploy
print("\nCreating draft model via EagleDrafterFactory...")
draft_factory = EagleDrafterFactory(
model=eagle_model_path,
skip_loading_weights=False,
)
# Build model on meta device first, then load weights
draft_model = draft_factory.build_model("meta")
print(f" Model type: {type(draft_model).__name__}")
# Load weights via factory
print(" Loading weights via factory.load_or_random_init()...")
draft_factory.load_or_random_init(draft_model, device)
draft_model.eval()
# Create EagleWrapper config
wrapper_config = EagleWrapperConfig(
max_draft_len=max_draft_len,
load_embedding_from_target=draft_model.load_embedding_from_target,
load_lm_head_from_target=draft_model.load_lm_head_from_target,
)
# Build EagleWrapper (this also loads weights from target into draft model where necessary)
eagle_wrapper = EagleWrapper(
config=wrapper_config,
target_model=target_model,
draft_model=draft_model,
resource_manager=resource_manager,
)
eagle_wrapper.eval()
print("✓ Built EagleWrapper")
return eagle_wrapper, target_model
def generate_target_outputs(
target_model: nn.Module,
input_ids: torch.Tensor,
num_iterations: int,
) -> torch.Tensor:
"""Generate tokens from target model using greedy sampling.
Runs target_model.forward() in a loop, taking the last logit from each output,
applying greedy sampling with torch.argmax, and appending to input_ids.
Args:
target_model: Model that returns logits from forward(input_ids, position_ids).
input_ids: Initial input token ids of shape [batch_size, seq_len].
num_iterations: Number of tokens to generate.
Returns:
output_ids: Tensor of shape [batch_size, seq_len + num_iterations] containing
the original input_ids plus the generated tokens.
"""
device = input_ids.device
init_seq_len = input_ids.shape[1]
print(f"Initial sequence length: {init_seq_len}")
current_ids = input_ids.clone()
with torch.no_grad():
for _ in range(num_iterations):
# Generate position_ids from current sequence length
seq_len = current_ids.shape[1]
position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
position_ids = position_ids.expand(current_ids.shape[0], -1)
# Forward pass
logits = target_model(current_ids, position_ids=position_ids).logits
# Take the last logit and apply greedy sampling
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
next_token = torch.argmax(last_logits, dim=-1, keepdim=True) # [batch_size, 1]
# Append to current_ids
current_ids = torch.cat([current_ids, next_token], dim=1)
return current_ids
def print_token_analysis(
input_ids: torch.Tensor,
num_previously_accepted: torch.Tensor,
target_output_ids: torch.Tensor,
tokenizer,
) -> None:
"""Print debug analysis of accepted vs speculative tokens for each batch.
Args:
input_ids: Current input token ids of shape [batch_size, seq_len].
num_previously_accepted: Number of accepted tokens per batch [batch_size].
target_output_ids: Reference output from target model [batch_size, total_seq_len].
tokenizer: Tokenizer for decoding tokens to text.
"""
batch_size = input_ids.shape[0]
print("\n --- Token Analysis (per batch) ---")
for i in range(batch_size):
prev_accepted_i = num_previously_accepted[i].item()
# Accepted tokens (before speculation): input_ids[i, :num_previously_accepted[i]]
accepted_tokens = input_ids[i, :prev_accepted_i]
# Speculative tokens: input_ids[i, num_previously_accepted[i]:]
speculative_tokens = input_ids[i, prev_accepted_i:]
# Target model's expected token at this position
if prev_accepted_i < target_output_ids.shape[1]:
target_token_at_pos = target_output_ids[i, prev_accepted_i]
else:
target_token_at_pos = None
print(f"\n Batch {i}:")
print(f" num_previously_accepted: {prev_accepted_i}")
print(
f" Accepted tokens ({accepted_tokens.shape[0]} tokens): {accepted_tokens.tolist()}"
)
accepted_text = tokenizer.decode(accepted_tokens, skip_special_tokens=True)
print(f' Accepted text: "{accepted_text}"')
print(
f" Speculative tokens ({speculative_tokens.shape[0]} tokens): {speculative_tokens.tolist()}"
)
if speculative_tokens.shape[0] > 0:
spec_text = tokenizer.decode(speculative_tokens, skip_special_tokens=False)
print(f' Speculative text: "{spec_text}"')
if target_token_at_pos is not None:
target_tok_id = target_token_at_pos.item()
target_tok_str = tokenizer.decode([target_tok_id])
print(
f' Target model\'s next token at pos {prev_accepted_i}: {target_tok_id} ("{target_tok_str}")'
)
def manual_sample_and_verify(
next_target_inputs: list,
num_accepted_tokens: torch.Tensor,
target_model: nn.Module,
eagle_wrapper: nn.Module,
max_draft_len: int,
device: torch.device,
) -> list:
"""Manually verify speculative tokens using sample_and_verify.
This is used for batch_size > 1 where truncation prevents speculative tokens
from being fed back, so we verify them manually before truncation.
Args:
next_target_inputs: List of tensors, one per batch element.
num_accepted_tokens: Number of tokens accepted so far per batch [batch_size].
target_model: The target model for running forward pass.
eagle_wrapper: The EagleWrapper containing sample_and_verify.
max_draft_len: Maximum draft length (for capping counts).
device: Device to run on.
Returns:
List of (num_accepted, num_speculative) tuples for each batch element.
"""
batch_size = len(next_target_inputs)
# Due to our truncation trick, all sequences should have the same length
seq_lens = [seq.shape[0] for seq in next_target_inputs]
assert all(slen == seq_lens[0] for slen in seq_lens), (
f"All sequences should have same length due to truncation, got {seq_lens}"
)
verify_seq_len = seq_lens[0]
# Stack into batched tensor
stacked_inputs = torch.stack(next_target_inputs, dim=0) # [batch_size, seq_len]
# Run target model forward to get logits
verify_position_ids = (
torch.arange(verify_seq_len, device=device, dtype=torch.long)
.unsqueeze(0)
.expand(batch_size, -1)
)
with torch.no_grad():
verify_target_logits = target_model(stacked_inputs, position_ids=verify_position_ids).logits
# new_num_previously_accepted = num_accepted_tokens + 1
# This represents the tokens accepted after target model's output from this iteration
new_num_previously_accepted = num_accepted_tokens + 1
# Call sample_and_verify to get acceptance counts
_, verify_newly_accepted, _, _ = eagle_wrapper.sample_and_verify(
stacked_inputs, verify_target_logits, new_num_previously_accepted
)
# Build results list
results = []
for i in range(batch_size):
num_accepted_i = min(verify_newly_accepted[i].item(), max_draft_len)
num_speculative = next_target_inputs[i].shape[0] - new_num_previously_accepted[i].item()
results.append((num_accepted_i, num_speculative))
return results
def verify_eagle_wrapper_output(output, tokenizer, batch_size, num_previously_accepted):
"""Verify the output structure and values from EagleWrapper forward pass.
Args:
output: The output from EagleWrapper forward pass.
tokenizer: The tokenizer for decoding tokens.
batch_size: The batch size.
num_previously_accepted: Tensor of previously accepted token counts.
"""
# Verify output structure
print("\nOutput verification:")
assert output is not None, "Output should not be None"
assert hasattr(output, "new_tokens"), "Output should have new_tokens"
assert hasattr(output, "new_tokens_lens"), "Output should have new_tokens_lens"
print(f" new_tokens: {type(output.new_tokens)} with {len(output.new_tokens)} items")
for i, tokens in enumerate(output.new_tokens):
new_tokens_text = tokenizer.decode(tokens, skip_special_tokens=True)
print(f" batch {i}: shape {tokens.shape}, tokens: {tokens.tolist()}")
print(f' batch {i}: decoded: "{new_tokens_text}"')
# Compute num_accepted_tokens from new_tokens_lens + num_previously_accepted
num_accepted_tokens = num_previously_accepted + output.new_tokens_lens
print(f" new_tokens_lens: {output.new_tokens_lens}")
print(f" num_accepted_tokens (computed): {num_accepted_tokens}")
# Verify new_tokens_lens is within expected bounds
assert output.new_tokens_lens.shape == (batch_size,), (
f"new_tokens_lens shape should be ({batch_size},), got {output.new_tokens_lens.shape}"
)
@pytest.mark.parametrize("batch_size", [1, 2])
def test_eagle_wrapper_forward(batch_size: int):
"""Test EagleWrapper forward pass with target and draft models.
This test validates the full speculative decoding loop:
1. Target model processes input and captures hidden states
2. Draft model generates speculative tokens
3. EagleWrapper orchestrates verification and drafting
For batch size 1, we call EagleWrapper forward in the expected way. Each iteration generates a "golden token"
(target output) and draft tokens. We input all of them to the wrapper model,
which verifies the draft tokens against the target output. It then outputs the accepted tokens
and newly generated draft tokens, along with numbers of accepted tokens, and the process repeats.
For batch size > 1, we need to work around the fact that as we run the loop described above, the sequences lengths
in the batch will get out of sync. So instead, we do not provide validated draft tokens as input in each iteration -
we just input the first accepted token from the previous iteration
(which we know was generated by the target model), which keeps the batches in sync.
To verify that the output draft tokens are reasonable, we run a manual target model verification step
after each iteration. We record how many of the output draft tokens were accepted.
In the end, we test that the acceptance ratio of the draft tokens generated by the EagleWrapper is reasonable.
Args:
batch_size: Number of prompts to process in parallel.
"""
print("\n" + "=" * 80)
print("Test: EagleWrapper forward pass")
print("=" * 80)
# Set random seeds for reproducibility
torch.manual_seed(42)
# Get model paths using integration test conventions
base_model_path, _, eagle_model_path = get_model_paths()
eagle_path = Path(eagle_model_path)
if not eagle_path.exists():
pytest.skip("Eagle model not found (model missing)")
# Configuration
capture_layers = {1, 15, 28} # Layers to capture for Eagle3
num_capture_layers = len(capture_layers)
hidden_size = 4096 # Llama 3.1-8B hidden size
dtype = torch.bfloat16
device = torch.device("cuda")
# Test dimensions
max_batch_size = 4
max_seq_len = 1024
max_draft_len = 3
# Tokenize the test prompts
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
# Llama uses left padding for batch inference
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
if batch_size == 1:
input_ids = tokenizer.encode(prompts[0], return_tensors="pt").to(device)
else:
tokenized = tokenizer(
prompts[:batch_size],
return_tensors="pt",
padding=True,
)
input_ids = tokenized.input_ids.to(device)
print(f"input_ids: {input_ids}")
seq_len = input_ids.shape[1]
init_seq_len = seq_len # Store initial sequence length for final comparison
print("\nTest configuration:")
print(f" target_model: {base_model_path}")
print(f" eagle_model: {eagle_path}")
print(f" batch_size: {batch_size}, seq_len: {seq_len}")
print(f" max_draft_len: {max_draft_len}")
print(f" capture_layers: {capture_layers}")
print(f" prompts: {prompts[:batch_size]}")
print(f" input_ids: {input_ids}")
# Create resource manager
resource_manager = PrefillOnlyEagleResourceManager(
hidden_size=hidden_size,
num_capture_layers=num_capture_layers,
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
max_draft_len=max_draft_len,
target_dtype=dtype,
)
print("\n✓ Created resource manager")
print(f" target_hidden_states shape: {resource_manager.hidden_states.shape}")
# Build eagle_wrapper and target_model using the refactored function
eagle_wrapper, target_model = build_eagle_wrapper(
base_model_path=base_model_path,
eagle_model_path=str(eagle_path),
resource_manager=resource_manager,
capture_layers=capture_layers,
max_seq_len=max_seq_len,
max_draft_len=max_draft_len,
target_dtype=dtype,
device=device,
)
# Create test inputs (input_ids already created from tokenizer above)
position_ids = (
torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1)
)
# Set previously_accepted_tokens to the length of input_ids (all context tokens are accepted)
# Shape should be [batch_size] - a 1D tensor with one value per batch
num_previously_accepted = torch.full((batch_size,), seq_len, device=device, dtype=torch.long)
print("\nTest inputs:")
print(f" input_ids shape: {input_ids.shape}")
print(f" input_ids: {input_ids}")
print(f" position_ids shape: {position_ids.shape}")
print(f" num_previously_accepted: {num_previously_accepted}")
# Generate target model outputs with greedy sampling
print("\nGenerating target model outputs with greedy sampling (for verification)...")
target_output_ids = generate_target_outputs(target_model, input_ids, num_iterations=100)
print(f" target_output_ids shape: {target_output_ids.shape}")
print(f" target_output_ids: {target_output_ids}")
# Decode to text as sanity check
generated_text = tokenizer.decode(target_output_ids[0], skip_special_tokens=True)
print(f"\n Target model greedy generation decoded text:\n {generated_text}")
print("\n✓ EagleWrapper forward pass completed successfully!")
print("✓ Output structure verified")
print("✓ new_tokens_lens within expected bounds")
print("✓ Target model greedy generation completed")
print("\n================================================")
num_iterations = 70
# Dictionary to track distribution of new_tokens_lens
# keys: 0 to max_draft_len
# newly_accepted_counts[i]: number of times the number of accepted draft tokens was i
newly_accepted_counts = {i: 0 for i in range(max_draft_len + 1)}
for iteration in range(num_iterations):
print(f"\n{'=' * 40}")
print(f"EagleWrapper forward pass - Iteration {iteration + 1}/{num_iterations}")
print(f"{'=' * 40}")
seq_len = input_ids.shape[1]
# Debug: Print speculative tokens, accepted tokens, and target comparison
print_token_analysis(input_ids, num_previously_accepted, target_output_ids, tokenizer)
kwargs = {
"num_previously_accepted": num_previously_accepted,
}
with torch.no_grad():
output = eagle_wrapper(
input_ids=input_ids,
position_ids=position_ids,
**kwargs,
)
verify_eagle_wrapper_output(output, tokenizer, batch_size, num_previously_accepted)
# Prepare next_target_inputs
# output.new_tokens[i] contains the full draft_input_ids tensor, but the valid prefix
# has length num_accepted_tokens[i] + max_draft_len. We slice to get only valid tokens.
# We then prepend the first token from the previous iteration's input_ids.
# This prepending is only needed for prefill-only mode, since in the cached case, the first token
# will always be in the KV cache.
# Compute num_accepted_tokens from num_previously_accepted + new_tokens_lens
num_accepted_tokens = num_previously_accepted + output.new_tokens_lens
valid_prefix_len = num_accepted_tokens + max_draft_len
next_target_inputs = [
torch.cat(
(input_ids[i, 0].unsqueeze(0), output.new_tokens[i][: valid_prefix_len[i]]),
dim=0,
)
for i in range(batch_size)
]
# Track distribution of newly accepted tokens by reading new_tokens_lens from the output.
# For batch size = 1, we are inputting draft tokens to the wrapper model, so new_tokens_lens
# gives the number of accepted tokens from drafts in the previous iteration.
if batch_size == 1:
for val in output.new_tokens_lens.tolist():
newly_accepted_counts[val] += 1
print(f" newly_accepted_counts so far: {newly_accepted_counts}")
# For batch_size > 1, we use manual target model verification below instead to check which of the draft tokens
# generated in *this* iteration would be accepted by the target model.
else:
# For batch_size > 1, verify acceptance using sample_and_verify()
# before truncation (since truncation prevents speculative tokens from being fed back)
verify_results = manual_sample_and_verify(
next_target_inputs,
num_accepted_tokens,
target_model,
eagle_wrapper,
max_draft_len,
device,
)
# Update newly_accepted_counts map
for i, (num_accepted_i, num_speculative) in enumerate(verify_results):
newly_accepted_counts[num_accepted_i] += 1
print(
f" [Batch {i}] sample_and_verify: {num_accepted_i}/{num_speculative} speculative accepted"
)
# Truncate to keep shapes consistent across batches in each iteration.
# We know that the first token that is generated in this iteration is accepted, so it is "safe".
# All speculative tokens are truncated regardless of whether they are accepted or not.
# This is a hack to prevent the sequence lengths from getting out of sync across batches in each iteration
# without needing to change the padding every iteration.
truncate_len = input_ids.shape[1] + 1
next_target_inputs = [seq[:truncate_len] for seq in next_target_inputs]
next_target_inputs = torch.stack(next_target_inputs, dim=0)
print(f" next_target_inputs: {next_target_inputs}")
print(f" next_target_inputs.shape: {next_target_inputs.shape}")
# Update for next iteration
input_ids = next_target_inputs
seq_len = input_ids.shape[1]
position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
position_ids = position_ids.expand(batch_size, -1)
if batch_size > 1:
# For multi-batch: increment by 1 (we truncated, so just advance by one token)
num_previously_accepted = num_previously_accepted + 1
else:
# For single batch: accept the tokens accepted in the previous iteration, plus one
# for the output token that was generated by the target.
num_previously_accepted = num_accepted_tokens + 1
print(f"\n{'=' * 40}")
print(f"Loop completed: {num_iterations} iterations")
print("Newly accepted tokens distribution:")
for k, v in newly_accepted_counts.items():
print(f" {k}: {v}")
# Calculate acceptance ratio
# For batch_size == 1: uses new_tokens_lens from eagle wrapper
# For batch_size > 1: uses manual verification against target model (since truncation
# prevents speculative tokens from being fed back)
total_accepted = sum(k * v for k, v in newly_accepted_counts.items())
# First iteration has no tokens to newly accept, subsequent iterations have max_draft_len potential
num_iterations_with_drafts = num_iterations - 1 if batch_size == 1 else num_iterations
total_potential = max_draft_len * (num_iterations_with_drafts) * batch_size
acceptance_ratio = total_accepted / total_potential if total_potential > 0 else 0.0
print(f"\nAcceptance ratio: {total_accepted}/{total_potential} = {acceptance_ratio:.3f}")
if batch_size > 1:
print(" (batch_size > 1: measured via manual target model verification)")
assert acceptance_ratio > 0.1, (
f"Acceptance ratio {acceptance_ratio:.3f} is too low (expected > 0.1)"
)
print("\n" + "=" * 80)
print("FINAL OUTPUT COMPARISON")
print("=" * 80)
for i in range(batch_size):
print(f"\n{'' * 40}")
print(f"BATCH {i}")
print(f"{'' * 40}")
print(f"\n[Target Model Output] ({target_output_ids[i].shape[0]} tokens):")
print(f" Tokens: {target_output_ids[i].tolist()}")
print(f' Text: "{tokenizer.decode(target_output_ids[i], skip_special_tokens=True)}"')
print(f"\n[Eagle Wrapper Output] ({input_ids[i].shape[0]} tokens):")
print(f" Tokens: {input_ids[i].tolist()}")
print(f' Text: "{tokenizer.decode(input_ids[i], skip_special_tokens=True)}"')
print("\n" + "=" * 80)
# Verify that the first 10 generated tokens match between target model and eagle wrapper
# They seem to diverge after awhile but are semantically the same.
# Note that even running the target model in decode vs prefill mode, the outputs seem to diverge similarly,
# so this is not worrisome. This test provides a check that they are "similar enough" to each other.
num_tokens_to_check = 10
print(f"\nVerifying first {num_tokens_to_check} generated tokens match...")
for i in range(batch_size):
target_generated = target_output_ids[i, init_seq_len : init_seq_len + num_tokens_to_check]
eagle_generated = input_ids[i, init_seq_len : init_seq_len + num_tokens_to_check]
print(f" Batch {i}:")
print(f" Target: {target_generated.tolist()}")
print(f" Eagle: {eagle_generated.tolist()}")
assert torch.equal(target_generated, eagle_generated), (
f"Batch {i}: First {num_tokens_to_check} generated tokens do not match!\n"
f" Target: {target_generated.tolist()}\n"
f" Eagle: {eagle_generated.tolist()}"
)
print(f"✓ First {num_tokens_to_check} generated tokens match for all batches!")

View File

@ -443,3 +443,5 @@ l0_h100:
- examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[eagle3]
- examples/test_ad_speculative_decoding.py::test_autodeploy_eagle3_acceptance_rate
- examples/test_ad_speculative_decoding.py::test_eagle_model_with_weights
- examples/test_ad_speculative_decoding.py::test_eagle_wrapper_forward[1]
- examples/test_ad_speculative_decoding.py::test_eagle_wrapper_forward[2]

View File

@ -21,11 +21,11 @@ import pytest
import torch
from _model_test_utils import get_small_model_config
from build_and_run_ad import ExperimentConfig, main
from transformers import AutoConfig
from tensorrt_llm._torch.auto_deploy.models.custom.modeling_eagle import (
Eagle3DrafterForCausalLM,
Eagle3Model,
Eagle3DraftOutput,
EagleConfig,
)
from tensorrt_llm._torch.auto_deploy.models.eagle import EagleDrafterFactory
from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactoryRegistry
@ -33,7 +33,6 @@ from tests.test_common.llm_data import hf_id_to_local_model_dir
EAGLE_MODEL_HUB_ID = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
###############################################################################
# Mock classes for standalone Eagle testing
#
@ -43,6 +42,22 @@ EAGLE_MODEL_HUB_ID = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
###############################################################################
class MockEagleConfig(EagleConfig):
"""Config for standalone Eagle testing with embedding/lm_head loaded from checkpoint.
In production, Eagle shares embedding/lm_head with the target model.
For standalone testing, we need to load these from the checkpoint.
"""
_drafter_defaults = {
"llama": {
"load_embedding_from_target": False,
"load_lm_head_from_target": False,
"num_capture_layers": 1,
},
}
class MockEagle3ModelForCausalLM(Eagle3DrafterForCausalLM):
"""Test wrapper that provides random hidden states for standalone Eagle testing.
@ -55,7 +70,17 @@ class MockEagle3ModelForCausalLM(Eagle3DrafterForCausalLM):
self._hidden_size = config.hidden_size
self._dtype = config.dtype
def forward(self, input_ids, **kwargs):
def forward(self, input_ids, position_ids, input_embeds=None, **kwargs):
assert self.model.embed_tokens is not None, (
"embed_tokens must be set before running standalone Eagle model."
)
assert self.lm_head is not None, (
"lm_head must be set before running standalone Eagle model."
)
if input_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
# Inject mock hidden states if not provided
if "hidden_states" not in kwargs:
batch_size, seq_len = input_ids.shape
@ -64,19 +89,39 @@ class MockEagle3ModelForCausalLM(Eagle3DrafterForCausalLM):
dtype=self._dtype,
device=input_ids.device,
)
return super().forward(input_ids, **kwargs)
draft_output = super().forward(inputs_embeds, position_ids, **kwargs)
logits = self.lm_head(draft_output.norm_hidden_state)
return Eagle3DraftOutput(logits=logits, last_hidden_state=draft_output.last_hidden_state)
class MockEagleDrafterFactory(EagleDrafterFactory):
"""Test factory that uses MockEagle3ModelForCausalLM for standalone Eagle testing.
This factory overrides the drafter mapping to use the mock model class which
generates random hidden states, enabling testing without a target model.
This factory directly builds MockEagle3ModelForCausalLM with MockEagleConfig,
which loads embedding/lm_head from checkpoint for standalone testing.
"""
_drafter_model_mapping = {
"llama": MockEagle3ModelForCausalLM,
}
def _build_model(self, device):
from contextlib import nullcontext
from accelerate import init_empty_weights
model_config, unused_kwargs = self._get_model_config()
model_config = MockEagleConfig(model_config, model_config.model_type)
with (init_empty_weights if device == "meta" else nullcontext)():
model = MockEagle3ModelForCausalLM._from_config(model_config, **unused_kwargs)
if device == "meta":
if hasattr(model, "post_init"):
model.post_init()
else:
model.to(device)
self._checkpoint_conversion_mapping = getattr(model, "_checkpoint_conversion_mapping", None)
model.eval()
return model
@pytest.fixture
@ -125,7 +170,7 @@ def test_eagle_model_torch_export():
torch.export for potential TensorRT compilation.
Note: We skip loading weights since torch.export only traces the computation
graph (model architecture), not the actual weight values. Random init is fine.
graph (model architecture).
"""
print("\n" + "=" * 80)
print("Test: EagleModel torch.export")
@ -141,40 +186,35 @@ def test_eagle_model_torch_export():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16
# Use pretrained config (llama) - can provide it directly to instantiate Eagle3Model.
config_path = eagle_path / "config.json"
config = AutoConfig.from_pretrained(config_path)
# Create model with random weights (no need to load for export test)
model = Eagle3Model(config)
model.to(device)
model.eval()
# Create model via EagleDrafterFactory (creates Eagle3DrafterForCausalLM)
factory = EagleDrafterFactory(model=str(eagle_path), skip_loading_weights=True)
model = factory.build_model(device)
config = model.config
# Create inputs for export
batch_size = 1
seq_len = 8
hidden_dim = config.hidden_size
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device, dtype=torch.long
)
inputs_embeds = torch.randn((batch_size, seq_len, hidden_dim), device=device, dtype=dtype)
position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
mock_hidden_states = torch.randn((batch_size, seq_len, hidden_dim), device=device, dtype=dtype)
print("Export input shapes:")
print(f" input_ids: {input_ids.shape}")
print(f" inputs_embeds: {inputs_embeds.shape}")
print(f" position_ids: {position_ids.shape}")
print(f" hidden_states: {mock_hidden_states.shape}")
example_args = (
input_ids,
inputs_embeds,
position_ids,
mock_hidden_states,
)
# Attempt torch.export
try:
exported_program = torch.export.export(model, args=example_args)
exported_program = torch.export.export(
model, args=example_args, kwargs={"hidden_states": mock_hidden_states}
)
print("✅ torch.export successful!")
print("Graph module code preview (first 20 lines):")
code_lines = exported_program.graph_module.code.split("\n")[:20]

View File

@ -21,6 +21,9 @@ from test_common.llm_data import with_mocked_hf_download
from tensorrt_llm.llmapi import DraftTargetDecodingConfig, KvCacheConfig
@pytest.mark.skip(
reason="OOM on A30 GPUs on CI - speculative model loading does not support model_kwargs reduction"
)
@pytest.mark.parametrize("use_hf_speculative_model", [False])
@with_mocked_hf_download
def test_ad_speculative_decoding_smoke(use_hf_speculative_model: bool):