[None][chore] AutoDeploy: Eagle One-Model [1/n]: PyTorch impl for Eagle3 Llama checkpoint (#10674)

Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
This commit is contained in:
gramnarayan 2026-01-28 12:10:49 -08:00 committed by GitHub
parent 0ffa77af51
commit 744a955cbb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 889 additions and 2 deletions

View File

@ -1,2 +1,2 @@
from . import custom, hf, nemotron_flash, patches
from . import custom, eagle, hf, nemotron_flash, patches
from .factory import *

View File

@ -1,7 +1,9 @@
from .modeling_eagle import Eagle3DrafterForCausalLM
from .modeling_nemotron_flash import NemotronFlashForCausalLM, NemotronFlashPreTrainedTokenizerFast
from .modeling_nemotron_h import NemotronHForCausalLM
__all__ = (
"Eagle3DrafterForCausalLM",
"NemotronFlashForCausalLM",
"NemotronFlashPreTrainedTokenizerFast",
"NemotronHForCausalLM",

View File

@ -0,0 +1,434 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Eagle3 model implementation for AutoDeploy.
Eagle3 is a speculative decoding draft model that predicts next tokens based on
hidden states from a target model (e.g., Llama-3.1-8B-Instruct).
This file contains model definitions used for executing Eagle3 speculative decoding in AutoDeploy.
"""
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from transformers.activations import ACT2FN
from transformers.utils import ModelOutput
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, config, dim, device=None, scaling_factor=1.0):
super().__init__()
self.scaling_factor = scaling_factor
self.dim = dim
self.base = getattr(config, "rope_theta", 10000.0)
self.config = config
self.factor = 2
max_position_embeddings = self.config.max_position_embeddings
if (
not hasattr(config, "rope_type")
or config.rope_type is None
or config.rope_type == "default"
):
inv_freq = 1.0 / (
self.base
** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)
)
self.max_seq_len_cached = max_position_embeddings
elif config.rope_type == "ntk":
assert self.config.orig_max_position_embeddings is not None
orig_max_position_embeddings = self.config.orig_max_position_embeddings
self.base = self.base * (
(self.factor * max_position_embeddings / orig_max_position_embeddings)
- (self.factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (
self.base
** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)
)
self.max_seq_len_cached = orig_max_position_embeddings
else:
raise ValueError(f"Not support rope_type: {config.rope_type}")
self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = (
device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
if q is not None:
q_embed = (q * cos) + (rotate_half(q) * sin)
else:
q_embed = None
if k is not None:
k_embed = (k * cos) + (rotate_half(k) * sin)
else:
k_embed = None
return q_embed, k_embed
class EagleRMSNorm(nn.Module):
"""RMSNorm implementation that uses the torch_rmsnorm custom op."""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
result = torch.ops.auto_deploy.torch_rmsnorm(
hidden_states, self.weight, self.variance_epsilon
)
return result
class EagleMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class Eagle3Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.is_causal = True
# Note: Eagle3Attention expects 2 * hidden_size input, which is the concatenation of the hidden states
# and the input embeddings.
self.q_proj = nn.Linear(
2 * config.hidden_size,
config.num_attention_heads * self.head_dim,
bias=config.attention_bias,
)
self.k_proj = nn.Linear(
2 * config.hidden_size,
config.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.v_proj = nn.Linear(
2 * config.hidden_size,
config.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim,
config.hidden_size,
bias=config.attention_bias,
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
bsz, q_len, _ = hidden_states.size()
cos, sin = position_embeddings
# Projections
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Reshape to [Batch, Seq, Heads, Dim]
query_states = query_states.view(bsz, q_len, -1, self.head_dim)
key_states = key_states.view(bsz, q_len, -1, self.head_dim)
value_states = value_states.view(bsz, q_len, -1, self.head_dim)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=2
)
attn_output = torch.ops.auto_deploy.torch_attention(
query_states,
key_states,
value_states,
attn_mask=None,
dropout_p=0.0,
is_causal=self.is_causal,
layout="bsnd",
)
attn_output = attn_output.view(bsz, q_len, self.num_attention_heads * self.head_dim)
attn_output = self.o_proj(attn_output)
return attn_output
class Eagle3DecoderLayer(nn.Module):
"""Eagle decoder layer with modified attention and hidden state normalization."""
def __init__(self, config, layer_idx: int = 0):
super().__init__()
self.self_attn = Eagle3Attention(config, layer_idx=layer_idx)
self.hidden_norm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = EagleMLP(config)
def forward(
self,
hidden_states: torch.Tensor,
embeds: torch.Tensor,
position_embeds: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states
hidden_states = self.hidden_norm(hidden_states)
embeds = self.input_layernorm(embeds)
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
hidden_states = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeds,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Eagle3Model(nn.Module):
"""Core Eagle model architecture."""
def __init__(self, config):
super().__init__()
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
if config.draft_vocab_size is not None and config.draft_vocab_size != config.vocab_size:
# Vocab mappings for draft <-> target token conversion
# Needed to convert draft outputs to target inputs for Eagle3.
# Since we reuse the target model's embedding in the drafter, we need
# to do this conversion after every draft iteration.
self.d2t = nn.Parameter(
torch.empty((config.draft_vocab_size,), dtype=torch.int32),
requires_grad=False,
)
# Input feature fusion: 3 * hidden_size -> hidden_size for Eagle3.
# TODO: Can make this configurable based on number of capture layers.
self.fc = nn.Linear(
config.hidden_size * 3,
config.hidden_size,
bias=getattr(config, "bias", False),
dtype=config.torch_dtype,
)
self.head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
self.rotary_emb = LlamaRotaryEmbedding(
config=config, dim=self.head_dim, device=torch.device("cuda")
)
if config.num_hidden_layers > 1:
self.midlayer = nn.ModuleList(
[Eagle3DecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
)
else:
self.midlayer = Eagle3DecoderLayer(config, layer_idx=0)
self.num_hidden_layers = config.num_hidden_layers
# Assumption: The hidden states are already fused if necessary
def forward(
self,
input_ids: torch.LongTensor,
position_ids: torch.LongTensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
input_embeds = self.embed_tokens(input_ids)
cos, sin = self.rotary_emb(hidden_states, position_ids)
position_embeds = (cos, sin)
if self.num_hidden_layers > 1:
for layer in self.midlayer:
hidden_states = layer(
hidden_states=hidden_states,
embeds=input_embeds,
position_embeds=position_embeds,
)
else:
hidden_states = self.midlayer(
hidden_states=hidden_states,
embeds=input_embeds,
position_embeds=position_embeds,
)
return hidden_states
def apply_eagle3_fc(self, target_hidden_states: torch.Tensor) -> torch.Tensor:
return self.fc(target_hidden_states)
@dataclass
class Eagle3DraftOutput(ModelOutput):
logits: Optional[torch.FloatTensor] = None
last_hidden_state: Optional[torch.FloatTensor] = None
class Eagle3DrafterForCausalLM(PreTrainedModel):
"""HuggingFace-compatible wrapper for EagleModel.
This wrapper makes EagleModel compatible with AutoDeploy's model loading
and inference pipeline.
"""
base_model_prefix = "model"
supports_gradient_checkpointing = False
_no_split_modules = ["Eagle3DecoderLayer"]
# Checkpoint conversion mapping: Eagle checkpoints have keys like "fc.weight"
# but the wrapper model expects "model.fc.weight" (due to self.model = Eagle3Model).
# This mapping tells the factory to add "model." prefix when loading weights.
# Used by AutoModelForCausalLMFactory._remap_param_names_load_hook()
_checkpoint_conversion_mapping = {
"^(?!lm_head|norm)": "model.", # Prepend "model." to all keys EXCEPT lm_head and norm
}
def __init__(self, config):
super().__init__(config)
self.model = Eagle3Model(config)
self.norm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.draft_vocab_size, bias=False)
eagle_config = getattr(config, "eagle_config", {})
self._return_hidden_post_norm = eagle_config.get("return_hidden_post_norm", False)
def forward(
self,
input_ids: torch.LongTensor,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
) -> Eagle3DraftOutput:
"""
Kwargs:
hidden_states: Hidden states from the target model. Required.
Raises:
ValueError: If hidden_states is not provided in kwargs.
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Generate position_ids if not provided
if position_ids is None:
position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
position_ids = position_ids.expand(batch_size, -1)
hidden_states = kwargs.get("hidden_states")
if hidden_states is None:
raise ValueError("hidden_states must be provided.")
hidden_states = self.model(
input_ids=input_ids,
position_ids=position_ids,
hidden_states=hidden_states,
)
norm_hidden_states = self.norm(hidden_states)
logits = self.lm_head(norm_hidden_states)
last_hidden_state = norm_hidden_states if self._return_hidden_post_norm else hidden_states
return Eagle3DraftOutput(
logits=logits,
last_hidden_state=last_hidden_state,
)

View File

@ -0,0 +1,90 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory definitions for building models related to Eagle in AutoDeploy.
This module provides EagleDrafterFactory, a specialized factory for building
Eagle speculative decoding draft models. It extends AutoModelForCausalLMFactory
to handle the mapping from base model types (e.g., "llama") to their corresponding
Eagle drafter implementations.
"""
from contextlib import nullcontext
from typing import Dict, Type
import torch.nn as nn
from accelerate import init_empty_weights
from torch._prims_common import DeviceLikeType
from transformers import PreTrainedModel
from .custom.modeling_eagle import Eagle3DrafterForCausalLM
from .factory import ModelFactoryRegistry
from .hf import AutoModelForCausalLMFactory
@ModelFactoryRegistry.register("EagleDrafter")
class EagleDrafterFactory(AutoModelForCausalLMFactory):
"""Factory for building Eagle drafter models.
This factory handles the mapping from base model types (e.g., "llama") to
their corresponding Eagle drafter model implementations. It overrides
_build_model() to directly construct the appropriate drafter class based
on the checkpoint's model_type.
The checkpoint config is expected to have the base model's model_type
(e.g., "llama") along with Eagle-specific fields like draft_vocab_size.
"""
# Map config model_type -> Eagle drafter model class
_drafter_model_mapping: Dict[str, Type[PreTrainedModel]] = {
"llama": Eagle3DrafterForCausalLM,
}
def _build_model(self, device: DeviceLikeType) -> nn.Module:
model_config, unused_kwargs = self._get_model_config()
# Select the appropriate drafter class based on the base model type
match model_config.model_type:
case "llama":
drafter_cls = self._drafter_model_mapping["llama"]
case _:
raise ValueError(
f"Unsupported model_type '{model_config.model_type}' for Eagle drafter. "
f"Supported types: {list(self._drafter_model_mapping.keys())}"
)
# Build the model (same pattern as parent's _build_model)
with (init_empty_weights if device == "meta" else nullcontext)():
model = drafter_cls._from_config(model_config, **unused_kwargs)
if device == "meta":
# post-init must be called explicitly for HF models with init_empty_weights
if hasattr(model, "post_init"):
model.post_init()
else:
model.to(device)
# Store checkpoint conversion mapping if present
self._checkpoint_conversion_mapping = getattr(model, "_checkpoint_conversion_mapping", None)
model.eval()
return model
def build_and_load_model(self, device: DeviceLikeType) -> nn.Module:
raise NotImplementedError(
"EagleDrafterFactory does not support build_and_load_model(). "
"Use build_model() + load_or_random_init() instead."
)

View File

@ -14,13 +14,17 @@
# limitations under the License.
import os
import re
from pathlib import Path
import pytest
import torch
from build_and_run_ad import ExperimentConfig, main
from defs.conftest import llm_models_root
from tensorrt_llm import SamplingParams
from tensorrt_llm._torch.auto_deploy.llm import LLM
from tensorrt_llm._torch.auto_deploy.models.eagle import EagleDrafterFactory
from tensorrt_llm.llmapi import DraftTargetDecodingConfig, Eagle3DecodingConfig, KvCacheConfig
prompts = [
@ -276,3 +280,171 @@ def test_autodeploy_eagle3_acceptance_rate():
print("\n" + "=" * 80)
print("SUCCESS! All requests passed acceptance rate threshold")
print("=" * 80)
def load_weights(model_path: Path, model: torch.nn.Module):
"""Load weights from checkpoint while applying the same _checkpoint_conversion_mapping that the factory uses.
Returns: tuple of (loaded_keys, missing_keys, unexpected_keys)
"""
# 1. Load checkpoint keys
bin_path = model_path / "pytorch_model.bin"
safetensors_path = model_path / "model.safetensors"
if safetensors_path.exists():
from safetensors import safe_open
with safe_open(safetensors_path, framework="pt") as f:
checkpoint_keys_original = list(f.keys())
elif bin_path.exists():
state_dict = torch.load(bin_path, map_location="cpu", weights_only=True)
checkpoint_keys_original = list(state_dict.keys())
del state_dict
else:
raise FileNotFoundError(f"No checkpoint found at {model_path}")
# 2. Apply _checkpoint_conversion_mapping (same logic as hf.py _remap_param_names_load_hook)
# This is the key part - the factory does this exact same thing in lines 496-512 of hf.py
conversion_mapping = getattr(model, "_checkpoint_conversion_mapping", None)
checkpoint_keys_remapped = []
for key in checkpoint_keys_original:
new_key = key
if conversion_mapping:
for pattern, replacement in conversion_mapping.items():
new_key = re.sub(pattern, replacement, new_key)
checkpoint_keys_remapped.append(new_key)
# 3. Get model's expected keys
model_keys = set(model.state_dict().keys())
checkpoint_keys = set(checkpoint_keys_remapped)
# 4. Calculate differences
loaded_keys = checkpoint_keys & model_keys
missing_in_checkpoint = model_keys - checkpoint_keys
unexpected_in_checkpoint = checkpoint_keys - model_keys
return loaded_keys, missing_in_checkpoint, unexpected_in_checkpoint
def test_eagle_model_with_weights():
"""Test EagleModel forward pass with loaded weights using the EagleDrafterFactory.
This test uses EagleDrafterFactory to initialize the model, which directly
builds the Eagle drafter model based on the checkpoint's model_type:
1. Factory creates config via AutoConfig.from_pretrained
2. Factory selects Eagle3DrafterForCausalLM based on model_type="llama"
3. Factory creates model via _from_config
4. Factory loads weights via load_or_random_init -> _load_checkpoint
This ensures the test validates the exact initialization path used in production.
"""
print("\n" + "=" * 80)
print("Test: EagleModel forward pass with loaded weights (via EagleDrafterFactory)")
print("=" * 80)
_, _, eagle_model_path = get_model_paths()
eagle_path = Path(eagle_model_path)
if not eagle_path.exists():
pytest.skip(f"Eagle model not found at {eagle_model_path}")
# Check for weights
bin_path = eagle_path / "pytorch_model.bin"
safetensors_path = eagle_path / "model.safetensors"
if not bin_path.exists() and not safetensors_path.exists():
pytest.skip(f"Weights not found at {eagle_model_path}")
# 1. Setup Device
device = "cuda" if torch.cuda.is_available() else "cpu"
# 2. Create factory
# EagleDrafterFactory directly builds the correct drafter model based on model_type
print("Creating EagleDrafterFactory...")
factory = EagleDrafterFactory(
model=eagle_model_path,
skip_loading_weights=False, # We want to test weight loading
)
# 3. Build model using factory
# Factory flow:
# build_model() -> prefetch_checkpoint() -> _build_model()
# _build_model() -> _get_model_config() (gets base LlamaConfig)
# _build_model() -> selects Eagle3DrafterForCausalLM for model_type="llama"
# _build_model() -> Eagle3DrafterForCausalLM._from_config(config)
print("Building model via factory.build_model('meta')...")
model = factory.build_model("meta")
print(f"Model type: {type(model).__name__}")
print(f"Model config type: {type(model.config).__name__}")
# 4. Load weights from checkpoint and compare to model's expected keys
print("\n--- Weight Loading Analysis ---")
loaded_keys, missing_keys, unexpected_keys = load_weights(eagle_path, model)
print(f"Total model parameters: {len(loaded_keys) + len(missing_keys)}")
print(f"Total checkpoint keys: {len(loaded_keys) + len(unexpected_keys)}")
print(f"✅ Weights to be loaded: {len(loaded_keys)}")
print(f"⚠️ Missing in checkpoint (will be random init): {len(missing_keys)}")
print(f"⚠️ Unexpected in checkpoint (will be ignored): {len(unexpected_keys)}")
if missing_keys:
print("\nMissing keys (model expects but checkpoint doesn't have):")
for key in sorted(missing_keys):
if "embed_tokens" in key:
print(f" - {key} (expected: shared from target model)")
elif "rotary_emb" in key:
print(f" - {key} (expected: computed at runtime)")
else:
print(f" - {key}")
if unexpected_keys:
print("\nUnexpected keys (in checkpoint but model doesn't expect):")
for key in sorted(unexpected_keys):
if "t2d" in key:
print(f" - {key} (expected: not used in Eagle3)")
else:
print(f" - {key}")
if loaded_keys:
print(f"\nLoaded keys ({len(loaded_keys)} total):")
for key in sorted(loaded_keys)[:10]:
print(f" - {key}")
if len(loaded_keys) > 10:
print(f" ... and {len(loaded_keys) - 10} more")
print("--- End Weight Analysis ---\n")
# Verify expected missing and unexpected keys
# These are the keys we expect based on Eagle3 architecture:
# - embed_tokens: shared from target model (not in Eagle checkpoint)
# - t2d: target-to-draft mapping, not used in Eagle3 (uses d2t instead)
expected_missing_keys = {"model.embed_tokens.weight"}
expected_unexpected_keys = {"model.t2d"}
assert missing_keys == expected_missing_keys, (
f"Unexpected missing keys.\n"
f"Expected: {expected_missing_keys}\n"
f"Got: {missing_keys}\n"
f"Extra missing: {missing_keys - expected_missing_keys}\n"
f"Not missing (but expected): {expected_missing_keys - missing_keys}"
)
assert unexpected_keys == expected_unexpected_keys, (
f"Unexpected keys in checkpoint.\n"
f"Expected: {expected_unexpected_keys}\n"
f"Got: {unexpected_keys}\n"
f"Extra unexpected: {unexpected_keys - expected_unexpected_keys}\n"
f"Not unexpected (but expected): {expected_unexpected_keys - unexpected_keys}"
)
print("✅ Weight loading analysis matches expected missing/unexpected keys!")
# 5. Load weights using factory (mimics actual pipeline)
# If tensor shapes do not match with how they are used in the forward() function, we will
# get an error.
print("Loading weights via factory.load_or_random_init()...")
factory.load_or_random_init(model, device)
print("Weights loaded successfully via factory interface!")
model.eval()

View File

@ -441,3 +441,4 @@ l0_h100:
- examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[draft_target]
- examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[eagle3]
- examples/test_ad_speculative_decoding.py::test_autodeploy_eagle3_acceptance_rate
- examples/test_ad_speculative_decoding.py::test_eagle_model_with_weights

View File

@ -484,6 +484,11 @@ _SMALL_MODEL_CONFIGS = {
"num_hidden_layers": 8,
},
},
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B": {
"model_kwargs": {
"hidden_size": 64,
}
},
}

View File

@ -0,0 +1,183 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for Eagle3 model with AutoDeploy."""
from pathlib import Path
import pytest
import torch
from _model_test_utils import get_small_model_config
from build_and_run_ad import ExperimentConfig, main
from transformers import AutoConfig
from tensorrt_llm._torch.auto_deploy.models.custom.modeling_eagle import (
Eagle3DrafterForCausalLM,
Eagle3Model,
)
from tensorrt_llm._torch.auto_deploy.models.eagle import EagleDrafterFactory
from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactoryRegistry
from tests.test_common.llm_data import hf_id_to_local_model_dir
EAGLE_MODEL_HUB_ID = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
###############################################################################
# Mock classes for standalone Eagle testing
#
# These classes enable unit testing the Eagle checkpoint without a target model.
# In production speculative decoding, real hidden states come from the target model.
# For testing, MockEagle3ModelForCausalLM generates random hidden states.
###############################################################################
class MockEagle3ModelForCausalLM(Eagle3DrafterForCausalLM):
"""Test wrapper that provides random hidden states for standalone Eagle testing.
In production speculative decoding, real hidden states come from the target model.
This mock class generates random hidden states for testing the Eagle model in isolation.
"""
def __init__(self, config):
super().__init__(config)
self._hidden_size = config.hidden_size
self._dtype = config.dtype
def forward(self, input_ids, **kwargs):
# Inject mock hidden states if not provided
if "hidden_states" not in kwargs:
batch_size, seq_len = input_ids.shape
kwargs["hidden_states"] = torch.randn(
(batch_size, seq_len, self._hidden_size),
dtype=self._dtype,
device=input_ids.device,
)
return super().forward(input_ids, **kwargs)
class MockEagleDrafterFactory(EagleDrafterFactory):
"""Test factory that uses MockEagle3ModelForCausalLM for standalone Eagle testing.
This factory overrides the drafter mapping to use the mock model class which
generates random hidden states, enabling testing without a target model.
"""
_drafter_model_mapping = {
"llama": MockEagle3ModelForCausalLM,
}
@pytest.fixture
def register_mock_eagle_factory():
"""Register MockEagleDrafterFactory for the test and clean up afterwards.
This fixture temporarily registers the mock factory with ModelFactoryRegistry,
allowing tests to use model_factory="MockEagleDrafter", and removes the
registration after the test completes.
"""
ModelFactoryRegistry._registry["MockEagleDrafter"] = MockEagleDrafterFactory
yield
ModelFactoryRegistry._registry.pop("MockEagleDrafter", None)
def test_build_ad_eagle(register_mock_eagle_factory):
"""Test building Eagle model with AutoDeploy using MockEagleDrafterFactory.
This test uses the MockEagleDrafterFactory which builds MockEagle3ModelForCausalLM,
a mock model that generates random hidden states for standalone Eagle testing.
"""
llm_extra_args = {
"model_factory": "MockEagleDrafter",
"transforms": {
"insert_cached_attention": {"backend": "flashinfer"},
"compile_model": {"backend": "torch-compile"},
},
}
experiment_config = get_small_model_config(EAGLE_MODEL_HUB_ID, **llm_extra_args)
experiment_config["args"]["runtime"] = "demollm"
experiment_config["args"]["world_size"] = 0
experiment_config["args"]["tokenizer"] = hf_id_to_local_model_dir(
"meta-llama/Meta-Llama-3.1-8B-Instruct"
)
print(f"Experiment Config: {experiment_config}")
experiment_config = ExperimentConfig(**experiment_config)
main(experiment_config)
def test_eagle_model_torch_export():
"""Test that Eagle3Model can be exported with torch.export.
This validates that the model architecture is compatible with
torch.export for potential TensorRT compilation.
Note: We skip loading weights since torch.export only traces the computation
graph (model architecture), not the actual weight values. Random init is fine.
"""
print("\n" + "=" * 80)
print("Test: EagleModel torch.export")
print("=" * 80)
eagle_model_path = hf_id_to_local_model_dir(EAGLE_MODEL_HUB_ID)
if eagle_model_path is None:
pytest.skip("Eagle model not found (LLM_MODELS_ROOT not set or model missing)")
eagle_path = Path(eagle_model_path)
# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16
# Use pretrained config (llama) - can provide it directly to instantiate Eagle3Model.
config_path = eagle_path / "config.json"
config = AutoConfig.from_pretrained(config_path)
# Create model with random weights (no need to load for export test)
model = Eagle3Model(config)
model.to(device)
model.eval()
# Create inputs for export
batch_size = 1
seq_len = 8
hidden_dim = config.hidden_size
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device, dtype=torch.long
)
position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
mock_hidden_states = torch.randn((batch_size, seq_len, hidden_dim), device=device, dtype=dtype)
print("Export input shapes:")
print(f" input_ids: {input_ids.shape}")
print(f" position_ids: {position_ids.shape}")
print(f" hidden_states: {mock_hidden_states.shape}")
example_args = (
input_ids,
position_ids,
mock_hidden_states,
)
# Attempt torch.export
try:
exported_program = torch.export.export(model, args=example_args)
print("✅ torch.export successful!")
print("Graph module code preview (first 20 lines):")
code_lines = exported_program.graph_module.code.split("\n")[:20]
print("\n".join(code_lines))
except Exception as e:
pytest.fail(f"torch.export failed: {e}")

View File

@ -21,7 +21,7 @@ from test_common.llm_data import with_mocked_hf_download
from tensorrt_llm.llmapi import DraftTargetDecodingConfig, KvCacheConfig
@pytest.mark.parametrize("use_hf_speculative_model", [False, True])
@pytest.mark.parametrize("use_hf_speculative_model", [False])
@with_mocked_hf_download
def test_ad_speculative_decoding_smoke(use_hf_speculative_model: bool):
"""Test speculative decoding with AutoDeploy using the build_and_run_ad main()."""