mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
This commit is contained in:
parent
3ef8a4639b
commit
585fbb2734
@ -22,14 +22,59 @@ This file contains model definitions used for executing Eagle3 speculative decod
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PreTrainedModel
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.utils import ModelOutput
|
||||
|
||||
from ...utils._config import deep_merge_dicts
|
||||
from ...utils.logger import ad_logger
|
||||
|
||||
|
||||
class EagleConfig(PretrainedConfig):
|
||||
"""Config for Eagle3 drafter models.
|
||||
|
||||
Extends PretrainedConfig with Eagle-specific parameters while preserving
|
||||
all base model config values.
|
||||
|
||||
Args:
|
||||
config: Base config for the draft model from its config.json.
|
||||
model_type: The base model type (e.g., "llama") used to look up defaults.
|
||||
"""
|
||||
|
||||
# Map model_type -> default Eagle config values
|
||||
_drafter_defaults: Dict[str, Dict[str, Any]] = {
|
||||
"llama": {
|
||||
"load_embedding_from_target": True,
|
||||
"load_lm_head_from_target": False,
|
||||
"num_capture_layers": 3,
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, config: PretrainedConfig, model_type: str):
|
||||
if model_type not in self._drafter_defaults:
|
||||
raise ValueError(
|
||||
f"Unsupported model_type '{model_type}' for EagleConfig. "
|
||||
f"Supported types: {list(self._drafter_defaults.keys())}"
|
||||
)
|
||||
|
||||
defaults = self._drafter_defaults[model_type]
|
||||
config_dict = config.to_dict()
|
||||
|
||||
# Log when config overrides a default
|
||||
for key, value in defaults.items():
|
||||
if key in config_dict and config_dict[key] != value:
|
||||
ad_logger.info(
|
||||
f"EagleConfig: config has '{key}={config_dict[key]}', "
|
||||
f"overriding default '{value}'"
|
||||
)
|
||||
|
||||
merged = deep_merge_dicts(defaults, config_dict)
|
||||
super().__init__(**merged)
|
||||
|
||||
|
||||
class LlamaRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config, dim, device=None, scaling_factor=1.0):
|
||||
@ -247,6 +292,7 @@ class Eagle3DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config, layer_idx: int = 0):
|
||||
super().__init__()
|
||||
self.dtype = config.torch_dtype
|
||||
self.self_attn = Eagle3Attention(config, layer_idx=layer_idx)
|
||||
self.hidden_norm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.input_layernorm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -287,7 +333,14 @@ class Eagle3Model(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.dtype = config.torch_dtype
|
||||
|
||||
load_embedding_from_target = getattr(config, "load_embedding_from_target", False)
|
||||
self.embed_tokens = (
|
||||
None
|
||||
if load_embedding_from_target
|
||||
else nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
)
|
||||
|
||||
if config.draft_vocab_size is not None and config.draft_vocab_size != config.vocab_size:
|
||||
# Vocab mappings for draft <-> target token conversion
|
||||
@ -299,13 +352,17 @@ class Eagle3Model(nn.Module):
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# Input feature fusion: 3 * hidden_size -> hidden_size for Eagle3.
|
||||
# TODO: Can make this configurable based on number of capture layers.
|
||||
self.fc = nn.Linear(
|
||||
config.hidden_size * 3,
|
||||
config.hidden_size,
|
||||
bias=getattr(config, "bias", False),
|
||||
dtype=config.torch_dtype,
|
||||
# Hidden size compression for target hidden states.
|
||||
# Assumption: No feedforward fusion needed if we have just one capture layer (valid for MTPEagle)
|
||||
self.fc = (
|
||||
nn.Linear(
|
||||
config.hidden_size * config.num_capture_layers,
|
||||
config.hidden_size,
|
||||
bias=getattr(config, "bias", False),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
if config.num_capture_layers > 1
|
||||
else None
|
||||
)
|
||||
|
||||
self.head_dim = getattr(
|
||||
@ -328,12 +385,10 @@ class Eagle3Model(nn.Module):
|
||||
# Assumption: The hidden states are already fused if necessary
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.LongTensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
input_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
cos, sin = self.rotary_emb(hidden_states, position_ids)
|
||||
position_embeds = (cos, sin)
|
||||
|
||||
@ -341,25 +396,23 @@ class Eagle3Model(nn.Module):
|
||||
for layer in self.midlayer:
|
||||
hidden_states = layer(
|
||||
hidden_states=hidden_states,
|
||||
embeds=input_embeds,
|
||||
embeds=inputs_embeds,
|
||||
position_embeds=position_embeds,
|
||||
)
|
||||
else:
|
||||
hidden_states = self.midlayer(
|
||||
hidden_states=hidden_states,
|
||||
embeds=input_embeds,
|
||||
embeds=inputs_embeds,
|
||||
position_embeds=position_embeds,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def apply_eagle3_fc(self, target_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return self.fc(target_hidden_states)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Eagle3DraftOutput(ModelOutput):
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
norm_hidden_state: Optional[torch.FloatTensor] = None
|
||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
@ -385,16 +438,24 @@ class Eagle3DrafterForCausalLM(PreTrainedModel):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.load_embedding_from_target = getattr(config, "load_embedding_from_target", False)
|
||||
self.load_lm_head_from_target = getattr(config, "load_lm_head_from_target", False)
|
||||
|
||||
self.model = Eagle3Model(config)
|
||||
self.norm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.draft_vocab_size, bias=False)
|
||||
self.lm_head = (
|
||||
None
|
||||
if self.load_lm_head_from_target
|
||||
else nn.Linear(config.hidden_size, config.draft_vocab_size, bias=False)
|
||||
)
|
||||
|
||||
eagle_config = getattr(config, "eagle_config", {})
|
||||
self._return_hidden_post_norm = eagle_config.get("return_hidden_post_norm", False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
inputs_embeds: torch.LongTensor,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Eagle3DraftOutput:
|
||||
@ -405,8 +466,8 @@ class Eagle3DrafterForCausalLM(PreTrainedModel):
|
||||
Raises:
|
||||
ValueError: If hidden_states is not provided in kwargs.
|
||||
"""
|
||||
batch_size, seq_len = input_ids.shape
|
||||
device = input_ids.device
|
||||
batch_size, seq_len, _ = inputs_embeds.shape
|
||||
device = inputs_embeds.device
|
||||
|
||||
# Generate position_ids if not provided
|
||||
if position_ids is None:
|
||||
@ -418,17 +479,419 @@ class Eagle3DrafterForCausalLM(PreTrainedModel):
|
||||
raise ValueError("hidden_states must be provided.")
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
inputs_embeds=inputs_embeds, position_ids=position_ids, hidden_states=hidden_states
|
||||
)
|
||||
|
||||
norm_hidden_states = self.norm(hidden_states)
|
||||
logits = self.lm_head(norm_hidden_states)
|
||||
norm_hidden_state = self.norm(hidden_states)
|
||||
|
||||
last_hidden_state = norm_hidden_states if self._return_hidden_post_norm else hidden_states
|
||||
last_hidden_state = norm_hidden_state if self._return_hidden_post_norm else hidden_states
|
||||
|
||||
return Eagle3DraftOutput(
|
||||
logits=logits,
|
||||
norm_hidden_state=norm_hidden_state,
|
||||
last_hidden_state=last_hidden_state,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
if self.model.embed_tokens is not None:
|
||||
return self.model.embed_tokens
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Eagle3DrafterForCausalLM does not have an input embedding layer."
|
||||
)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
if self.lm_head is not None:
|
||||
return self.lm_head
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Eagle3DrafterForCausalLM does not have an output embedding layer."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EagleWrapperOutput(ModelOutput):
|
||||
"""Output format compatible with Eagle3OneModelSampler/MTPSampler.
|
||||
|
||||
This output format allows the one-model speculative decoding flow to bypass
|
||||
logits-based sampling in the sampler. The EagleWrapper performs all sampling
|
||||
and verification internally, returning pre-computed tokens.
|
||||
"""
|
||||
|
||||
# logits: [batch_size, 1, vocab_size]. Used for compatibility.
|
||||
logits: Optional[torch.Tensor] = None
|
||||
|
||||
# new_tokens: [batch_size, max_draft_len + 1]. Accepted tokens from verification.
|
||||
# This is a 2D tensor where each row contains the accepted tokens for a request,
|
||||
# padded if fewer tokens were accepted.
|
||||
new_tokens: Optional[torch.Tensor] = None
|
||||
|
||||
# new_tokens_lens: [batch_size]. Number of newly accepted tokens per request in this iteration.
|
||||
new_tokens_lens: Optional[torch.Tensor] = None
|
||||
|
||||
# next_draft_tokens: [batch_size, max_draft_len]. Draft tokens for the next iteration.
|
||||
# These are the tokens predicted by the draft model, already converted via d2t.
|
||||
next_draft_tokens: Optional[torch.Tensor] = None
|
||||
|
||||
# next_new_tokens: [batch_size, max_draft_len + 1]. Input tokens for the next iteration.
|
||||
# Format: [last_accepted_token, draft_token_0, draft_token_1, ...]
|
||||
next_new_tokens: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EagleWrapperConfig:
|
||||
max_draft_len: int
|
||||
load_embedding_from_target: bool
|
||||
load_lm_head_from_target: bool
|
||||
|
||||
|
||||
class EagleWrapper(nn.Module):
|
||||
def __init__(self, config, target_model, draft_model, resource_manager):
|
||||
super().__init__()
|
||||
self.target_model = target_model
|
||||
self.draft_model = draft_model
|
||||
self.resource_manager = resource_manager
|
||||
self.max_draft_len = config.max_draft_len
|
||||
self.load_embedding_from_target = config.load_embedding_from_target
|
||||
self.load_lm_head_from_target = config.load_lm_head_from_target
|
||||
|
||||
def apply_eagle3_fc(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply the fc layer that fuses hidden states from multiple target layers."""
|
||||
draft_model = self.draft_model.model
|
||||
hidden_states = hidden_states.to(draft_model.dtype)
|
||||
|
||||
fc = getattr(draft_model, "fc", None)
|
||||
if fc is not None:
|
||||
hidden_states = fc(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def apply_d2t(self, draft_output_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply draft-to-target token mapping if available."""
|
||||
d2t = getattr(self.draft_model.model, "d2t", None)
|
||||
if d2t is not None:
|
||||
draft_output_ids = d2t[draft_output_ids] + draft_output_ids
|
||||
return draft_output_ids
|
||||
|
||||
def apply_draft_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply embedding to input_ids for the draft model."""
|
||||
if self.load_embedding_from_target:
|
||||
embeds = self.target_model.get_input_embeddings()(input_ids)
|
||||
return embeds.to(self.draft_model.dtype)
|
||||
else:
|
||||
return self.draft_model.get_input_embeddings()(input_ids)
|
||||
|
||||
def apply_lm_head(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply lm_head to get logits from hidden states."""
|
||||
if self.load_lm_head_from_target:
|
||||
lm_head_weights = self.target_model.get_output_embeddings()(hidden_states)
|
||||
return lm_head_weights.to(self.draft_model.dtype)
|
||||
else:
|
||||
return self.draft_model.get_output_embeddings()(hidden_states)
|
||||
|
||||
def sample_greedy(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
ret = torch.argmax(logits, dim=-1)
|
||||
return ret
|
||||
|
||||
def sample_and_verify(
|
||||
self, input_ids, target_logits: torch.Tensor, num_previously_accepted: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
input_ids: [batch_size, seq_len]
|
||||
target_logits: [batch_size, seq_len, vocab_size]
|
||||
num_previously_accepted: [batch_size]. Number of input tokens accepted so far for each batch.
|
||||
|
||||
Returns:
|
||||
output_ids: [batch_size, seq_len] (result of greedy sampling on input ids)
|
||||
num_newly_accepted_tokens: [batch_size]. Number of newly accepted tokens in each batch.
|
||||
num_accepted_tokens: [batch_size]. Number of tokens accepted in each batch, including previously accepted.
|
||||
So num_accepted_tokens[i] = num_previously_accepted + num_newly_accepted_tokens.
|
||||
last_logits_3d: [batch_size, 1, vocab_size]. The logit used to sample the bonus token.
|
||||
|
||||
How it works:
|
||||
- Get input ids that were not previously accepted.
|
||||
- Get the corresponding target logits to these input ids (target_logit[j-1] corresponds to input_ids[j])
|
||||
- Sample a token from the logits for each batch and compare to input_ids to get the newly accepted tokens.
|
||||
- The output_ids consist of the previously accepted tokens, the newly accepted tokens,
|
||||
and a newly sampled token after the last accepted token.
|
||||
"""
|
||||
|
||||
batch_size, seq_len = input_ids.shape
|
||||
|
||||
# First, check that num_previously_accepted is <= seq_len for each batch
|
||||
# Additionally, num_previously_accepted should be >= 1 for each batch,
|
||||
# which corresponds to having some context tokens (context tokens are always accepted).
|
||||
assert (num_previously_accepted >= 1).all(), (
|
||||
"num_previously_accepted must be >= 1. Please provide non-empty context in each batch."
|
||||
)
|
||||
assert (num_previously_accepted <= seq_len).all(), (
|
||||
"num_previously_accepted must be <= seq_len for each batch"
|
||||
)
|
||||
|
||||
# We get input tokens that were not yet accepted.
|
||||
unchecked_input_ids = [
|
||||
input_ids[i, num_previously_accepted[i] : seq_len].unsqueeze(0)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
|
||||
# We get the corresponding target logits for the unchecked input tokens.
|
||||
# logit j-1 corresponds to input j.
|
||||
# Note that because of our check that num_previously_accepted is >= 1
|
||||
# We also get the last output token for each batch, because we may need to append it
|
||||
# at the end.
|
||||
unchecked_target_logits = [
|
||||
target_logits[i, (num_previously_accepted[i] - 1) : seq_len, :].unsqueeze(0)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
|
||||
unchecked_output_ids = [self.sample_greedy(x) for x in unchecked_target_logits]
|
||||
|
||||
# corresponding_output_ids: [batch_size, seq_len - 1]. The output ids that correspond to the unchecked input ids
|
||||
# Omits the last index because that corresponds to a freshly sampled output model.
|
||||
corresponding_output_ids = [output_id[:, :-1] for output_id in unchecked_output_ids]
|
||||
|
||||
# After sample_greedy, corresponding_output_ids should have same shape as unchecked_input_ids
|
||||
assert [x.shape for x in unchecked_input_ids] == [
|
||||
x.shape for x in corresponding_output_ids
|
||||
], "unchecked_input_ids and corresponding_output_ids must have the same shape"
|
||||
|
||||
matches = [
|
||||
(corresponding_output_ids[i] == unchecked_input_ids[i]).int() for i in range(batch_size)
|
||||
]
|
||||
|
||||
# Compute num_newly_accepted_tokens per batch (handles different sizes across batches)
|
||||
num_newly_accepted_tokens = []
|
||||
for i in range(batch_size):
|
||||
if matches[i].numel() == 0:
|
||||
# No unchecked tokens for this batch (num_previously_accepted == seq_len)
|
||||
num_newly_accepted_tokens.append(
|
||||
torch.tensor(0, dtype=torch.long, device=input_ids.device)
|
||||
)
|
||||
else:
|
||||
# prefix_matches[j] is 1 if first j+1 tokens all matched
|
||||
prefix_matches = matches[i].cumprod(dim=-1)
|
||||
num_newly_accepted_tokens.append(prefix_matches.sum().long())
|
||||
num_newly_accepted_tokens = torch.stack(num_newly_accepted_tokens)
|
||||
|
||||
# num_accepted_tokens: [batch_size]. The total number of accepted tokens in each batch,
|
||||
# including previously accepted tokens.
|
||||
num_accepted_tokens = num_previously_accepted + num_newly_accepted_tokens
|
||||
|
||||
assert (num_accepted_tokens <= seq_len).all(), (
|
||||
"num_accepted_tokens must be <= seq_len for each batch"
|
||||
)
|
||||
|
||||
# Construct draft_input_ids for the draft model
|
||||
# For each sequence:
|
||||
# 1. Take previously accepted tokens (skipping the first one)
|
||||
# 2. Append newly accepted tokens directly from input_ids.
|
||||
# 3. Append the sampled token for last accepted position: unchecked_output_ids[0][num_newly_accepted]
|
||||
# 4. Fill the rest with zeros (padding)
|
||||
# Total real tokens: (num_previously_accepted - 1) + num_newly_accepted + 1 = num_accepted_tokens
|
||||
|
||||
draft_input_ids = torch.zeros(
|
||||
(batch_size, seq_len), dtype=input_ids.dtype, device=input_ids.device
|
||||
)
|
||||
|
||||
for i in range(batch_size):
|
||||
# 1. Previously accepted tokens (skip the first one in keeping with Eagle convention)
|
||||
# Note that this potentially includes context tokens, but is structured this way because we
|
||||
# want the output to contain the entire prefix of accepted tokens because the drafters have no KV cache.
|
||||
prev_accepted = input_ids[i, 1 : num_previously_accepted[i]]
|
||||
|
||||
# 2. Newly accepted input tokens
|
||||
newly_accepted = input_ids[
|
||||
i,
|
||||
num_previously_accepted[i] : num_previously_accepted[i]
|
||||
+ num_newly_accepted_tokens[i],
|
||||
]
|
||||
|
||||
# 3. The sampled output token for the last accepted position
|
||||
# unchecked_output_ids[i][j] is the sampled token for position (num_previously_accepted + j)
|
||||
# We want the token for position num_accepted_tokens, which is index num_newly_accepted_tokens
|
||||
next_token = unchecked_output_ids[i][0][num_newly_accepted_tokens[i]].unsqueeze(0)
|
||||
|
||||
# Concatenate all parts
|
||||
draft_prefix = torch.cat([prev_accepted, newly_accepted, next_token])
|
||||
|
||||
# Sanity check: draft_prefix length should equal num_accepted_tokens
|
||||
assert draft_prefix.shape[0] == num_accepted_tokens[i], (
|
||||
f"draft_prefix length {draft_prefix.shape[0]} != num_accepted_tokens {num_accepted_tokens[i]}"
|
||||
)
|
||||
|
||||
# Fill into draft_input_ids (rest remains zeros as padding)
|
||||
draft_input_ids[i, : num_accepted_tokens[i]] = draft_prefix
|
||||
|
||||
# Construct last_logits_3d: [batch_size, 1, vocab_size]
|
||||
# This is the logit used to sample the bonus token for each sequence.
|
||||
# The bonus token is sampled from unchecked_target_logits[i][0][num_newly_accepted_tokens[i]]
|
||||
last_logits_list = []
|
||||
for i in range(batch_size):
|
||||
# unchecked_target_logits[i] has shape [1, num_unchecked + 1, vocab_size]
|
||||
# Index num_newly_accepted_tokens[i] gives the logit for the bonus token
|
||||
bonus_logit = unchecked_target_logits[i][0, num_newly_accepted_tokens[i], :].unsqueeze(
|
||||
0
|
||||
)
|
||||
last_logits_list.append(bonus_logit)
|
||||
last_logits_3d = torch.stack(last_logits_list, dim=0) # [batch_size, 1, vocab_size]
|
||||
|
||||
return draft_input_ids, num_newly_accepted_tokens, num_accepted_tokens, last_logits_3d
|
||||
|
||||
def forward(self, input_ids, position_ids, **kwargs):
|
||||
"""Dispatch to appropriate forward implementation based on kwargs.
|
||||
|
||||
If num_previously_accepted is provided, use the prefill-only (no KV cache) implementation.
|
||||
Otherwise, this is the KV cache case which is not yet implemented.
|
||||
"""
|
||||
num_previously_accepted = kwargs.get("num_previously_accepted", None)
|
||||
|
||||
if num_previously_accepted is not None:
|
||||
return self._forward_prefill_only(input_ids, position_ids, **kwargs)
|
||||
else:
|
||||
# KV cached case - not implemented yet
|
||||
raise NotImplementedError(
|
||||
"EagleWrapper forward with KV cache is not implemented. "
|
||||
"This code path is reached when num_previously_accepted is not provided in kwargs."
|
||||
)
|
||||
|
||||
def _forward_prefill_only(self, input_ids, position_ids, **kwargs):
|
||||
"""Forward pass without KV cache (prefill-only mode).
|
||||
|
||||
This is the original implementation that recomputes all attention
|
||||
from scratch on every forward call.
|
||||
"""
|
||||
batch_size, seq_len = input_ids.shape
|
||||
device = input_ids.device
|
||||
|
||||
num_previously_accepted = kwargs.get("num_previously_accepted", None)
|
||||
if num_previously_accepted is None:
|
||||
raise ValueError("num_previously_accepted must be provided for prefill-only mode.")
|
||||
|
||||
# Compute embeddings using the target embedding layer
|
||||
input_embeds = self.target_model.get_input_embeddings()(input_ids)
|
||||
|
||||
# target_logits: [batch_size, seq_len, vocab_size]
|
||||
# Pass embeddings to target model instead of input_ids
|
||||
target_logits = self.target_model(
|
||||
inputs_embeds=input_embeds, position_ids=position_ids
|
||||
).logits
|
||||
|
||||
# output_ids: [batch_size, seq_len]. Contains a prefix of accepted tokens from the target model,
|
||||
# a generated token from the target model, and some padding to fill out the tensor.
|
||||
# num_accepted_tokens: [batch_size]. The number of accepted tokens in each batch.
|
||||
# num_newly_accepted_tokens: [batch_size]. The number of newly accepted tokens in each batch.
|
||||
|
||||
output_ids, num_newly_accepted_tokens, num_accepted_tokens, _ = self.sample_and_verify(
|
||||
input_ids, target_logits, num_previously_accepted
|
||||
)
|
||||
|
||||
# Get hidden states from the resource manager
|
||||
# resource_manager.hidden_states is [max_tokens, hidden_size * num_capture_layers] (flattened)
|
||||
# We slice to get [batch_size * seq_len, hidden_size * num_capture_layers]
|
||||
hidden_states = self.resource_manager.hidden_states[: (batch_size * seq_len), :]
|
||||
|
||||
# Apply eagle3 fc to reduce hidden size.
|
||||
# Note: Since we are in prefill-only mode, this is extremely wasteful - we will apply the eagle3 fc layer
|
||||
# to hidden states that we have applied it to previously. But, this is generally the case in prefill-only mode.
|
||||
# Input: [batch_size * seq_len, hidden_size * num_capture_layers]
|
||||
# Output: [batch_size * seq_len, hidden_size]
|
||||
hidden_states = self.apply_eagle3_fc(hidden_states)
|
||||
|
||||
# Reshape from [batch_size * seq_len, hidden_size] to [batch_size, seq_len, hidden_size]
|
||||
hidden_size = hidden_states.shape[-1]
|
||||
hidden_states = hidden_states.view(batch_size, seq_len, hidden_size)
|
||||
|
||||
# Create a working buffer for the drafting loop in [batch, seq + draft_len, hidden] format.
|
||||
# This is separate from resource_manager.hidden_states which remains in flattened format.
|
||||
all_hidden_states = torch.zeros(
|
||||
(batch_size, seq_len + self.max_draft_len, hidden_size),
|
||||
device=device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
# Copy the initial hidden states from target model
|
||||
all_hidden_states[:, :seq_len, :] = hidden_states
|
||||
|
||||
# Construct our inputs for the drafting loop.
|
||||
# We want tensors that will be able to hold all the tokens we draft.
|
||||
|
||||
dummy_input_ids = torch.zeros(
|
||||
(batch_size, self.max_draft_len), device=device, dtype=output_ids.dtype
|
||||
)
|
||||
|
||||
# draft_input_ids: [batch_size, seq_len + self.max_draft_len]
|
||||
draft_input_ids = torch.cat((output_ids, dummy_input_ids), dim=1)
|
||||
|
||||
draft_position_ids = 1 + torch.arange(
|
||||
self.max_draft_len, device=device, dtype=torch.long
|
||||
).unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
draft_position_ids = draft_position_ids + position_ids[:, -1:].expand(
|
||||
-1, self.max_draft_len
|
||||
)
|
||||
|
||||
# draft_position_ids: [batch_size, seq_len + self.max_draft_len]
|
||||
# These position ids will work throughout the drafting loop.
|
||||
draft_position_ids = torch.cat((position_ids, draft_position_ids), dim=1)
|
||||
|
||||
# The number of tokens currently in the draft input ids. Possibly includes padding.
|
||||
curr_num_tokens = seq_len
|
||||
|
||||
# [batch_size]
|
||||
# The number of valid tokens currently in the draft input ids (does not include padding).
|
||||
curr_valid_tokens = num_accepted_tokens.clone()
|
||||
|
||||
batch_indices = torch.arange(batch_size, device=device)
|
||||
|
||||
for _ in range(self.max_draft_len):
|
||||
# Get the input ids, position ids, and hidden states for the current tokens.
|
||||
# size of tensor is constant for the current iteration and constant across dimensions (curr_num_tokens)
|
||||
# These tensors may correspond to padding tokens, but due to the causality of the draft model,
|
||||
# we can extract the draft tokens and hidden states corresponding to the valid tokens.
|
||||
|
||||
input_ids = draft_input_ids[:, :curr_num_tokens]
|
||||
position_ids = draft_position_ids[:, :curr_num_tokens]
|
||||
hidden_states = all_hidden_states[:, :curr_num_tokens, :]
|
||||
|
||||
inputs_embeds = self.apply_draft_embedding(input_ids)
|
||||
draft_output = self.draft_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
draft_output_logits = self.apply_lm_head(draft_output.norm_hidden_state)
|
||||
|
||||
# get the output logits for the latest valid token in each batch
|
||||
# It is at curr_valid_tokens-1 due to 0-indexing.
|
||||
latest_draft_logits = draft_output_logits[batch_indices, curr_valid_tokens - 1, :]
|
||||
|
||||
# draft_output_tokens: [batch_size, 1]
|
||||
draft_output_tokens = self.sample_greedy(latest_draft_logits)
|
||||
|
||||
# if the lm_head outputs tokens from the draft vocab, we need to convert them to tokens
|
||||
# from the target vocab before the next iteration.
|
||||
draft_output_tokens = self.apply_d2t(draft_output_tokens)
|
||||
|
||||
# insert the draft output tokens into the draft input ids.
|
||||
draft_input_ids[batch_indices, curr_valid_tokens] = draft_output_tokens
|
||||
|
||||
# Similarly, we want the hidden state for the latest drafted token in each batch.
|
||||
# This is a draft hidden state for the token that was just created from the latest valid token.
|
||||
|
||||
# [batch_size, seq_len + self.max_draft_len, hidden_size]
|
||||
all_hidden_states[batch_indices, curr_valid_tokens, :] = draft_output.last_hidden_state[
|
||||
batch_indices, curr_valid_tokens - 1, :
|
||||
]
|
||||
|
||||
curr_valid_tokens = curr_valid_tokens + 1
|
||||
curr_num_tokens = curr_num_tokens + 1
|
||||
|
||||
# Return the full draft_input_ids tensor for each batch element.
|
||||
# The valid prefix within each tensor has length:
|
||||
# num_previously_accepted[i] + num_newly_accepted_tokens[i] + max_draft_len
|
||||
# Callers should use this to slice out the valid tokens if needed.
|
||||
new_tokens = [draft_input_ids[i] for i in range(batch_size)]
|
||||
|
||||
return EagleWrapperOutput(
|
||||
new_tokens=new_tokens,
|
||||
new_tokens_lens=num_newly_accepted_tokens,
|
||||
)
|
||||
|
||||
@ -22,14 +22,14 @@ Eagle drafter implementations.
|
||||
"""
|
||||
|
||||
from contextlib import nullcontext
|
||||
from typing import Dict, Type
|
||||
from typing import Dict
|
||||
|
||||
import torch.nn as nn
|
||||
from accelerate import init_empty_weights
|
||||
from torch._prims_common import DeviceLikeType
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from .custom.modeling_eagle import Eagle3DrafterForCausalLM
|
||||
from ..utils.logger import ad_logger
|
||||
from .custom.modeling_eagle import Eagle3DrafterForCausalLM, EagleConfig
|
||||
from .factory import ModelFactoryRegistry
|
||||
from .hf import AutoModelForCausalLMFactory
|
||||
|
||||
@ -47,23 +47,28 @@ class EagleDrafterFactory(AutoModelForCausalLMFactory):
|
||||
(e.g., "llama") along with Eagle-specific fields like draft_vocab_size.
|
||||
"""
|
||||
|
||||
# Map config model_type -> Eagle drafter model class
|
||||
_drafter_model_mapping: Dict[str, Type[PreTrainedModel]] = {
|
||||
_drafter_classes: Dict[str, type] = {
|
||||
"llama": Eagle3DrafterForCausalLM,
|
||||
}
|
||||
|
||||
def _build_model(self, device: DeviceLikeType) -> nn.Module:
|
||||
model_config, unused_kwargs = self._get_model_config()
|
||||
|
||||
# Select the appropriate drafter class based on the base model type
|
||||
match model_config.model_type:
|
||||
case "llama":
|
||||
drafter_cls = self._drafter_model_mapping["llama"]
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unsupported model_type '{model_config.model_type}' for Eagle drafter. "
|
||||
f"Supported types: {list(self._drafter_model_mapping.keys())}"
|
||||
)
|
||||
# Select the appropriate drafter class and config based on the base model type
|
||||
model_type = model_config.model_type
|
||||
if model_type not in self._drafter_classes:
|
||||
raise ValueError(
|
||||
f"Unsupported model_type '{model_type}' for Eagle drafter. "
|
||||
f"Supported types: {list(self._drafter_classes.keys())}"
|
||||
)
|
||||
drafter_cls = self._drafter_classes[model_type]
|
||||
ad_logger.info(
|
||||
f"EagleDrafterFactory: model_type='{model_type}' -> drafter_cls={drafter_cls.__name__}"
|
||||
)
|
||||
|
||||
# Convert base config to EagleConfig, preserving existing values
|
||||
# and applying model-specific defaults based on model_type
|
||||
model_config = EagleConfig(model_config, model_type)
|
||||
|
||||
# Build the model (same pattern as parent's _build_model)
|
||||
with (init_empty_weights if device == "meta" else nullcontext)():
|
||||
@ -83,7 +88,7 @@ class EagleDrafterFactory(AutoModelForCausalLMFactory):
|
||||
|
||||
return model
|
||||
|
||||
def build_and_load_model(self, device: DeviceLikeType) -> nn.Module:
|
||||
def build_and_load_model(self, _device: DeviceLikeType) -> nn.Module:
|
||||
raise NotImplementedError(
|
||||
"EagleDrafterFactory does not support build_and_load_model(). "
|
||||
"Use build_model() + load_or_random_init() instead."
|
||||
|
||||
@ -15,15 +15,27 @@
|
||||
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Set
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from build_and_run_ad import ExperimentConfig, main
|
||||
from defs.conftest import llm_models_root
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.masking_utils import create_causal_mask
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import LlamaModel
|
||||
from transformers.utils.generic import ModelOutput
|
||||
|
||||
from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch.auto_deploy.llm import LLM
|
||||
from tensorrt_llm._torch.auto_deploy.models.custom.modeling_eagle import (
|
||||
EagleWrapper,
|
||||
EagleWrapperConfig,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.models.eagle import EagleDrafterFactory
|
||||
from tensorrt_llm.llmapi import DraftTargetDecodingConfig, Eagle3DecodingConfig, KvCacheConfig
|
||||
|
||||
@ -83,8 +95,6 @@ def run_with_autodeploy(model, speculative_config, batch_size):
|
||||
# Select prompts based on batch size
|
||||
selected_prompts = prompts[:batch_size]
|
||||
|
||||
spec_config = speculative_config
|
||||
|
||||
# Configure KV cache
|
||||
kv_cache_config = KvCacheConfig(
|
||||
free_gpu_memory_fraction=0.01,
|
||||
@ -94,7 +104,7 @@ def run_with_autodeploy(model, speculative_config, batch_size):
|
||||
llm_args = {
|
||||
"model": model,
|
||||
"skip_loading_weights": False,
|
||||
"speculative_config": spec_config,
|
||||
"speculative_config": speculative_config,
|
||||
"runtime": "trtllm",
|
||||
"world_size": 1,
|
||||
"kv_cache_config": kv_cache_config,
|
||||
@ -388,30 +398,20 @@ def test_eagle_model_with_weights():
|
||||
print(f"⚠️ Missing in checkpoint (will be random init): {len(missing_keys)}")
|
||||
print(f"⚠️ Unexpected in checkpoint (will be ignored): {len(unexpected_keys)}")
|
||||
|
||||
if missing_keys:
|
||||
print("\nMissing keys (model expects but checkpoint doesn't have):")
|
||||
for key in sorted(missing_keys):
|
||||
if "embed_tokens" in key:
|
||||
print(f" - {key} (expected: shared from target model)")
|
||||
elif "rotary_emb" in key:
|
||||
print(f" - {key} (expected: computed at runtime)")
|
||||
else:
|
||||
print(f" - {key}")
|
||||
|
||||
if unexpected_keys:
|
||||
print("\nUnexpected keys (in checkpoint but model doesn't expect):")
|
||||
for key in sorted(unexpected_keys):
|
||||
if "t2d" in key:
|
||||
print(f" - {key} (expected: not used in Eagle3)")
|
||||
print(f" - {key} (expected: not used in Eagle3 for Llama3.1-8B-Instruct)")
|
||||
else:
|
||||
print(f" - {key}")
|
||||
|
||||
if loaded_keys:
|
||||
print(f"\nLoaded keys ({len(loaded_keys)} total):")
|
||||
for key in sorted(loaded_keys)[:10]:
|
||||
for key in sorted(loaded_keys)[:20]:
|
||||
print(f" - {key}")
|
||||
if len(loaded_keys) > 10:
|
||||
print(f" ... and {len(loaded_keys) - 10} more")
|
||||
if len(loaded_keys) > 20:
|
||||
print(f" ... and {len(loaded_keys) - 20} more")
|
||||
|
||||
print("--- End Weight Analysis ---\n")
|
||||
|
||||
@ -419,15 +419,10 @@ def test_eagle_model_with_weights():
|
||||
# These are the keys we expect based on Eagle3 architecture:
|
||||
# - embed_tokens: shared from target model (not in Eagle checkpoint)
|
||||
# - t2d: target-to-draft mapping, not used in Eagle3 (uses d2t instead)
|
||||
expected_missing_keys = {"model.embed_tokens.weight"}
|
||||
expected_unexpected_keys = {"model.t2d"}
|
||||
|
||||
assert missing_keys == expected_missing_keys, (
|
||||
f"Unexpected missing keys.\n"
|
||||
f"Expected: {expected_missing_keys}\n"
|
||||
f"Got: {missing_keys}\n"
|
||||
f"Extra missing: {missing_keys - expected_missing_keys}\n"
|
||||
f"Not missing (but expected): {expected_missing_keys - missing_keys}"
|
||||
assert len(missing_keys) == 0, (
|
||||
f"Expect all keys to be loaded.\nKeys that are missing: {missing_keys}\n"
|
||||
)
|
||||
|
||||
assert unexpected_keys == expected_unexpected_keys, (
|
||||
@ -448,3 +443,777 @@ def test_eagle_model_with_weights():
|
||||
print("Weights loaded successfully via factory interface!")
|
||||
|
||||
model.eval()
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Set up to test the prefill-only version of the EagleWrapper model in test_eagle_wrapper_forward().
|
||||
# This helps us guarantee that the EagleWrapper model, before it enters AutoDeploy, is working correctly,
|
||||
# The test does not rely on any TRTLLM logic.
|
||||
###############################################################################
|
||||
class PrefillOnlyEagleResourceManager:
|
||||
"""Simple resource manager for Eagle speculative decoding (prefill-only variant).
|
||||
|
||||
Stores hidden states for use by draft loop in EagleWrapper.forward().
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_capture_layers: int,
|
||||
max_batch_size: int,
|
||||
max_seq_len: int,
|
||||
max_draft_len: int,
|
||||
target_dtype: torch.dtype,
|
||||
):
|
||||
# Buffer for hidden states from target model: [max_tokens, hidden_size * num_capture_layers]
|
||||
# Uses flattened 2D format to match ADHiddenStateManager
|
||||
self.hidden_states = torch.empty(
|
||||
max_batch_size * (max_seq_len + max_draft_len),
|
||||
hidden_size * num_capture_layers,
|
||||
device="cuda",
|
||||
dtype=target_dtype,
|
||||
)
|
||||
|
||||
|
||||
class LlamaModelWithCapture(LlamaModel):
|
||||
"""LlamaModel that captures un-normalized hidden states from specified layers.
|
||||
|
||||
Overwrites the base model's forward method to capture hidden states from specified layers.
|
||||
Base model's forward method is otherwise copied from LlamaModel in HuggingFace.
|
||||
Takes PrefillOnlyEagleResourceManager as an argument to store captured hidden states.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
layers_to_capture: Optional[Set[int]] = None,
|
||||
resource_manager: Optional[PrefillOnlyEagleResourceManager] = None,
|
||||
):
|
||||
super().__init__(config)
|
||||
# layers_to_capture: set of layer indices (0-indexed) to capture
|
||||
# If None, capture all layers
|
||||
if layers_to_capture is None:
|
||||
self.layers_to_capture = set(range(config.num_hidden_layers))
|
||||
else:
|
||||
self.layers_to_capture = set(layers_to_capture)
|
||||
|
||||
self.resource_manager = resource_manager
|
||||
|
||||
# Validate layer indices
|
||||
for idx in self.layers_to_capture:
|
||||
if idx < 0 or idx >= config.num_hidden_layers:
|
||||
raise ValueError(
|
||||
f"Layer index {idx} out of range. "
|
||||
f"Model has {config.num_hidden_layers} layers (0 to {config.num_hidden_layers - 1})"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> BaseModelOutputWithPast:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if cache_position is None:
|
||||
# prefill only - no past key values.
|
||||
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=None,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# Buffer to collect captured hidden states
|
||||
captured_hidden_states = []
|
||||
|
||||
for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
|
||||
hidden_states = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Capture this layer's output if it's in our list
|
||||
if layer_idx in self.layers_to_capture:
|
||||
captured_hidden_states.append(hidden_states)
|
||||
|
||||
# Apply final normalization for last_hidden_state
|
||||
last_hidden_state = self.norm(hidden_states)
|
||||
|
||||
# Store captured hidden states in resource manager if available
|
||||
# Resource manager uses 2D flattened format: [max_tokens, hidden_size * num_capture_layers]
|
||||
if self.resource_manager is not None and captured_hidden_states:
|
||||
concatenated = torch.cat(captured_hidden_states, dim=-1)
|
||||
batch_size, seq_len, total_hidden_size = concatenated.shape
|
||||
assert self.resource_manager.hidden_states.shape[-1] == total_hidden_size, (
|
||||
f"Resource manager buffer last dim {self.resource_manager.hidden_states.shape[-1]} "
|
||||
f"!= concatenated hidden states last dim {total_hidden_size}"
|
||||
)
|
||||
# Flatten to [batch_size * seq_len, total_hidden_size] for 2D format
|
||||
flattened = concatenated.view(batch_size * seq_len, total_hidden_size)
|
||||
self.resource_manager.hidden_states[: (batch_size * seq_len), :].copy_(flattened)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=last_hidden_state,
|
||||
hidden_states=tuple(captured_hidden_states) if captured_hidden_states else None,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LlamaForCausalLMOutput(ModelOutput):
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
|
||||
|
||||
|
||||
class LlamaForCausalLMWithCapture(nn.Module):
|
||||
"""Wrapper combining LlamaModelWithCapture with lm_head for EagleWrapper testing.
|
||||
|
||||
EagleWrapper.forward() expects target_model(input_ids, position_ids) to return logits.
|
||||
This class wraps LlamaModelWithCapture (which captures hidden states to resource manager)
|
||||
and adds the lm_head to produce logits.
|
||||
"""
|
||||
|
||||
def __init__(self, base_model, capture_model):
|
||||
super().__init__()
|
||||
self.model = capture_model # LlamaModelWithCapture with resource_manager
|
||||
self.lm_head = base_model.lm_head
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
outputs = self.model(
|
||||
input_ids=input_ids, inputs_embeds=inputs_embeds, position_ids=position_ids, **kwargs
|
||||
)
|
||||
logits = self.lm_head(outputs.last_hidden_state)
|
||||
return LlamaForCausalLMOutput(
|
||||
logits=logits,
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
hidden_states=outputs.hidden_states,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.model.lm_head
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_name: str,
|
||||
resource_manager,
|
||||
capture_layers,
|
||||
dtype=torch.bfloat16,
|
||||
):
|
||||
"""Load a base model and create a LlamaForCausalLMWithCapture with shared weights."""
|
||||
print(f"Loading {model_name}...")
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=dtype,
|
||||
device_map={"": 0},
|
||||
)
|
||||
base_model.eval()
|
||||
|
||||
# Create LlamaModelWithCapture that shares weights with the base model
|
||||
original_llama_model = base_model.model
|
||||
|
||||
capture_model = LlamaModelWithCapture.__new__(LlamaModelWithCapture)
|
||||
nn.Module.__init__(capture_model)
|
||||
|
||||
capture_model.config = original_llama_model.config
|
||||
capture_model.layers_to_capture = capture_layers
|
||||
capture_model.resource_manager = resource_manager
|
||||
|
||||
# Share all modules (no weight copying)
|
||||
capture_model.embed_tokens = original_llama_model.embed_tokens
|
||||
capture_model.layers = original_llama_model.layers
|
||||
capture_model.norm = original_llama_model.norm
|
||||
capture_model.rotary_emb = original_llama_model.rotary_emb
|
||||
capture_model.gradient_checkpointing = original_llama_model.gradient_checkpointing
|
||||
|
||||
return cls(base_model, capture_model)
|
||||
|
||||
|
||||
def build_eagle_wrapper(
|
||||
base_model_path: str,
|
||||
eagle_model_path: str,
|
||||
resource_manager: PrefillOnlyEagleResourceManager,
|
||||
capture_layers: Set[int],
|
||||
max_seq_len: int,
|
||||
max_draft_len: int,
|
||||
target_dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> tuple[EagleWrapper, nn.Module]:
|
||||
"""Build an EagleWrapper model for testing.
|
||||
|
||||
This function encapsulates the model building logic using manual model building.
|
||||
|
||||
Returns:
|
||||
A tuple of (eagle_wrapper, target_model) where:
|
||||
- eagle_wrapper: The EagleWrapper model ready for inference.
|
||||
- target_model: The target model (for verification steps).
|
||||
"""
|
||||
# Build EagleWrapper manually.
|
||||
print("\n" + "-" * 40)
|
||||
print("Building EagleWrapper")
|
||||
print("-" * 40)
|
||||
|
||||
# Create target model with capture
|
||||
target_model = LlamaForCausalLMWithCapture.from_pretrained(
|
||||
base_model_path, resource_manager, capture_layers, target_dtype
|
||||
)
|
||||
print("✓ Created target model with capture")
|
||||
|
||||
# Create draft model using EagleDrafterFactory (mimics production pipeline)
|
||||
# This ensures weights are loaded correctly via the same path as AutoDeploy
|
||||
print("\nCreating draft model via EagleDrafterFactory...")
|
||||
draft_factory = EagleDrafterFactory(
|
||||
model=eagle_model_path,
|
||||
skip_loading_weights=False,
|
||||
)
|
||||
|
||||
# Build model on meta device first, then load weights
|
||||
draft_model = draft_factory.build_model("meta")
|
||||
print(f" Model type: {type(draft_model).__name__}")
|
||||
|
||||
# Load weights via factory
|
||||
print(" Loading weights via factory.load_or_random_init()...")
|
||||
draft_factory.load_or_random_init(draft_model, device)
|
||||
draft_model.eval()
|
||||
|
||||
# Create EagleWrapper config
|
||||
wrapper_config = EagleWrapperConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
load_embedding_from_target=draft_model.load_embedding_from_target,
|
||||
load_lm_head_from_target=draft_model.load_lm_head_from_target,
|
||||
)
|
||||
|
||||
# Build EagleWrapper (this also loads weights from target into draft model where necessary)
|
||||
eagle_wrapper = EagleWrapper(
|
||||
config=wrapper_config,
|
||||
target_model=target_model,
|
||||
draft_model=draft_model,
|
||||
resource_manager=resource_manager,
|
||||
)
|
||||
eagle_wrapper.eval()
|
||||
print("✓ Built EagleWrapper")
|
||||
|
||||
return eagle_wrapper, target_model
|
||||
|
||||
|
||||
def generate_target_outputs(
|
||||
target_model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
num_iterations: int,
|
||||
) -> torch.Tensor:
|
||||
"""Generate tokens from target model using greedy sampling.
|
||||
|
||||
Runs target_model.forward() in a loop, taking the last logit from each output,
|
||||
applying greedy sampling with torch.argmax, and appending to input_ids.
|
||||
|
||||
Args:
|
||||
target_model: Model that returns logits from forward(input_ids, position_ids).
|
||||
input_ids: Initial input token ids of shape [batch_size, seq_len].
|
||||
num_iterations: Number of tokens to generate.
|
||||
|
||||
Returns:
|
||||
output_ids: Tensor of shape [batch_size, seq_len + num_iterations] containing
|
||||
the original input_ids plus the generated tokens.
|
||||
"""
|
||||
device = input_ids.device
|
||||
init_seq_len = input_ids.shape[1]
|
||||
print(f"Initial sequence length: {init_seq_len}")
|
||||
current_ids = input_ids.clone()
|
||||
|
||||
with torch.no_grad():
|
||||
for _ in range(num_iterations):
|
||||
# Generate position_ids from current sequence length
|
||||
seq_len = current_ids.shape[1]
|
||||
position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
|
||||
position_ids = position_ids.expand(current_ids.shape[0], -1)
|
||||
|
||||
# Forward pass
|
||||
logits = target_model(current_ids, position_ids=position_ids).logits
|
||||
|
||||
# Take the last logit and apply greedy sampling
|
||||
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
|
||||
next_token = torch.argmax(last_logits, dim=-1, keepdim=True) # [batch_size, 1]
|
||||
|
||||
# Append to current_ids
|
||||
current_ids = torch.cat([current_ids, next_token], dim=1)
|
||||
|
||||
return current_ids
|
||||
|
||||
|
||||
def print_token_analysis(
|
||||
input_ids: torch.Tensor,
|
||||
num_previously_accepted: torch.Tensor,
|
||||
target_output_ids: torch.Tensor,
|
||||
tokenizer,
|
||||
) -> None:
|
||||
"""Print debug analysis of accepted vs speculative tokens for each batch.
|
||||
|
||||
Args:
|
||||
input_ids: Current input token ids of shape [batch_size, seq_len].
|
||||
num_previously_accepted: Number of accepted tokens per batch [batch_size].
|
||||
target_output_ids: Reference output from target model [batch_size, total_seq_len].
|
||||
tokenizer: Tokenizer for decoding tokens to text.
|
||||
"""
|
||||
batch_size = input_ids.shape[0]
|
||||
print("\n --- Token Analysis (per batch) ---")
|
||||
|
||||
for i in range(batch_size):
|
||||
prev_accepted_i = num_previously_accepted[i].item()
|
||||
|
||||
# Accepted tokens (before speculation): input_ids[i, :num_previously_accepted[i]]
|
||||
accepted_tokens = input_ids[i, :prev_accepted_i]
|
||||
# Speculative tokens: input_ids[i, num_previously_accepted[i]:]
|
||||
speculative_tokens = input_ids[i, prev_accepted_i:]
|
||||
|
||||
# Target model's expected token at this position
|
||||
if prev_accepted_i < target_output_ids.shape[1]:
|
||||
target_token_at_pos = target_output_ids[i, prev_accepted_i]
|
||||
else:
|
||||
target_token_at_pos = None
|
||||
|
||||
print(f"\n Batch {i}:")
|
||||
print(f" num_previously_accepted: {prev_accepted_i}")
|
||||
print(
|
||||
f" Accepted tokens ({accepted_tokens.shape[0]} tokens): {accepted_tokens.tolist()}"
|
||||
)
|
||||
accepted_text = tokenizer.decode(accepted_tokens, skip_special_tokens=True)
|
||||
print(f' Accepted text: "{accepted_text}"')
|
||||
print(
|
||||
f" Speculative tokens ({speculative_tokens.shape[0]} tokens): {speculative_tokens.tolist()}"
|
||||
)
|
||||
if speculative_tokens.shape[0] > 0:
|
||||
spec_text = tokenizer.decode(speculative_tokens, skip_special_tokens=False)
|
||||
print(f' Speculative text: "{spec_text}"')
|
||||
if target_token_at_pos is not None:
|
||||
target_tok_id = target_token_at_pos.item()
|
||||
target_tok_str = tokenizer.decode([target_tok_id])
|
||||
print(
|
||||
f' Target model\'s next token at pos {prev_accepted_i}: {target_tok_id} ("{target_tok_str}")'
|
||||
)
|
||||
|
||||
|
||||
def manual_sample_and_verify(
|
||||
next_target_inputs: list,
|
||||
num_accepted_tokens: torch.Tensor,
|
||||
target_model: nn.Module,
|
||||
eagle_wrapper: nn.Module,
|
||||
max_draft_len: int,
|
||||
device: torch.device,
|
||||
) -> list:
|
||||
"""Manually verify speculative tokens using sample_and_verify.
|
||||
|
||||
This is used for batch_size > 1 where truncation prevents speculative tokens
|
||||
from being fed back, so we verify them manually before truncation.
|
||||
|
||||
Args:
|
||||
next_target_inputs: List of tensors, one per batch element.
|
||||
num_accepted_tokens: Number of tokens accepted so far per batch [batch_size].
|
||||
target_model: The target model for running forward pass.
|
||||
eagle_wrapper: The EagleWrapper containing sample_and_verify.
|
||||
max_draft_len: Maximum draft length (for capping counts).
|
||||
device: Device to run on.
|
||||
|
||||
Returns:
|
||||
List of (num_accepted, num_speculative) tuples for each batch element.
|
||||
"""
|
||||
batch_size = len(next_target_inputs)
|
||||
|
||||
# Due to our truncation trick, all sequences should have the same length
|
||||
seq_lens = [seq.shape[0] for seq in next_target_inputs]
|
||||
assert all(slen == seq_lens[0] for slen in seq_lens), (
|
||||
f"All sequences should have same length due to truncation, got {seq_lens}"
|
||||
)
|
||||
verify_seq_len = seq_lens[0]
|
||||
|
||||
# Stack into batched tensor
|
||||
stacked_inputs = torch.stack(next_target_inputs, dim=0) # [batch_size, seq_len]
|
||||
|
||||
# Run target model forward to get logits
|
||||
verify_position_ids = (
|
||||
torch.arange(verify_seq_len, device=device, dtype=torch.long)
|
||||
.unsqueeze(0)
|
||||
.expand(batch_size, -1)
|
||||
)
|
||||
with torch.no_grad():
|
||||
verify_target_logits = target_model(stacked_inputs, position_ids=verify_position_ids).logits
|
||||
|
||||
# new_num_previously_accepted = num_accepted_tokens + 1
|
||||
# This represents the tokens accepted after target model's output from this iteration
|
||||
new_num_previously_accepted = num_accepted_tokens + 1
|
||||
|
||||
# Call sample_and_verify to get acceptance counts
|
||||
_, verify_newly_accepted, _, _ = eagle_wrapper.sample_and_verify(
|
||||
stacked_inputs, verify_target_logits, new_num_previously_accepted
|
||||
)
|
||||
|
||||
# Build results list
|
||||
results = []
|
||||
for i in range(batch_size):
|
||||
num_accepted_i = min(verify_newly_accepted[i].item(), max_draft_len)
|
||||
num_speculative = next_target_inputs[i].shape[0] - new_num_previously_accepted[i].item()
|
||||
results.append((num_accepted_i, num_speculative))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def verify_eagle_wrapper_output(output, tokenizer, batch_size, num_previously_accepted):
|
||||
"""Verify the output structure and values from EagleWrapper forward pass.
|
||||
|
||||
Args:
|
||||
output: The output from EagleWrapper forward pass.
|
||||
tokenizer: The tokenizer for decoding tokens.
|
||||
batch_size: The batch size.
|
||||
num_previously_accepted: Tensor of previously accepted token counts.
|
||||
"""
|
||||
# Verify output structure
|
||||
print("\nOutput verification:")
|
||||
assert output is not None, "Output should not be None"
|
||||
assert hasattr(output, "new_tokens"), "Output should have new_tokens"
|
||||
assert hasattr(output, "new_tokens_lens"), "Output should have new_tokens_lens"
|
||||
|
||||
print(f" new_tokens: {type(output.new_tokens)} with {len(output.new_tokens)} items")
|
||||
for i, tokens in enumerate(output.new_tokens):
|
||||
new_tokens_text = tokenizer.decode(tokens, skip_special_tokens=True)
|
||||
print(f" batch {i}: shape {tokens.shape}, tokens: {tokens.tolist()}")
|
||||
print(f' batch {i}: decoded: "{new_tokens_text}"')
|
||||
|
||||
# Compute num_accepted_tokens from new_tokens_lens + num_previously_accepted
|
||||
num_accepted_tokens = num_previously_accepted + output.new_tokens_lens
|
||||
|
||||
print(f" new_tokens_lens: {output.new_tokens_lens}")
|
||||
print(f" num_accepted_tokens (computed): {num_accepted_tokens}")
|
||||
|
||||
# Verify new_tokens_lens is within expected bounds
|
||||
assert output.new_tokens_lens.shape == (batch_size,), (
|
||||
f"new_tokens_lens shape should be ({batch_size},), got {output.new_tokens_lens.shape}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 2])
|
||||
def test_eagle_wrapper_forward(batch_size: int):
|
||||
"""Test EagleWrapper forward pass with target and draft models.
|
||||
|
||||
This test validates the full speculative decoding loop:
|
||||
1. Target model processes input and captures hidden states
|
||||
2. Draft model generates speculative tokens
|
||||
3. EagleWrapper orchestrates verification and drafting
|
||||
|
||||
For batch size 1, we call EagleWrapper forward in the expected way. Each iteration generates a "golden token"
|
||||
(target output) and draft tokens. We input all of them to the wrapper model,
|
||||
which verifies the draft tokens against the target output. It then outputs the accepted tokens
|
||||
and newly generated draft tokens, along with numbers of accepted tokens, and the process repeats.
|
||||
|
||||
For batch size > 1, we need to work around the fact that as we run the loop described above, the sequences lengths
|
||||
in the batch will get out of sync. So instead, we do not provide validated draft tokens as input in each iteration -
|
||||
we just input the first accepted token from the previous iteration
|
||||
(which we know was generated by the target model), which keeps the batches in sync.
|
||||
|
||||
To verify that the output draft tokens are reasonable, we run a manual target model verification step
|
||||
after each iteration. We record how many of the output draft tokens were accepted.
|
||||
|
||||
In the end, we test that the acceptance ratio of the draft tokens generated by the EagleWrapper is reasonable.
|
||||
|
||||
Args:
|
||||
batch_size: Number of prompts to process in parallel.
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print("Test: EagleWrapper forward pass")
|
||||
print("=" * 80)
|
||||
|
||||
# Set random seeds for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Get model paths using integration test conventions
|
||||
base_model_path, _, eagle_model_path = get_model_paths()
|
||||
eagle_path = Path(eagle_model_path)
|
||||
|
||||
if not eagle_path.exists():
|
||||
pytest.skip("Eagle model not found (model missing)")
|
||||
|
||||
# Configuration
|
||||
capture_layers = {1, 15, 28} # Layers to capture for Eagle3
|
||||
num_capture_layers = len(capture_layers)
|
||||
hidden_size = 4096 # Llama 3.1-8B hidden size
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
|
||||
# Test dimensions
|
||||
max_batch_size = 4
|
||||
max_seq_len = 1024
|
||||
max_draft_len = 3
|
||||
|
||||
# Tokenize the test prompts
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
|
||||
# Llama uses left padding for batch inference
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
if batch_size == 1:
|
||||
input_ids = tokenizer.encode(prompts[0], return_tensors="pt").to(device)
|
||||
else:
|
||||
tokenized = tokenizer(
|
||||
prompts[:batch_size],
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
input_ids = tokenized.input_ids.to(device)
|
||||
|
||||
print(f"input_ids: {input_ids}")
|
||||
seq_len = input_ids.shape[1]
|
||||
init_seq_len = seq_len # Store initial sequence length for final comparison
|
||||
|
||||
print("\nTest configuration:")
|
||||
print(f" target_model: {base_model_path}")
|
||||
print(f" eagle_model: {eagle_path}")
|
||||
print(f" batch_size: {batch_size}, seq_len: {seq_len}")
|
||||
print(f" max_draft_len: {max_draft_len}")
|
||||
print(f" capture_layers: {capture_layers}")
|
||||
print(f" prompts: {prompts[:batch_size]}")
|
||||
print(f" input_ids: {input_ids}")
|
||||
|
||||
# Create resource manager
|
||||
resource_manager = PrefillOnlyEagleResourceManager(
|
||||
hidden_size=hidden_size,
|
||||
num_capture_layers=num_capture_layers,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=max_seq_len,
|
||||
max_draft_len=max_draft_len,
|
||||
target_dtype=dtype,
|
||||
)
|
||||
print("\n✓ Created resource manager")
|
||||
print(f" target_hidden_states shape: {resource_manager.hidden_states.shape}")
|
||||
|
||||
# Build eagle_wrapper and target_model using the refactored function
|
||||
eagle_wrapper, target_model = build_eagle_wrapper(
|
||||
base_model_path=base_model_path,
|
||||
eagle_model_path=str(eagle_path),
|
||||
resource_manager=resource_manager,
|
||||
capture_layers=capture_layers,
|
||||
max_seq_len=max_seq_len,
|
||||
max_draft_len=max_draft_len,
|
||||
target_dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Create test inputs (input_ids already created from tokenizer above)
|
||||
position_ids = (
|
||||
torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1)
|
||||
)
|
||||
# Set previously_accepted_tokens to the length of input_ids (all context tokens are accepted)
|
||||
# Shape should be [batch_size] - a 1D tensor with one value per batch
|
||||
num_previously_accepted = torch.full((batch_size,), seq_len, device=device, dtype=torch.long)
|
||||
|
||||
print("\nTest inputs:")
|
||||
print(f" input_ids shape: {input_ids.shape}")
|
||||
print(f" input_ids: {input_ids}")
|
||||
print(f" position_ids shape: {position_ids.shape}")
|
||||
print(f" num_previously_accepted: {num_previously_accepted}")
|
||||
|
||||
# Generate target model outputs with greedy sampling
|
||||
print("\nGenerating target model outputs with greedy sampling (for verification)...")
|
||||
target_output_ids = generate_target_outputs(target_model, input_ids, num_iterations=100)
|
||||
print(f" target_output_ids shape: {target_output_ids.shape}")
|
||||
print(f" target_output_ids: {target_output_ids}")
|
||||
|
||||
# Decode to text as sanity check
|
||||
generated_text = tokenizer.decode(target_output_ids[0], skip_special_tokens=True)
|
||||
print(f"\n Target model greedy generation decoded text:\n {generated_text}")
|
||||
|
||||
print("\n✓ EagleWrapper forward pass completed successfully!")
|
||||
print("✓ Output structure verified")
|
||||
print("✓ new_tokens_lens within expected bounds")
|
||||
print("✓ Target model greedy generation completed")
|
||||
|
||||
print("\n================================================")
|
||||
|
||||
num_iterations = 70
|
||||
|
||||
# Dictionary to track distribution of new_tokens_lens
|
||||
# keys: 0 to max_draft_len
|
||||
# newly_accepted_counts[i]: number of times the number of accepted draft tokens was i
|
||||
newly_accepted_counts = {i: 0 for i in range(max_draft_len + 1)}
|
||||
|
||||
for iteration in range(num_iterations):
|
||||
print(f"\n{'=' * 40}")
|
||||
print(f"EagleWrapper forward pass - Iteration {iteration + 1}/{num_iterations}")
|
||||
print(f"{'=' * 40}")
|
||||
|
||||
seq_len = input_ids.shape[1]
|
||||
|
||||
# Debug: Print speculative tokens, accepted tokens, and target comparison
|
||||
print_token_analysis(input_ids, num_previously_accepted, target_output_ids, tokenizer)
|
||||
|
||||
kwargs = {
|
||||
"num_previously_accepted": num_previously_accepted,
|
||||
}
|
||||
with torch.no_grad():
|
||||
output = eagle_wrapper(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
verify_eagle_wrapper_output(output, tokenizer, batch_size, num_previously_accepted)
|
||||
|
||||
# Prepare next_target_inputs
|
||||
# output.new_tokens[i] contains the full draft_input_ids tensor, but the valid prefix
|
||||
# has length num_accepted_tokens[i] + max_draft_len. We slice to get only valid tokens.
|
||||
# We then prepend the first token from the previous iteration's input_ids.
|
||||
# This prepending is only needed for prefill-only mode, since in the cached case, the first token
|
||||
# will always be in the KV cache.
|
||||
# Compute num_accepted_tokens from num_previously_accepted + new_tokens_lens
|
||||
num_accepted_tokens = num_previously_accepted + output.new_tokens_lens
|
||||
valid_prefix_len = num_accepted_tokens + max_draft_len
|
||||
next_target_inputs = [
|
||||
torch.cat(
|
||||
(input_ids[i, 0].unsqueeze(0), output.new_tokens[i][: valid_prefix_len[i]]),
|
||||
dim=0,
|
||||
)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
|
||||
# Track distribution of newly accepted tokens by reading new_tokens_lens from the output.
|
||||
# For batch size = 1, we are inputting draft tokens to the wrapper model, so new_tokens_lens
|
||||
# gives the number of accepted tokens from drafts in the previous iteration.
|
||||
if batch_size == 1:
|
||||
for val in output.new_tokens_lens.tolist():
|
||||
newly_accepted_counts[val] += 1
|
||||
print(f" newly_accepted_counts so far: {newly_accepted_counts}")
|
||||
|
||||
# For batch_size > 1, we use manual target model verification below instead to check which of the draft tokens
|
||||
# generated in *this* iteration would be accepted by the target model.
|
||||
else:
|
||||
# For batch_size > 1, verify acceptance using sample_and_verify()
|
||||
# before truncation (since truncation prevents speculative tokens from being fed back)
|
||||
verify_results = manual_sample_and_verify(
|
||||
next_target_inputs,
|
||||
num_accepted_tokens,
|
||||
target_model,
|
||||
eagle_wrapper,
|
||||
max_draft_len,
|
||||
device,
|
||||
)
|
||||
|
||||
# Update newly_accepted_counts map
|
||||
for i, (num_accepted_i, num_speculative) in enumerate(verify_results):
|
||||
newly_accepted_counts[num_accepted_i] += 1
|
||||
print(
|
||||
f" [Batch {i}] sample_and_verify: {num_accepted_i}/{num_speculative} speculative accepted"
|
||||
)
|
||||
|
||||
# Truncate to keep shapes consistent across batches in each iteration.
|
||||
# We know that the first token that is generated in this iteration is accepted, so it is "safe".
|
||||
# All speculative tokens are truncated regardless of whether they are accepted or not.
|
||||
# This is a hack to prevent the sequence lengths from getting out of sync across batches in each iteration
|
||||
# without needing to change the padding every iteration.
|
||||
truncate_len = input_ids.shape[1] + 1
|
||||
next_target_inputs = [seq[:truncate_len] for seq in next_target_inputs]
|
||||
|
||||
next_target_inputs = torch.stack(next_target_inputs, dim=0)
|
||||
|
||||
print(f" next_target_inputs: {next_target_inputs}")
|
||||
print(f" next_target_inputs.shape: {next_target_inputs.shape}")
|
||||
|
||||
# Update for next iteration
|
||||
input_ids = next_target_inputs
|
||||
seq_len = input_ids.shape[1]
|
||||
position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
|
||||
position_ids = position_ids.expand(batch_size, -1)
|
||||
|
||||
if batch_size > 1:
|
||||
# For multi-batch: increment by 1 (we truncated, so just advance by one token)
|
||||
num_previously_accepted = num_previously_accepted + 1
|
||||
else:
|
||||
# For single batch: accept the tokens accepted in the previous iteration, plus one
|
||||
# for the output token that was generated by the target.
|
||||
num_previously_accepted = num_accepted_tokens + 1
|
||||
|
||||
print(f"\n{'=' * 40}")
|
||||
print(f"Loop completed: {num_iterations} iterations")
|
||||
print("Newly accepted tokens distribution:")
|
||||
for k, v in newly_accepted_counts.items():
|
||||
print(f" {k}: {v}")
|
||||
|
||||
# Calculate acceptance ratio
|
||||
# For batch_size == 1: uses new_tokens_lens from eagle wrapper
|
||||
# For batch_size > 1: uses manual verification against target model (since truncation
|
||||
# prevents speculative tokens from being fed back)
|
||||
total_accepted = sum(k * v for k, v in newly_accepted_counts.items())
|
||||
# First iteration has no tokens to newly accept, subsequent iterations have max_draft_len potential
|
||||
|
||||
num_iterations_with_drafts = num_iterations - 1 if batch_size == 1 else num_iterations
|
||||
total_potential = max_draft_len * (num_iterations_with_drafts) * batch_size
|
||||
acceptance_ratio = total_accepted / total_potential if total_potential > 0 else 0.0
|
||||
print(f"\nAcceptance ratio: {total_accepted}/{total_potential} = {acceptance_ratio:.3f}")
|
||||
if batch_size > 1:
|
||||
print(" (batch_size > 1: measured via manual target model verification)")
|
||||
assert acceptance_ratio > 0.1, (
|
||||
f"Acceptance ratio {acceptance_ratio:.3f} is too low (expected > 0.1)"
|
||||
)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("FINAL OUTPUT COMPARISON")
|
||||
print("=" * 80)
|
||||
for i in range(batch_size):
|
||||
print(f"\n{'─' * 40}")
|
||||
print(f"BATCH {i}")
|
||||
print(f"{'─' * 40}")
|
||||
print(f"\n[Target Model Output] ({target_output_ids[i].shape[0]} tokens):")
|
||||
print(f" Tokens: {target_output_ids[i].tolist()}")
|
||||
print(f' Text: "{tokenizer.decode(target_output_ids[i], skip_special_tokens=True)}"')
|
||||
print(f"\n[Eagle Wrapper Output] ({input_ids[i].shape[0]} tokens):")
|
||||
print(f" Tokens: {input_ids[i].tolist()}")
|
||||
print(f' Text: "{tokenizer.decode(input_ids[i], skip_special_tokens=True)}"')
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
# Verify that the first 10 generated tokens match between target model and eagle wrapper
|
||||
# They seem to diverge after awhile but are semantically the same.
|
||||
# Note that even running the target model in decode vs prefill mode, the outputs seem to diverge similarly,
|
||||
# so this is not worrisome. This test provides a check that they are "similar enough" to each other.
|
||||
num_tokens_to_check = 10
|
||||
print(f"\nVerifying first {num_tokens_to_check} generated tokens match...")
|
||||
for i in range(batch_size):
|
||||
target_generated = target_output_ids[i, init_seq_len : init_seq_len + num_tokens_to_check]
|
||||
eagle_generated = input_ids[i, init_seq_len : init_seq_len + num_tokens_to_check]
|
||||
|
||||
print(f" Batch {i}:")
|
||||
print(f" Target: {target_generated.tolist()}")
|
||||
print(f" Eagle: {eagle_generated.tolist()}")
|
||||
|
||||
assert torch.equal(target_generated, eagle_generated), (
|
||||
f"Batch {i}: First {num_tokens_to_check} generated tokens do not match!\n"
|
||||
f" Target: {target_generated.tolist()}\n"
|
||||
f" Eagle: {eagle_generated.tolist()}"
|
||||
)
|
||||
print(f"✓ First {num_tokens_to_check} generated tokens match for all batches!")
|
||||
|
||||
@ -443,3 +443,5 @@ l0_h100:
|
||||
- examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[eagle3]
|
||||
- examples/test_ad_speculative_decoding.py::test_autodeploy_eagle3_acceptance_rate
|
||||
- examples/test_ad_speculative_decoding.py::test_eagle_model_with_weights
|
||||
- examples/test_ad_speculative_decoding.py::test_eagle_wrapper_forward[1]
|
||||
- examples/test_ad_speculative_decoding.py::test_eagle_wrapper_forward[2]
|
||||
|
||||
@ -21,11 +21,11 @@ import pytest
|
||||
import torch
|
||||
from _model_test_utils import get_small_model_config
|
||||
from build_and_run_ad import ExperimentConfig, main
|
||||
from transformers import AutoConfig
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.models.custom.modeling_eagle import (
|
||||
Eagle3DrafterForCausalLM,
|
||||
Eagle3Model,
|
||||
Eagle3DraftOutput,
|
||||
EagleConfig,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.models.eagle import EagleDrafterFactory
|
||||
from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactoryRegistry
|
||||
@ -33,7 +33,6 @@ from tests.test_common.llm_data import hf_id_to_local_model_dir
|
||||
|
||||
EAGLE_MODEL_HUB_ID = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Mock classes for standalone Eagle testing
|
||||
#
|
||||
@ -43,6 +42,22 @@ EAGLE_MODEL_HUB_ID = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
###############################################################################
|
||||
|
||||
|
||||
class MockEagleConfig(EagleConfig):
|
||||
"""Config for standalone Eagle testing with embedding/lm_head loaded from checkpoint.
|
||||
|
||||
In production, Eagle shares embedding/lm_head with the target model.
|
||||
For standalone testing, we need to load these from the checkpoint.
|
||||
"""
|
||||
|
||||
_drafter_defaults = {
|
||||
"llama": {
|
||||
"load_embedding_from_target": False,
|
||||
"load_lm_head_from_target": False,
|
||||
"num_capture_layers": 1,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class MockEagle3ModelForCausalLM(Eagle3DrafterForCausalLM):
|
||||
"""Test wrapper that provides random hidden states for standalone Eagle testing.
|
||||
|
||||
@ -55,7 +70,17 @@ class MockEagle3ModelForCausalLM(Eagle3DrafterForCausalLM):
|
||||
self._hidden_size = config.hidden_size
|
||||
self._dtype = config.dtype
|
||||
|
||||
def forward(self, input_ids, **kwargs):
|
||||
def forward(self, input_ids, position_ids, input_embeds=None, **kwargs):
|
||||
assert self.model.embed_tokens is not None, (
|
||||
"embed_tokens must be set before running standalone Eagle model."
|
||||
)
|
||||
assert self.lm_head is not None, (
|
||||
"lm_head must be set before running standalone Eagle model."
|
||||
)
|
||||
|
||||
if input_embeds is None:
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
|
||||
# Inject mock hidden states if not provided
|
||||
if "hidden_states" not in kwargs:
|
||||
batch_size, seq_len = input_ids.shape
|
||||
@ -64,19 +89,39 @@ class MockEagle3ModelForCausalLM(Eagle3DrafterForCausalLM):
|
||||
dtype=self._dtype,
|
||||
device=input_ids.device,
|
||||
)
|
||||
return super().forward(input_ids, **kwargs)
|
||||
draft_output = super().forward(inputs_embeds, position_ids, **kwargs)
|
||||
logits = self.lm_head(draft_output.norm_hidden_state)
|
||||
return Eagle3DraftOutput(logits=logits, last_hidden_state=draft_output.last_hidden_state)
|
||||
|
||||
|
||||
class MockEagleDrafterFactory(EagleDrafterFactory):
|
||||
"""Test factory that uses MockEagle3ModelForCausalLM for standalone Eagle testing.
|
||||
|
||||
This factory overrides the drafter mapping to use the mock model class which
|
||||
generates random hidden states, enabling testing without a target model.
|
||||
This factory directly builds MockEagle3ModelForCausalLM with MockEagleConfig,
|
||||
which loads embedding/lm_head from checkpoint for standalone testing.
|
||||
"""
|
||||
|
||||
_drafter_model_mapping = {
|
||||
"llama": MockEagle3ModelForCausalLM,
|
||||
}
|
||||
def _build_model(self, device):
|
||||
from contextlib import nullcontext
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
model_config, unused_kwargs = self._get_model_config()
|
||||
model_config = MockEagleConfig(model_config, model_config.model_type)
|
||||
|
||||
with (init_empty_weights if device == "meta" else nullcontext)():
|
||||
model = MockEagle3ModelForCausalLM._from_config(model_config, **unused_kwargs)
|
||||
|
||||
if device == "meta":
|
||||
if hasattr(model, "post_init"):
|
||||
model.post_init()
|
||||
else:
|
||||
model.to(device)
|
||||
|
||||
self._checkpoint_conversion_mapping = getattr(model, "_checkpoint_conversion_mapping", None)
|
||||
model.eval()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -125,7 +170,7 @@ def test_eagle_model_torch_export():
|
||||
torch.export for potential TensorRT compilation.
|
||||
|
||||
Note: We skip loading weights since torch.export only traces the computation
|
||||
graph (model architecture), not the actual weight values. Random init is fine.
|
||||
graph (model architecture).
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print("Test: EagleModel torch.export")
|
||||
@ -141,40 +186,35 @@ def test_eagle_model_torch_export():
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dtype = torch.float16
|
||||
|
||||
# Use pretrained config (llama) - can provide it directly to instantiate Eagle3Model.
|
||||
config_path = eagle_path / "config.json"
|
||||
config = AutoConfig.from_pretrained(config_path)
|
||||
|
||||
# Create model with random weights (no need to load for export test)
|
||||
model = Eagle3Model(config)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
# Create model via EagleDrafterFactory (creates Eagle3DrafterForCausalLM)
|
||||
factory = EagleDrafterFactory(model=str(eagle_path), skip_loading_weights=True)
|
||||
model = factory.build_model(device)
|
||||
config = model.config
|
||||
|
||||
# Create inputs for export
|
||||
batch_size = 1
|
||||
seq_len = 8
|
||||
hidden_dim = config.hidden_size
|
||||
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, seq_len), device=device, dtype=torch.long
|
||||
)
|
||||
inputs_embeds = torch.randn((batch_size, seq_len, hidden_dim), device=device, dtype=dtype)
|
||||
position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
|
||||
mock_hidden_states = torch.randn((batch_size, seq_len, hidden_dim), device=device, dtype=dtype)
|
||||
|
||||
print("Export input shapes:")
|
||||
print(f" input_ids: {input_ids.shape}")
|
||||
print(f" inputs_embeds: {inputs_embeds.shape}")
|
||||
print(f" position_ids: {position_ids.shape}")
|
||||
print(f" hidden_states: {mock_hidden_states.shape}")
|
||||
|
||||
example_args = (
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
position_ids,
|
||||
mock_hidden_states,
|
||||
)
|
||||
|
||||
# Attempt torch.export
|
||||
try:
|
||||
exported_program = torch.export.export(model, args=example_args)
|
||||
exported_program = torch.export.export(
|
||||
model, args=example_args, kwargs={"hidden_states": mock_hidden_states}
|
||||
)
|
||||
print("✅ torch.export successful!")
|
||||
print("Graph module code preview (first 20 lines):")
|
||||
code_lines = exported_program.graph_module.code.split("\n")[:20]
|
||||
|
||||
@ -21,6 +21,9 @@ from test_common.llm_data import with_mocked_hf_download
|
||||
from tensorrt_llm.llmapi import DraftTargetDecodingConfig, KvCacheConfig
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="OOM on A30 GPUs on CI - speculative model loading does not support model_kwargs reduction"
|
||||
)
|
||||
@pytest.mark.parametrize("use_hf_speculative_model", [False])
|
||||
@with_mocked_hf_download
|
||||
def test_ad_speculative_decoding_smoke(use_hf_speculative_model: bool):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user