mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
[None][chore] AutoDeploy: Eagle One-Model [1/n]: PyTorch impl for Eagle3 Llama checkpoint (#10674)
Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
This commit is contained in:
parent
0ffa77af51
commit
744a955cbb
@ -1,2 +1,2 @@
|
||||
from . import custom, hf, nemotron_flash, patches
|
||||
from . import custom, eagle, hf, nemotron_flash, patches
|
||||
from .factory import *
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
from .modeling_eagle import Eagle3DrafterForCausalLM
|
||||
from .modeling_nemotron_flash import NemotronFlashForCausalLM, NemotronFlashPreTrainedTokenizerFast
|
||||
from .modeling_nemotron_h import NemotronHForCausalLM
|
||||
|
||||
__all__ = (
|
||||
"Eagle3DrafterForCausalLM",
|
||||
"NemotronFlashForCausalLM",
|
||||
"NemotronFlashPreTrainedTokenizerFast",
|
||||
"NemotronHForCausalLM",
|
||||
|
||||
434
tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py
Normal file
434
tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py
Normal file
@ -0,0 +1,434 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Eagle3 model implementation for AutoDeploy.
|
||||
|
||||
Eagle3 is a speculative decoding draft model that predicts next tokens based on
|
||||
hidden states from a target model (e.g., Llama-3.1-8B-Instruct).
|
||||
|
||||
This file contains model definitions used for executing Eagle3 speculative decoding in AutoDeploy.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.utils import ModelOutput
|
||||
|
||||
|
||||
class LlamaRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config, dim, device=None, scaling_factor=1.0):
|
||||
super().__init__()
|
||||
self.scaling_factor = scaling_factor
|
||||
self.dim = dim
|
||||
self.base = getattr(config, "rope_theta", 10000.0)
|
||||
self.config = config
|
||||
|
||||
self.factor = 2
|
||||
|
||||
max_position_embeddings = self.config.max_position_embeddings
|
||||
|
||||
if (
|
||||
not hasattr(config, "rope_type")
|
||||
or config.rope_type is None
|
||||
or config.rope_type == "default"
|
||||
):
|
||||
inv_freq = 1.0 / (
|
||||
self.base
|
||||
** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)
|
||||
)
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
|
||||
elif config.rope_type == "ntk":
|
||||
assert self.config.orig_max_position_embeddings is not None
|
||||
orig_max_position_embeddings = self.config.orig_max_position_embeddings
|
||||
|
||||
self.base = self.base * (
|
||||
(self.factor * max_position_embeddings / orig_max_position_embeddings)
|
||||
- (self.factor - 1)
|
||||
) ** (self.dim / (self.dim - 2))
|
||||
inv_freq = 1.0 / (
|
||||
self.base
|
||||
** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)
|
||||
)
|
||||
|
||||
self.max_seq_len_cached = orig_max_position_embeddings
|
||||
else:
|
||||
raise ValueError(f"Not support rope_type: {config.rope_type}")
|
||||
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
inv_freq_expanded = (
|
||||
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 since bfloat16 loses precision on long contexts
|
||||
# See https://github.com/huggingface/transformers/pull/29285
|
||||
device_type = x.device.type
|
||||
device_type = (
|
||||
device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
)
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
if q is not None:
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
|
||||
else:
|
||||
q_embed = None
|
||||
|
||||
if k is not None:
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
else:
|
||||
k_embed = None
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class EagleRMSNorm(nn.Module):
|
||||
"""RMSNorm implementation that uses the torch_rmsnorm custom op."""
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
result = torch.ops.auto_deploy.torch_rmsnorm(
|
||||
hidden_states, self.weight, self.variance_epsilon
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class EagleMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
class Eagle3Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(
|
||||
config, "head_dim", config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.is_causal = True
|
||||
|
||||
# Note: Eagle3Attention expects 2 * hidden_size input, which is the concatenation of the hidden states
|
||||
# and the input embeddings.
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
2 * config.hidden_size,
|
||||
config.num_attention_heads * self.head_dim,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
2 * config.hidden_size,
|
||||
config.num_key_value_heads * self.head_dim,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
2 * config.hidden_size,
|
||||
config.num_key_value_heads * self.head_dim,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
config.num_attention_heads * self.head_dim,
|
||||
config.hidden_size,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
cos, sin = position_embeddings
|
||||
|
||||
# Projections
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# Reshape to [Batch, Seq, Heads, Dim]
|
||||
query_states = query_states.view(bsz, q_len, -1, self.head_dim)
|
||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim)
|
||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim)
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=2
|
||||
)
|
||||
|
||||
attn_output = torch.ops.auto_deploy.torch_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=self.is_causal,
|
||||
layout="bsnd",
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, q_len, self.num_attention_heads * self.head_dim)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class Eagle3DecoderLayer(nn.Module):
|
||||
"""Eagle decoder layer with modified attention and hidden state normalization."""
|
||||
|
||||
def __init__(self, config, layer_idx: int = 0):
|
||||
super().__init__()
|
||||
self.self_attn = Eagle3Attention(config, layer_idx=layer_idx)
|
||||
self.hidden_norm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.input_layernorm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.mlp = EagleMLP(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
embeds: torch.Tensor,
|
||||
position_embeds: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.hidden_norm(hidden_states)
|
||||
|
||||
embeds = self.input_layernorm(embeds)
|
||||
|
||||
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeds,
|
||||
)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Eagle3Model(nn.Module):
|
||||
"""Core Eagle model architecture."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
|
||||
if config.draft_vocab_size is not None and config.draft_vocab_size != config.vocab_size:
|
||||
# Vocab mappings for draft <-> target token conversion
|
||||
# Needed to convert draft outputs to target inputs for Eagle3.
|
||||
# Since we reuse the target model's embedding in the drafter, we need
|
||||
# to do this conversion after every draft iteration.
|
||||
self.d2t = nn.Parameter(
|
||||
torch.empty((config.draft_vocab_size,), dtype=torch.int32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# Input feature fusion: 3 * hidden_size -> hidden_size for Eagle3.
|
||||
# TODO: Can make this configurable based on number of capture layers.
|
||||
self.fc = nn.Linear(
|
||||
config.hidden_size * 3,
|
||||
config.hidden_size,
|
||||
bias=getattr(config, "bias", False),
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
|
||||
self.head_dim = getattr(
|
||||
config, "head_dim", config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
|
||||
self.rotary_emb = LlamaRotaryEmbedding(
|
||||
config=config, dim=self.head_dim, device=torch.device("cuda")
|
||||
)
|
||||
|
||||
if config.num_hidden_layers > 1:
|
||||
self.midlayer = nn.ModuleList(
|
||||
[Eagle3DecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
||||
)
|
||||
else:
|
||||
self.midlayer = Eagle3DecoderLayer(config, layer_idx=0)
|
||||
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
|
||||
# Assumption: The hidden states are already fused if necessary
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
position_ids: torch.LongTensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
input_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
cos, sin = self.rotary_emb(hidden_states, position_ids)
|
||||
position_embeds = (cos, sin)
|
||||
|
||||
if self.num_hidden_layers > 1:
|
||||
for layer in self.midlayer:
|
||||
hidden_states = layer(
|
||||
hidden_states=hidden_states,
|
||||
embeds=input_embeds,
|
||||
position_embeds=position_embeds,
|
||||
)
|
||||
else:
|
||||
hidden_states = self.midlayer(
|
||||
hidden_states=hidden_states,
|
||||
embeds=input_embeds,
|
||||
position_embeds=position_embeds,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def apply_eagle3_fc(self, target_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return self.fc(target_hidden_states)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Eagle3DraftOutput(ModelOutput):
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
class Eagle3DrafterForCausalLM(PreTrainedModel):
|
||||
"""HuggingFace-compatible wrapper for EagleModel.
|
||||
|
||||
This wrapper makes EagleModel compatible with AutoDeploy's model loading
|
||||
and inference pipeline.
|
||||
"""
|
||||
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = False
|
||||
_no_split_modules = ["Eagle3DecoderLayer"]
|
||||
|
||||
# Checkpoint conversion mapping: Eagle checkpoints have keys like "fc.weight"
|
||||
# but the wrapper model expects "model.fc.weight" (due to self.model = Eagle3Model).
|
||||
# This mapping tells the factory to add "model." prefix when loading weights.
|
||||
# Used by AutoModelForCausalLMFactory._remap_param_names_load_hook()
|
||||
|
||||
_checkpoint_conversion_mapping = {
|
||||
"^(?!lm_head|norm)": "model.", # Prepend "model." to all keys EXCEPT lm_head and norm
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = Eagle3Model(config)
|
||||
self.norm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.draft_vocab_size, bias=False)
|
||||
|
||||
eagle_config = getattr(config, "eagle_config", {})
|
||||
self._return_hidden_post_norm = eagle_config.get("return_hidden_post_norm", False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Eagle3DraftOutput:
|
||||
"""
|
||||
Kwargs:
|
||||
hidden_states: Hidden states from the target model. Required.
|
||||
|
||||
Raises:
|
||||
ValueError: If hidden_states is not provided in kwargs.
|
||||
"""
|
||||
batch_size, seq_len = input_ids.shape
|
||||
device = input_ids.device
|
||||
|
||||
# Generate position_ids if not provided
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
|
||||
position_ids = position_ids.expand(batch_size, -1)
|
||||
|
||||
hidden_states = kwargs.get("hidden_states")
|
||||
if hidden_states is None:
|
||||
raise ValueError("hidden_states must be provided.")
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
norm_hidden_states = self.norm(hidden_states)
|
||||
logits = self.lm_head(norm_hidden_states)
|
||||
|
||||
last_hidden_state = norm_hidden_states if self._return_hidden_post_norm else hidden_states
|
||||
|
||||
return Eagle3DraftOutput(
|
||||
logits=logits,
|
||||
last_hidden_state=last_hidden_state,
|
||||
)
|
||||
90
tensorrt_llm/_torch/auto_deploy/models/eagle.py
Normal file
90
tensorrt_llm/_torch/auto_deploy/models/eagle.py
Normal file
@ -0,0 +1,90 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Factory definitions for building models related to Eagle in AutoDeploy.
|
||||
|
||||
This module provides EagleDrafterFactory, a specialized factory for building
|
||||
Eagle speculative decoding draft models. It extends AutoModelForCausalLMFactory
|
||||
to handle the mapping from base model types (e.g., "llama") to their corresponding
|
||||
Eagle drafter implementations.
|
||||
"""
|
||||
|
||||
from contextlib import nullcontext
|
||||
from typing import Dict, Type
|
||||
|
||||
import torch.nn as nn
|
||||
from accelerate import init_empty_weights
|
||||
from torch._prims_common import DeviceLikeType
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from .custom.modeling_eagle import Eagle3DrafterForCausalLM
|
||||
from .factory import ModelFactoryRegistry
|
||||
from .hf import AutoModelForCausalLMFactory
|
||||
|
||||
|
||||
@ModelFactoryRegistry.register("EagleDrafter")
|
||||
class EagleDrafterFactory(AutoModelForCausalLMFactory):
|
||||
"""Factory for building Eagle drafter models.
|
||||
|
||||
This factory handles the mapping from base model types (e.g., "llama") to
|
||||
their corresponding Eagle drafter model implementations. It overrides
|
||||
_build_model() to directly construct the appropriate drafter class based
|
||||
on the checkpoint's model_type.
|
||||
|
||||
The checkpoint config is expected to have the base model's model_type
|
||||
(e.g., "llama") along with Eagle-specific fields like draft_vocab_size.
|
||||
"""
|
||||
|
||||
# Map config model_type -> Eagle drafter model class
|
||||
_drafter_model_mapping: Dict[str, Type[PreTrainedModel]] = {
|
||||
"llama": Eagle3DrafterForCausalLM,
|
||||
}
|
||||
|
||||
def _build_model(self, device: DeviceLikeType) -> nn.Module:
|
||||
model_config, unused_kwargs = self._get_model_config()
|
||||
|
||||
# Select the appropriate drafter class based on the base model type
|
||||
match model_config.model_type:
|
||||
case "llama":
|
||||
drafter_cls = self._drafter_model_mapping["llama"]
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unsupported model_type '{model_config.model_type}' for Eagle drafter. "
|
||||
f"Supported types: {list(self._drafter_model_mapping.keys())}"
|
||||
)
|
||||
|
||||
# Build the model (same pattern as parent's _build_model)
|
||||
with (init_empty_weights if device == "meta" else nullcontext)():
|
||||
model = drafter_cls._from_config(model_config, **unused_kwargs)
|
||||
|
||||
if device == "meta":
|
||||
# post-init must be called explicitly for HF models with init_empty_weights
|
||||
if hasattr(model, "post_init"):
|
||||
model.post_init()
|
||||
else:
|
||||
model.to(device)
|
||||
|
||||
# Store checkpoint conversion mapping if present
|
||||
self._checkpoint_conversion_mapping = getattr(model, "_checkpoint_conversion_mapping", None)
|
||||
|
||||
model.eval()
|
||||
|
||||
return model
|
||||
|
||||
def build_and_load_model(self, device: DeviceLikeType) -> nn.Module:
|
||||
raise NotImplementedError(
|
||||
"EagleDrafterFactory does not support build_and_load_model(). "
|
||||
"Use build_model() + load_or_random_init() instead."
|
||||
)
|
||||
@ -14,13 +14,17 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from build_and_run_ad import ExperimentConfig, main
|
||||
from defs.conftest import llm_models_root
|
||||
|
||||
from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch.auto_deploy.llm import LLM
|
||||
from tensorrt_llm._torch.auto_deploy.models.eagle import EagleDrafterFactory
|
||||
from tensorrt_llm.llmapi import DraftTargetDecodingConfig, Eagle3DecodingConfig, KvCacheConfig
|
||||
|
||||
prompts = [
|
||||
@ -276,3 +280,171 @@ def test_autodeploy_eagle3_acceptance_rate():
|
||||
print("\n" + "=" * 80)
|
||||
print("SUCCESS! All requests passed acceptance rate threshold")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
def load_weights(model_path: Path, model: torch.nn.Module):
|
||||
"""Load weights from checkpoint while applying the same _checkpoint_conversion_mapping that the factory uses.
|
||||
|
||||
Returns: tuple of (loaded_keys, missing_keys, unexpected_keys)
|
||||
"""
|
||||
# 1. Load checkpoint keys
|
||||
bin_path = model_path / "pytorch_model.bin"
|
||||
safetensors_path = model_path / "model.safetensors"
|
||||
|
||||
if safetensors_path.exists():
|
||||
from safetensors import safe_open
|
||||
|
||||
with safe_open(safetensors_path, framework="pt") as f:
|
||||
checkpoint_keys_original = list(f.keys())
|
||||
elif bin_path.exists():
|
||||
state_dict = torch.load(bin_path, map_location="cpu", weights_only=True)
|
||||
checkpoint_keys_original = list(state_dict.keys())
|
||||
del state_dict
|
||||
else:
|
||||
raise FileNotFoundError(f"No checkpoint found at {model_path}")
|
||||
|
||||
# 2. Apply _checkpoint_conversion_mapping (same logic as hf.py _remap_param_names_load_hook)
|
||||
# This is the key part - the factory does this exact same thing in lines 496-512 of hf.py
|
||||
conversion_mapping = getattr(model, "_checkpoint_conversion_mapping", None)
|
||||
checkpoint_keys_remapped = []
|
||||
|
||||
for key in checkpoint_keys_original:
|
||||
new_key = key
|
||||
if conversion_mapping:
|
||||
for pattern, replacement in conversion_mapping.items():
|
||||
new_key = re.sub(pattern, replacement, new_key)
|
||||
checkpoint_keys_remapped.append(new_key)
|
||||
|
||||
# 3. Get model's expected keys
|
||||
model_keys = set(model.state_dict().keys())
|
||||
checkpoint_keys = set(checkpoint_keys_remapped)
|
||||
|
||||
# 4. Calculate differences
|
||||
loaded_keys = checkpoint_keys & model_keys
|
||||
missing_in_checkpoint = model_keys - checkpoint_keys
|
||||
unexpected_in_checkpoint = checkpoint_keys - model_keys
|
||||
|
||||
return loaded_keys, missing_in_checkpoint, unexpected_in_checkpoint
|
||||
|
||||
|
||||
def test_eagle_model_with_weights():
|
||||
"""Test EagleModel forward pass with loaded weights using the EagleDrafterFactory.
|
||||
|
||||
This test uses EagleDrafterFactory to initialize the model, which directly
|
||||
builds the Eagle drafter model based on the checkpoint's model_type:
|
||||
|
||||
1. Factory creates config via AutoConfig.from_pretrained
|
||||
2. Factory selects Eagle3DrafterForCausalLM based on model_type="llama"
|
||||
3. Factory creates model via _from_config
|
||||
4. Factory loads weights via load_or_random_init -> _load_checkpoint
|
||||
|
||||
This ensures the test validates the exact initialization path used in production.
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print("Test: EagleModel forward pass with loaded weights (via EagleDrafterFactory)")
|
||||
print("=" * 80)
|
||||
|
||||
_, _, eagle_model_path = get_model_paths()
|
||||
eagle_path = Path(eagle_model_path)
|
||||
|
||||
if not eagle_path.exists():
|
||||
pytest.skip(f"Eagle model not found at {eagle_model_path}")
|
||||
|
||||
# Check for weights
|
||||
bin_path = eagle_path / "pytorch_model.bin"
|
||||
safetensors_path = eagle_path / "model.safetensors"
|
||||
if not bin_path.exists() and not safetensors_path.exists():
|
||||
pytest.skip(f"Weights not found at {eagle_model_path}")
|
||||
|
||||
# 1. Setup Device
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# 2. Create factory
|
||||
# EagleDrafterFactory directly builds the correct drafter model based on model_type
|
||||
print("Creating EagleDrafterFactory...")
|
||||
factory = EagleDrafterFactory(
|
||||
model=eagle_model_path,
|
||||
skip_loading_weights=False, # We want to test weight loading
|
||||
)
|
||||
|
||||
# 3. Build model using factory
|
||||
# Factory flow:
|
||||
# build_model() -> prefetch_checkpoint() -> _build_model()
|
||||
# _build_model() -> _get_model_config() (gets base LlamaConfig)
|
||||
# _build_model() -> selects Eagle3DrafterForCausalLM for model_type="llama"
|
||||
# _build_model() -> Eagle3DrafterForCausalLM._from_config(config)
|
||||
print("Building model via factory.build_model('meta')...")
|
||||
model = factory.build_model("meta")
|
||||
print(f"Model type: {type(model).__name__}")
|
||||
print(f"Model config type: {type(model.config).__name__}")
|
||||
|
||||
# 4. Load weights from checkpoint and compare to model's expected keys
|
||||
print("\n--- Weight Loading Analysis ---")
|
||||
loaded_keys, missing_keys, unexpected_keys = load_weights(eagle_path, model)
|
||||
|
||||
print(f"Total model parameters: {len(loaded_keys) + len(missing_keys)}")
|
||||
print(f"Total checkpoint keys: {len(loaded_keys) + len(unexpected_keys)}")
|
||||
print(f"✅ Weights to be loaded: {len(loaded_keys)}")
|
||||
print(f"⚠️ Missing in checkpoint (will be random init): {len(missing_keys)}")
|
||||
print(f"⚠️ Unexpected in checkpoint (will be ignored): {len(unexpected_keys)}")
|
||||
|
||||
if missing_keys:
|
||||
print("\nMissing keys (model expects but checkpoint doesn't have):")
|
||||
for key in sorted(missing_keys):
|
||||
if "embed_tokens" in key:
|
||||
print(f" - {key} (expected: shared from target model)")
|
||||
elif "rotary_emb" in key:
|
||||
print(f" - {key} (expected: computed at runtime)")
|
||||
else:
|
||||
print(f" - {key}")
|
||||
|
||||
if unexpected_keys:
|
||||
print("\nUnexpected keys (in checkpoint but model doesn't expect):")
|
||||
for key in sorted(unexpected_keys):
|
||||
if "t2d" in key:
|
||||
print(f" - {key} (expected: not used in Eagle3)")
|
||||
else:
|
||||
print(f" - {key}")
|
||||
|
||||
if loaded_keys:
|
||||
print(f"\nLoaded keys ({len(loaded_keys)} total):")
|
||||
for key in sorted(loaded_keys)[:10]:
|
||||
print(f" - {key}")
|
||||
if len(loaded_keys) > 10:
|
||||
print(f" ... and {len(loaded_keys) - 10} more")
|
||||
|
||||
print("--- End Weight Analysis ---\n")
|
||||
|
||||
# Verify expected missing and unexpected keys
|
||||
# These are the keys we expect based on Eagle3 architecture:
|
||||
# - embed_tokens: shared from target model (not in Eagle checkpoint)
|
||||
# - t2d: target-to-draft mapping, not used in Eagle3 (uses d2t instead)
|
||||
expected_missing_keys = {"model.embed_tokens.weight"}
|
||||
expected_unexpected_keys = {"model.t2d"}
|
||||
|
||||
assert missing_keys == expected_missing_keys, (
|
||||
f"Unexpected missing keys.\n"
|
||||
f"Expected: {expected_missing_keys}\n"
|
||||
f"Got: {missing_keys}\n"
|
||||
f"Extra missing: {missing_keys - expected_missing_keys}\n"
|
||||
f"Not missing (but expected): {expected_missing_keys - missing_keys}"
|
||||
)
|
||||
|
||||
assert unexpected_keys == expected_unexpected_keys, (
|
||||
f"Unexpected keys in checkpoint.\n"
|
||||
f"Expected: {expected_unexpected_keys}\n"
|
||||
f"Got: {unexpected_keys}\n"
|
||||
f"Extra unexpected: {unexpected_keys - expected_unexpected_keys}\n"
|
||||
f"Not unexpected (but expected): {expected_unexpected_keys - unexpected_keys}"
|
||||
)
|
||||
|
||||
print("✅ Weight loading analysis matches expected missing/unexpected keys!")
|
||||
|
||||
# 5. Load weights using factory (mimics actual pipeline)
|
||||
# If tensor shapes do not match with how they are used in the forward() function, we will
|
||||
# get an error.
|
||||
print("Loading weights via factory.load_or_random_init()...")
|
||||
factory.load_or_random_init(model, device)
|
||||
print("Weights loaded successfully via factory interface!")
|
||||
|
||||
model.eval()
|
||||
|
||||
@ -441,3 +441,4 @@ l0_h100:
|
||||
- examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[draft_target]
|
||||
- examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[eagle3]
|
||||
- examples/test_ad_speculative_decoding.py::test_autodeploy_eagle3_acceptance_rate
|
||||
- examples/test_ad_speculative_decoding.py::test_eagle_model_with_weights
|
||||
|
||||
@ -484,6 +484,11 @@ _SMALL_MODEL_CONFIGS = {
|
||||
"num_hidden_layers": 8,
|
||||
},
|
||||
},
|
||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B": {
|
||||
"model_kwargs": {
|
||||
"hidden_size": 64,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,183 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for Eagle3 model with AutoDeploy."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from _model_test_utils import get_small_model_config
|
||||
from build_and_run_ad import ExperimentConfig, main
|
||||
from transformers import AutoConfig
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.models.custom.modeling_eagle import (
|
||||
Eagle3DrafterForCausalLM,
|
||||
Eagle3Model,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.models.eagle import EagleDrafterFactory
|
||||
from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactoryRegistry
|
||||
from tests.test_common.llm_data import hf_id_to_local_model_dir
|
||||
|
||||
EAGLE_MODEL_HUB_ID = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Mock classes for standalone Eagle testing
|
||||
#
|
||||
# These classes enable unit testing the Eagle checkpoint without a target model.
|
||||
# In production speculative decoding, real hidden states come from the target model.
|
||||
# For testing, MockEagle3ModelForCausalLM generates random hidden states.
|
||||
###############################################################################
|
||||
|
||||
|
||||
class MockEagle3ModelForCausalLM(Eagle3DrafterForCausalLM):
|
||||
"""Test wrapper that provides random hidden states for standalone Eagle testing.
|
||||
|
||||
In production speculative decoding, real hidden states come from the target model.
|
||||
This mock class generates random hidden states for testing the Eagle model in isolation.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self._hidden_size = config.hidden_size
|
||||
self._dtype = config.dtype
|
||||
|
||||
def forward(self, input_ids, **kwargs):
|
||||
# Inject mock hidden states if not provided
|
||||
if "hidden_states" not in kwargs:
|
||||
batch_size, seq_len = input_ids.shape
|
||||
kwargs["hidden_states"] = torch.randn(
|
||||
(batch_size, seq_len, self._hidden_size),
|
||||
dtype=self._dtype,
|
||||
device=input_ids.device,
|
||||
)
|
||||
return super().forward(input_ids, **kwargs)
|
||||
|
||||
|
||||
class MockEagleDrafterFactory(EagleDrafterFactory):
|
||||
"""Test factory that uses MockEagle3ModelForCausalLM for standalone Eagle testing.
|
||||
|
||||
This factory overrides the drafter mapping to use the mock model class which
|
||||
generates random hidden states, enabling testing without a target model.
|
||||
"""
|
||||
|
||||
_drafter_model_mapping = {
|
||||
"llama": MockEagle3ModelForCausalLM,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def register_mock_eagle_factory():
|
||||
"""Register MockEagleDrafterFactory for the test and clean up afterwards.
|
||||
|
||||
This fixture temporarily registers the mock factory with ModelFactoryRegistry,
|
||||
allowing tests to use model_factory="MockEagleDrafter", and removes the
|
||||
registration after the test completes.
|
||||
"""
|
||||
ModelFactoryRegistry._registry["MockEagleDrafter"] = MockEagleDrafterFactory
|
||||
yield
|
||||
ModelFactoryRegistry._registry.pop("MockEagleDrafter", None)
|
||||
|
||||
|
||||
def test_build_ad_eagle(register_mock_eagle_factory):
|
||||
"""Test building Eagle model with AutoDeploy using MockEagleDrafterFactory.
|
||||
|
||||
This test uses the MockEagleDrafterFactory which builds MockEagle3ModelForCausalLM,
|
||||
a mock model that generates random hidden states for standalone Eagle testing.
|
||||
"""
|
||||
llm_extra_args = {
|
||||
"model_factory": "MockEagleDrafter",
|
||||
"transforms": {
|
||||
"insert_cached_attention": {"backend": "flashinfer"},
|
||||
"compile_model": {"backend": "torch-compile"},
|
||||
},
|
||||
}
|
||||
experiment_config = get_small_model_config(EAGLE_MODEL_HUB_ID, **llm_extra_args)
|
||||
experiment_config["args"]["runtime"] = "demollm"
|
||||
experiment_config["args"]["world_size"] = 0
|
||||
experiment_config["args"]["tokenizer"] = hf_id_to_local_model_dir(
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
)
|
||||
|
||||
print(f"Experiment Config: {experiment_config}")
|
||||
experiment_config = ExperimentConfig(**experiment_config)
|
||||
|
||||
main(experiment_config)
|
||||
|
||||
|
||||
def test_eagle_model_torch_export():
|
||||
"""Test that Eagle3Model can be exported with torch.export.
|
||||
|
||||
This validates that the model architecture is compatible with
|
||||
torch.export for potential TensorRT compilation.
|
||||
|
||||
Note: We skip loading weights since torch.export only traces the computation
|
||||
graph (model architecture), not the actual weight values. Random init is fine.
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print("Test: EagleModel torch.export")
|
||||
print("=" * 80)
|
||||
|
||||
eagle_model_path = hf_id_to_local_model_dir(EAGLE_MODEL_HUB_ID)
|
||||
if eagle_model_path is None:
|
||||
pytest.skip("Eagle model not found (LLM_MODELS_ROOT not set or model missing)")
|
||||
|
||||
eagle_path = Path(eagle_model_path)
|
||||
|
||||
# Setup
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dtype = torch.float16
|
||||
|
||||
# Use pretrained config (llama) - can provide it directly to instantiate Eagle3Model.
|
||||
config_path = eagle_path / "config.json"
|
||||
config = AutoConfig.from_pretrained(config_path)
|
||||
|
||||
# Create model with random weights (no need to load for export test)
|
||||
model = Eagle3Model(config)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# Create inputs for export
|
||||
batch_size = 1
|
||||
seq_len = 8
|
||||
hidden_dim = config.hidden_size
|
||||
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, seq_len), device=device, dtype=torch.long
|
||||
)
|
||||
position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
|
||||
mock_hidden_states = torch.randn((batch_size, seq_len, hidden_dim), device=device, dtype=dtype)
|
||||
|
||||
print("Export input shapes:")
|
||||
print(f" input_ids: {input_ids.shape}")
|
||||
print(f" position_ids: {position_ids.shape}")
|
||||
print(f" hidden_states: {mock_hidden_states.shape}")
|
||||
|
||||
example_args = (
|
||||
input_ids,
|
||||
position_ids,
|
||||
mock_hidden_states,
|
||||
)
|
||||
|
||||
# Attempt torch.export
|
||||
try:
|
||||
exported_program = torch.export.export(model, args=example_args)
|
||||
print("✅ torch.export successful!")
|
||||
print("Graph module code preview (first 20 lines):")
|
||||
code_lines = exported_program.graph_module.code.split("\n")[:20]
|
||||
print("\n".join(code_lines))
|
||||
except Exception as e:
|
||||
pytest.fail(f"torch.export failed: {e}")
|
||||
@ -21,7 +21,7 @@ from test_common.llm_data import with_mocked_hf_download
|
||||
from tensorrt_llm.llmapi import DraftTargetDecodingConfig, KvCacheConfig
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_hf_speculative_model", [False, True])
|
||||
@pytest.mark.parametrize("use_hf_speculative_model", [False])
|
||||
@with_mocked_hf_download
|
||||
def test_ad_speculative_decoding_smoke(use_hf_speculative_model: bool):
|
||||
"""Test speculative decoding with AutoDeploy using the build_and_run_ad main()."""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user