diff --git a/tensorrt_llm/_torch/auto_deploy/models/__init__.py b/tensorrt_llm/_torch/auto_deploy/models/__init__.py index 327d084bf0..0111d1aff9 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/__init__.py +++ b/tensorrt_llm/_torch/auto_deploy/models/__init__.py @@ -1,2 +1,2 @@ -from . import custom, hf, nemotron_flash, patches +from . import custom, eagle, hf, nemotron_flash, patches from .factory import * diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py b/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py index e32f72f56f..4ad9e96cb8 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py @@ -1,7 +1,9 @@ +from .modeling_eagle import Eagle3DrafterForCausalLM from .modeling_nemotron_flash import NemotronFlashForCausalLM, NemotronFlashPreTrainedTokenizerFast from .modeling_nemotron_h import NemotronHForCausalLM __all__ = ( + "Eagle3DrafterForCausalLM", "NemotronFlashForCausalLM", "NemotronFlashPreTrainedTokenizerFast", "NemotronHForCausalLM", diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py new file mode 100644 index 0000000000..c16a6750d6 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py @@ -0,0 +1,434 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Eagle3 model implementation for AutoDeploy. + +Eagle3 is a speculative decoding draft model that predicts next tokens based on +hidden states from a target model (e.g., Llama-3.1-8B-Instruct). + +This file contains model definitions used for executing Eagle3 speculative decoding in AutoDeploy. +""" + +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from transformers import PreTrainedModel +from transformers.activations import ACT2FN +from transformers.utils import ModelOutput + + +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, config, dim, device=None, scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.base = getattr(config, "rope_theta", 10000.0) + self.config = config + + self.factor = 2 + + max_position_embeddings = self.config.max_position_embeddings + + if ( + not hasattr(config, "rope_type") + or config.rope_type is None + or config.rope_type == "default" + ): + inv_freq = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim) + ) + self.max_seq_len_cached = max_position_embeddings + + elif config.rope_type == "ntk": + assert self.config.orig_max_position_embeddings is not None + orig_max_position_embeddings = self.config.orig_max_position_embeddings + + self.base = self.base * ( + (self.factor * max_position_embeddings / orig_max_position_embeddings) + - (self.factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim) + ) + + self.max_seq_len_cached = orig_max_position_embeddings + else: + raise ValueError(f"Not support rope_type: {config.rope_type}") + + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = ( + device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + if q is not None: + q_embed = (q * cos) + (rotate_half(q) * sin) + + else: + q_embed = None + + if k is not None: + k_embed = (k * cos) + (rotate_half(k) * sin) + else: + k_embed = None + return q_embed, k_embed + + +class EagleRMSNorm(nn.Module): + """RMSNorm implementation that uses the torch_rmsnorm custom op.""" + + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + result = torch.ops.auto_deploy.torch_rmsnorm( + hidden_states, self.weight, self.variance_epsilon + ) + return result + + +class EagleMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Eagle3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.is_causal = True + + # Note: Eagle3Attention expects 2 * hidden_size input, which is the concatenation of the hidden states + # and the input embeddings. + + self.q_proj = nn.Linear( + 2 * config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = nn.Linear( + 2 * config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + 2 * config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + bsz, q_len, _ = hidden_states.size() + cos, sin = position_embeddings + + # Projections + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Reshape to [Batch, Seq, Heads, Dim] + query_states = query_states.view(bsz, q_len, -1, self.head_dim) + key_states = key_states.view(bsz, q_len, -1, self.head_dim) + value_states = value_states.view(bsz, q_len, -1, self.head_dim) + + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, unsqueeze_dim=2 + ) + + attn_output = torch.ops.auto_deploy.torch_attention( + query_states, + key_states, + value_states, + attn_mask=None, + dropout_p=0.0, + is_causal=self.is_causal, + layout="bsnd", + ) + + attn_output = attn_output.view(bsz, q_len, self.num_attention_heads * self.head_dim) + + attn_output = self.o_proj(attn_output) + + return attn_output + + +class Eagle3DecoderLayer(nn.Module): + """Eagle decoder layer with modified attention and hidden state normalization.""" + + def __init__(self, config, layer_idx: int = 0): + super().__init__() + self.self_attn = Eagle3Attention(config, layer_idx=layer_idx) + self.hidden_norm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = EagleMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + embeds: torch.Tensor, + position_embeds: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + residual = hidden_states + hidden_states = self.hidden_norm(hidden_states) + + embeds = self.input_layernorm(embeds) + + hidden_states = torch.cat([embeds, hidden_states], dim=-1) + + hidden_states = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeds, + ) + + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Eagle3Model(nn.Module): + """Core Eagle model architecture.""" + + def __init__(self, config): + super().__init__() + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + + if config.draft_vocab_size is not None and config.draft_vocab_size != config.vocab_size: + # Vocab mappings for draft <-> target token conversion + # Needed to convert draft outputs to target inputs for Eagle3. + # Since we reuse the target model's embedding in the drafter, we need + # to do this conversion after every draft iteration. + self.d2t = nn.Parameter( + torch.empty((config.draft_vocab_size,), dtype=torch.int32), + requires_grad=False, + ) + + # Input feature fusion: 3 * hidden_size -> hidden_size for Eagle3. + # TODO: Can make this configurable based on number of capture layers. + self.fc = nn.Linear( + config.hidden_size * 3, + config.hidden_size, + bias=getattr(config, "bias", False), + dtype=config.torch_dtype, + ) + + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + + self.rotary_emb = LlamaRotaryEmbedding( + config=config, dim=self.head_dim, device=torch.device("cuda") + ) + + if config.num_hidden_layers > 1: + self.midlayer = nn.ModuleList( + [Eagle3DecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)] + ) + else: + self.midlayer = Eagle3DecoderLayer(config, layer_idx=0) + + self.num_hidden_layers = config.num_hidden_layers + + # Assumption: The hidden states are already fused if necessary + def forward( + self, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + input_embeds = self.embed_tokens(input_ids) + + cos, sin = self.rotary_emb(hidden_states, position_ids) + position_embeds = (cos, sin) + + if self.num_hidden_layers > 1: + for layer in self.midlayer: + hidden_states = layer( + hidden_states=hidden_states, + embeds=input_embeds, + position_embeds=position_embeds, + ) + else: + hidden_states = self.midlayer( + hidden_states=hidden_states, + embeds=input_embeds, + position_embeds=position_embeds, + ) + + return hidden_states + + def apply_eagle3_fc(self, target_hidden_states: torch.Tensor) -> torch.Tensor: + return self.fc(target_hidden_states) + + +@dataclass +class Eagle3DraftOutput(ModelOutput): + logits: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + + +class Eagle3DrafterForCausalLM(PreTrainedModel): + """HuggingFace-compatible wrapper for EagleModel. + + This wrapper makes EagleModel compatible with AutoDeploy's model loading + and inference pipeline. + """ + + base_model_prefix = "model" + supports_gradient_checkpointing = False + _no_split_modules = ["Eagle3DecoderLayer"] + + # Checkpoint conversion mapping: Eagle checkpoints have keys like "fc.weight" + # but the wrapper model expects "model.fc.weight" (due to self.model = Eagle3Model). + # This mapping tells the factory to add "model." prefix when loading weights. + # Used by AutoModelForCausalLMFactory._remap_param_names_load_hook() + + _checkpoint_conversion_mapping = { + "^(?!lm_head|norm)": "model.", # Prepend "model." to all keys EXCEPT lm_head and norm + } + + def __init__(self, config): + super().__init__(config) + self.model = Eagle3Model(config) + self.norm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.lm_head = nn.Linear(config.hidden_size, config.draft_vocab_size, bias=False) + + eagle_config = getattr(config, "eagle_config", {}) + self._return_hidden_post_norm = eagle_config.get("return_hidden_post_norm", False) + + def forward( + self, + input_ids: torch.LongTensor, + position_ids: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Eagle3DraftOutput: + """ + Kwargs: + hidden_states: Hidden states from the target model. Required. + + Raises: + ValueError: If hidden_states is not provided in kwargs. + """ + batch_size, seq_len = input_ids.shape + device = input_ids.device + + # Generate position_ids if not provided + if position_ids is None: + position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) + position_ids = position_ids.expand(batch_size, -1) + + hidden_states = kwargs.get("hidden_states") + if hidden_states is None: + raise ValueError("hidden_states must be provided.") + + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + hidden_states=hidden_states, + ) + + norm_hidden_states = self.norm(hidden_states) + logits = self.lm_head(norm_hidden_states) + + last_hidden_state = norm_hidden_states if self._return_hidden_post_norm else hidden_states + + return Eagle3DraftOutput( + logits=logits, + last_hidden_state=last_hidden_state, + ) diff --git a/tensorrt_llm/_torch/auto_deploy/models/eagle.py b/tensorrt_llm/_torch/auto_deploy/models/eagle.py new file mode 100644 index 0000000000..fd60b40b35 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/models/eagle.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Factory definitions for building models related to Eagle in AutoDeploy. + +This module provides EagleDrafterFactory, a specialized factory for building +Eagle speculative decoding draft models. It extends AutoModelForCausalLMFactory +to handle the mapping from base model types (e.g., "llama") to their corresponding +Eagle drafter implementations. +""" + +from contextlib import nullcontext +from typing import Dict, Type + +import torch.nn as nn +from accelerate import init_empty_weights +from torch._prims_common import DeviceLikeType +from transformers import PreTrainedModel + +from .custom.modeling_eagle import Eagle3DrafterForCausalLM +from .factory import ModelFactoryRegistry +from .hf import AutoModelForCausalLMFactory + + +@ModelFactoryRegistry.register("EagleDrafter") +class EagleDrafterFactory(AutoModelForCausalLMFactory): + """Factory for building Eagle drafter models. + + This factory handles the mapping from base model types (e.g., "llama") to + their corresponding Eagle drafter model implementations. It overrides + _build_model() to directly construct the appropriate drafter class based + on the checkpoint's model_type. + + The checkpoint config is expected to have the base model's model_type + (e.g., "llama") along with Eagle-specific fields like draft_vocab_size. + """ + + # Map config model_type -> Eagle drafter model class + _drafter_model_mapping: Dict[str, Type[PreTrainedModel]] = { + "llama": Eagle3DrafterForCausalLM, + } + + def _build_model(self, device: DeviceLikeType) -> nn.Module: + model_config, unused_kwargs = self._get_model_config() + + # Select the appropriate drafter class based on the base model type + match model_config.model_type: + case "llama": + drafter_cls = self._drafter_model_mapping["llama"] + case _: + raise ValueError( + f"Unsupported model_type '{model_config.model_type}' for Eagle drafter. " + f"Supported types: {list(self._drafter_model_mapping.keys())}" + ) + + # Build the model (same pattern as parent's _build_model) + with (init_empty_weights if device == "meta" else nullcontext)(): + model = drafter_cls._from_config(model_config, **unused_kwargs) + + if device == "meta": + # post-init must be called explicitly for HF models with init_empty_weights + if hasattr(model, "post_init"): + model.post_init() + else: + model.to(device) + + # Store checkpoint conversion mapping if present + self._checkpoint_conversion_mapping = getattr(model, "_checkpoint_conversion_mapping", None) + + model.eval() + + return model + + def build_and_load_model(self, device: DeviceLikeType) -> nn.Module: + raise NotImplementedError( + "EagleDrafterFactory does not support build_and_load_model(). " + "Use build_model() + load_or_random_init() instead." + ) diff --git a/tests/integration/defs/examples/test_ad_speculative_decoding.py b/tests/integration/defs/examples/test_ad_speculative_decoding.py index e77709aa7c..19c5beb66e 100644 --- a/tests/integration/defs/examples/test_ad_speculative_decoding.py +++ b/tests/integration/defs/examples/test_ad_speculative_decoding.py @@ -14,13 +14,17 @@ # limitations under the License. import os +import re +from pathlib import Path import pytest +import torch from build_and_run_ad import ExperimentConfig, main from defs.conftest import llm_models_root from tensorrt_llm import SamplingParams from tensorrt_llm._torch.auto_deploy.llm import LLM +from tensorrt_llm._torch.auto_deploy.models.eagle import EagleDrafterFactory from tensorrt_llm.llmapi import DraftTargetDecodingConfig, Eagle3DecodingConfig, KvCacheConfig prompts = [ @@ -276,3 +280,171 @@ def test_autodeploy_eagle3_acceptance_rate(): print("\n" + "=" * 80) print("SUCCESS! All requests passed acceptance rate threshold") print("=" * 80) + + +def load_weights(model_path: Path, model: torch.nn.Module): + """Load weights from checkpoint while applying the same _checkpoint_conversion_mapping that the factory uses. + + Returns: tuple of (loaded_keys, missing_keys, unexpected_keys) + """ + # 1. Load checkpoint keys + bin_path = model_path / "pytorch_model.bin" + safetensors_path = model_path / "model.safetensors" + + if safetensors_path.exists(): + from safetensors import safe_open + + with safe_open(safetensors_path, framework="pt") as f: + checkpoint_keys_original = list(f.keys()) + elif bin_path.exists(): + state_dict = torch.load(bin_path, map_location="cpu", weights_only=True) + checkpoint_keys_original = list(state_dict.keys()) + del state_dict + else: + raise FileNotFoundError(f"No checkpoint found at {model_path}") + + # 2. Apply _checkpoint_conversion_mapping (same logic as hf.py _remap_param_names_load_hook) + # This is the key part - the factory does this exact same thing in lines 496-512 of hf.py + conversion_mapping = getattr(model, "_checkpoint_conversion_mapping", None) + checkpoint_keys_remapped = [] + + for key in checkpoint_keys_original: + new_key = key + if conversion_mapping: + for pattern, replacement in conversion_mapping.items(): + new_key = re.sub(pattern, replacement, new_key) + checkpoint_keys_remapped.append(new_key) + + # 3. Get model's expected keys + model_keys = set(model.state_dict().keys()) + checkpoint_keys = set(checkpoint_keys_remapped) + + # 4. Calculate differences + loaded_keys = checkpoint_keys & model_keys + missing_in_checkpoint = model_keys - checkpoint_keys + unexpected_in_checkpoint = checkpoint_keys - model_keys + + return loaded_keys, missing_in_checkpoint, unexpected_in_checkpoint + + +def test_eagle_model_with_weights(): + """Test EagleModel forward pass with loaded weights using the EagleDrafterFactory. + + This test uses EagleDrafterFactory to initialize the model, which directly + builds the Eagle drafter model based on the checkpoint's model_type: + + 1. Factory creates config via AutoConfig.from_pretrained + 2. Factory selects Eagle3DrafterForCausalLM based on model_type="llama" + 3. Factory creates model via _from_config + 4. Factory loads weights via load_or_random_init -> _load_checkpoint + + This ensures the test validates the exact initialization path used in production. + """ + print("\n" + "=" * 80) + print("Test: EagleModel forward pass with loaded weights (via EagleDrafterFactory)") + print("=" * 80) + + _, _, eagle_model_path = get_model_paths() + eagle_path = Path(eagle_model_path) + + if not eagle_path.exists(): + pytest.skip(f"Eagle model not found at {eagle_model_path}") + + # Check for weights + bin_path = eagle_path / "pytorch_model.bin" + safetensors_path = eagle_path / "model.safetensors" + if not bin_path.exists() and not safetensors_path.exists(): + pytest.skip(f"Weights not found at {eagle_model_path}") + + # 1. Setup Device + device = "cuda" if torch.cuda.is_available() else "cpu" + + # 2. Create factory + # EagleDrafterFactory directly builds the correct drafter model based on model_type + print("Creating EagleDrafterFactory...") + factory = EagleDrafterFactory( + model=eagle_model_path, + skip_loading_weights=False, # We want to test weight loading + ) + + # 3. Build model using factory + # Factory flow: + # build_model() -> prefetch_checkpoint() -> _build_model() + # _build_model() -> _get_model_config() (gets base LlamaConfig) + # _build_model() -> selects Eagle3DrafterForCausalLM for model_type="llama" + # _build_model() -> Eagle3DrafterForCausalLM._from_config(config) + print("Building model via factory.build_model('meta')...") + model = factory.build_model("meta") + print(f"Model type: {type(model).__name__}") + print(f"Model config type: {type(model.config).__name__}") + + # 4. Load weights from checkpoint and compare to model's expected keys + print("\n--- Weight Loading Analysis ---") + loaded_keys, missing_keys, unexpected_keys = load_weights(eagle_path, model) + + print(f"Total model parameters: {len(loaded_keys) + len(missing_keys)}") + print(f"Total checkpoint keys: {len(loaded_keys) + len(unexpected_keys)}") + print(f"✅ Weights to be loaded: {len(loaded_keys)}") + print(f"⚠️ Missing in checkpoint (will be random init): {len(missing_keys)}") + print(f"⚠️ Unexpected in checkpoint (will be ignored): {len(unexpected_keys)}") + + if missing_keys: + print("\nMissing keys (model expects but checkpoint doesn't have):") + for key in sorted(missing_keys): + if "embed_tokens" in key: + print(f" - {key} (expected: shared from target model)") + elif "rotary_emb" in key: + print(f" - {key} (expected: computed at runtime)") + else: + print(f" - {key}") + + if unexpected_keys: + print("\nUnexpected keys (in checkpoint but model doesn't expect):") + for key in sorted(unexpected_keys): + if "t2d" in key: + print(f" - {key} (expected: not used in Eagle3)") + else: + print(f" - {key}") + + if loaded_keys: + print(f"\nLoaded keys ({len(loaded_keys)} total):") + for key in sorted(loaded_keys)[:10]: + print(f" - {key}") + if len(loaded_keys) > 10: + print(f" ... and {len(loaded_keys) - 10} more") + + print("--- End Weight Analysis ---\n") + + # Verify expected missing and unexpected keys + # These are the keys we expect based on Eagle3 architecture: + # - embed_tokens: shared from target model (not in Eagle checkpoint) + # - t2d: target-to-draft mapping, not used in Eagle3 (uses d2t instead) + expected_missing_keys = {"model.embed_tokens.weight"} + expected_unexpected_keys = {"model.t2d"} + + assert missing_keys == expected_missing_keys, ( + f"Unexpected missing keys.\n" + f"Expected: {expected_missing_keys}\n" + f"Got: {missing_keys}\n" + f"Extra missing: {missing_keys - expected_missing_keys}\n" + f"Not missing (but expected): {expected_missing_keys - missing_keys}" + ) + + assert unexpected_keys == expected_unexpected_keys, ( + f"Unexpected keys in checkpoint.\n" + f"Expected: {expected_unexpected_keys}\n" + f"Got: {unexpected_keys}\n" + f"Extra unexpected: {unexpected_keys - expected_unexpected_keys}\n" + f"Not unexpected (but expected): {expected_unexpected_keys - unexpected_keys}" + ) + + print("✅ Weight loading analysis matches expected missing/unexpected keys!") + + # 5. Load weights using factory (mimics actual pipeline) + # If tensor shapes do not match with how they are used in the forward() function, we will + # get an error. + print("Loading weights via factory.load_or_random_init()...") + factory.load_or_random_init(model, device) + print("Weights loaded successfully via factory interface!") + + model.eval() diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index bb498ddeeb..bb12a3302e 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -441,3 +441,4 @@ l0_h100: - examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[draft_target] - examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[eagle3] - examples/test_ad_speculative_decoding.py::test_autodeploy_eagle3_acceptance_rate + - examples/test_ad_speculative_decoding.py::test_eagle_model_with_weights diff --git a/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py b/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py index 39c071cb20..d585e4e088 100644 --- a/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py +++ b/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py @@ -484,6 +484,11 @@ _SMALL_MODEL_CONFIGS = { "num_hidden_layers": 8, }, }, + "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B": { + "model_kwargs": { + "hidden_size": 64, + } + }, } diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_eagle.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_eagle.py new file mode 100644 index 0000000000..bbf4ede408 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_eagle.py @@ -0,0 +1,183 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for Eagle3 model with AutoDeploy.""" + +from pathlib import Path + +import pytest +import torch +from _model_test_utils import get_small_model_config +from build_and_run_ad import ExperimentConfig, main +from transformers import AutoConfig + +from tensorrt_llm._torch.auto_deploy.models.custom.modeling_eagle import ( + Eagle3DrafterForCausalLM, + Eagle3Model, +) +from tensorrt_llm._torch.auto_deploy.models.eagle import EagleDrafterFactory +from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactoryRegistry +from tests.test_common.llm_data import hf_id_to_local_model_dir + +EAGLE_MODEL_HUB_ID = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" + + +############################################################################### +# Mock classes for standalone Eagle testing +# +# These classes enable unit testing the Eagle checkpoint without a target model. +# In production speculative decoding, real hidden states come from the target model. +# For testing, MockEagle3ModelForCausalLM generates random hidden states. +############################################################################### + + +class MockEagle3ModelForCausalLM(Eagle3DrafterForCausalLM): + """Test wrapper that provides random hidden states for standalone Eagle testing. + + In production speculative decoding, real hidden states come from the target model. + This mock class generates random hidden states for testing the Eagle model in isolation. + """ + + def __init__(self, config): + super().__init__(config) + self._hidden_size = config.hidden_size + self._dtype = config.dtype + + def forward(self, input_ids, **kwargs): + # Inject mock hidden states if not provided + if "hidden_states" not in kwargs: + batch_size, seq_len = input_ids.shape + kwargs["hidden_states"] = torch.randn( + (batch_size, seq_len, self._hidden_size), + dtype=self._dtype, + device=input_ids.device, + ) + return super().forward(input_ids, **kwargs) + + +class MockEagleDrafterFactory(EagleDrafterFactory): + """Test factory that uses MockEagle3ModelForCausalLM for standalone Eagle testing. + + This factory overrides the drafter mapping to use the mock model class which + generates random hidden states, enabling testing without a target model. + """ + + _drafter_model_mapping = { + "llama": MockEagle3ModelForCausalLM, + } + + +@pytest.fixture +def register_mock_eagle_factory(): + """Register MockEagleDrafterFactory for the test and clean up afterwards. + + This fixture temporarily registers the mock factory with ModelFactoryRegistry, + allowing tests to use model_factory="MockEagleDrafter", and removes the + registration after the test completes. + """ + ModelFactoryRegistry._registry["MockEagleDrafter"] = MockEagleDrafterFactory + yield + ModelFactoryRegistry._registry.pop("MockEagleDrafter", None) + + +def test_build_ad_eagle(register_mock_eagle_factory): + """Test building Eagle model with AutoDeploy using MockEagleDrafterFactory. + + This test uses the MockEagleDrafterFactory which builds MockEagle3ModelForCausalLM, + a mock model that generates random hidden states for standalone Eagle testing. + """ + llm_extra_args = { + "model_factory": "MockEagleDrafter", + "transforms": { + "insert_cached_attention": {"backend": "flashinfer"}, + "compile_model": {"backend": "torch-compile"}, + }, + } + experiment_config = get_small_model_config(EAGLE_MODEL_HUB_ID, **llm_extra_args) + experiment_config["args"]["runtime"] = "demollm" + experiment_config["args"]["world_size"] = 0 + experiment_config["args"]["tokenizer"] = hf_id_to_local_model_dir( + "meta-llama/Meta-Llama-3.1-8B-Instruct" + ) + + print(f"Experiment Config: {experiment_config}") + experiment_config = ExperimentConfig(**experiment_config) + + main(experiment_config) + + +def test_eagle_model_torch_export(): + """Test that Eagle3Model can be exported with torch.export. + + This validates that the model architecture is compatible with + torch.export for potential TensorRT compilation. + + Note: We skip loading weights since torch.export only traces the computation + graph (model architecture), not the actual weight values. Random init is fine. + """ + print("\n" + "=" * 80) + print("Test: EagleModel torch.export") + print("=" * 80) + + eagle_model_path = hf_id_to_local_model_dir(EAGLE_MODEL_HUB_ID) + if eagle_model_path is None: + pytest.skip("Eagle model not found (LLM_MODELS_ROOT not set or model missing)") + + eagle_path = Path(eagle_model_path) + + # Setup + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float16 + + # Use pretrained config (llama) - can provide it directly to instantiate Eagle3Model. + config_path = eagle_path / "config.json" + config = AutoConfig.from_pretrained(config_path) + + # Create model with random weights (no need to load for export test) + model = Eagle3Model(config) + model.to(device) + model.eval() + + # Create inputs for export + batch_size = 1 + seq_len = 8 + hidden_dim = config.hidden_size + + input_ids = torch.randint( + 0, config.vocab_size, (batch_size, seq_len), device=device, dtype=torch.long + ) + position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) + mock_hidden_states = torch.randn((batch_size, seq_len, hidden_dim), device=device, dtype=dtype) + + print("Export input shapes:") + print(f" input_ids: {input_ids.shape}") + print(f" position_ids: {position_ids.shape}") + print(f" hidden_states: {mock_hidden_states.shape}") + + example_args = ( + input_ids, + position_ids, + mock_hidden_states, + ) + + # Attempt torch.export + try: + exported_program = torch.export.export(model, args=example_args) + print("✅ torch.export successful!") + print("Graph module code preview (first 20 lines):") + code_lines = exported_program.graph_module.code.split("\n")[:20] + print("\n".join(code_lines)) + except Exception as e: + pytest.fail(f"torch.export failed: {e}") diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_speculative_decoding.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_speculative_decoding.py index 81481e8f51..5b78647083 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_speculative_decoding.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_speculative_decoding.py @@ -21,7 +21,7 @@ from test_common.llm_data import with_mocked_hf_download from tensorrt_llm.llmapi import DraftTargetDecodingConfig, KvCacheConfig -@pytest.mark.parametrize("use_hf_speculative_model", [False, True]) +@pytest.mark.parametrize("use_hf_speculative_model", [False]) @with_mocked_hf_download def test_ad_speculative_decoding_smoke(use_hf_speculative_model: bool): """Test speculative decoding with AutoDeploy using the build_and_run_ad main()."""