[#9241][feat] AutoDeploy: Support Eagle3 Speculative Decoding (#9869)

Support two model flow with no overlap scheduler or chain drafter. Drafting model is in PyTorch backend.

Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
This commit is contained in:
gramnarayan 2025-12-24 20:30:42 -08:00 committed by GitHub
parent 1f8ed71d5f
commit a9eb5afc9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 671 additions and 56 deletions

View File

@ -75,6 +75,8 @@ transforms:
stage: pattern_matcher
quantize_mxfp4_moe:
stage: pattern_matcher
detect_hidden_states_for_capture:
stage: pattern_matcher
detect_sharding:
stage: sharding
simple_shard_only: false
@ -163,6 +165,9 @@ transforms:
insert_cached_delta_rule:
stage: cache_init
backend: fla_delta
insert_cached_residual_add:
stage: cache_init
backend: cached_residual_add
initialize_cache:
stage: cache_init
run_per_gm: false

View File

@ -8,7 +8,14 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
from tensorrt_llm.models.modeling_utils import QuantConfig
from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, KvCacheConfig, SamplerType, _ParallelConfig
from ...llmapi.llm_args import (
BaseLlmArgs,
BuildConfig,
EagleDecodingConfig,
KvCacheConfig,
SamplerType,
_ParallelConfig,
)
from .models import ModelFactory, ModelFactoryRegistry
from .utils._config import DynamicYamlMixInForSettings
from .utils.logger import ad_logger
@ -150,6 +157,11 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
enable_chunked_prefill: bool = Field(default=False, description="Enable chunked prefill.")
draft_checkpoint_loader: Optional[object] = Field(
default=None,
description="The checkpoint loader to use for the draft model when using speculative decoding with two models.",
)
### INFERENCE OPTIMIZER CONFIG #################################################################
mode: Literal["graph", "transformers"] = Field(
default="graph",
@ -190,11 +202,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
),
)
draft_checkpoint_loader: Optional[object] = Field(
default=None,
description="The checkpoint loader to use for the draft model when using speculative decoding with two models.",
)
### SEQUENCE INTERFACE CONFIG ##################################################################
max_input_len: int = Field(default=1024, description="The maximum input length.")
max_num_tokens: Optional[int] = Field(default=None, description="The maximum number of tokens.")
@ -420,6 +427,19 @@ class LlmArgs(AutoDeployConfig, BaseLlmArgs, BaseSettings):
msg = "AutoDeploy only supports parallelization via the `world_size` argument."
return _check_for_default_value_only(cls, value, info, msg)
@model_validator(mode="after")
def setup_hidden_state_capture(self):
if self.speculative_config is None or not isinstance(
self.speculative_config, EagleDecodingConfig
):
return self
self.transforms["detect_hidden_states_for_capture"]["capture_hidden_states"] = True
self.transforms["detect_hidden_states_for_capture"]["eagle3_layers_to_capture"] = (
self.speculative_config.eagle3_layers_to_capture
)
return self
@model_validator(mode="after")
def validate_parallel_config(self):
"""Setup parallel config according to world_size.

View File

@ -13,14 +13,17 @@ import copy
import types
from collections import defaultdict
from dataclasses import dataclass
from types import SimpleNamespace
from types import MethodType, SimpleNamespace
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from strenum import StrEnum
from torch._prims_common import DeviceLikeType
from tensorrt_llm._torch.attention_backend.interface import AttentionRuntimeFeatures
from tensorrt_llm._torch.auto_deploy.utils._graph import get_input_embeddings, get_lm_head_weights
from tensorrt_llm._torch.models.modeling_speculative import Eagle3ForCausalLM
from tensorrt_llm._torch.pyexecutor._util import (
_create_kv_cache_manager,
get_decoding_mode,
@ -32,9 +35,11 @@ from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest, get_draft_tok
from tensorrt_llm._torch.pyexecutor.py_executor_creator import get_guided_decoding_config
from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
from tensorrt_llm._torch.speculative import get_spec_drafter
from tensorrt_llm._torch.speculative.eagle3 import Eagle3ResourceManager
from tensorrt_llm._utils import nvtx_range
from tensorrt_llm.llmapi.llm_args import (
ContextChunkingPolicy,
EagleDecodingConfig,
LoadFormat,
SamplerType,
TorchLlmArgs,
@ -57,6 +62,7 @@ from ...pyexecutor.sampler import TorchSampler, TRTLLMSampler
from ...pyexecutor.scheduler import (
BindCapacityScheduler,
BindMicroBatchScheduler,
RequestList,
ScheduledRequests,
SimpleScheduler,
)
@ -113,6 +119,90 @@ class _CacheManagerWithFakePool(KVCacheManager):
return self.num_blocks, 0
class ADHiddenStateManager(Eagle3ResourceManager):
def __init__(
self,
cache_seq_interface: CachedSequenceInterface,
config: EagleDecodingConfig,
max_num_requests: int,
max_seq_len: int,
max_num_tokens: int,
):
hidden_state_buffer = self._get_hidden_state_buffers(cache_seq_interface)[0]
dtype = hidden_state_buffer.dtype
hidden_size = hidden_state_buffer.shape[1]
super().__init__(config, dtype, hidden_size, max_num_requests, max_seq_len, max_num_tokens)
self.hidden_state_write_indices: torch.Tensor = torch.empty(
max_num_tokens, dtype=torch.long, device="cuda"
)
def _get_hidden_state_buffers(
self, cache_seq_interface: CachedSequenceInterface
) -> List[torch.Tensor]:
hidden_state_buffers = []
for name, tensor in cache_seq_interface.named_args.items():
if "hidden_states_cache" in name:
hidden_state_buffers.append(tensor)
if not hidden_state_buffers:
raise ValueError(
"No hidden_state_buffers found in cache_seq_interface. Check if we are actually running Eagle3."
)
return hidden_state_buffers
def prepare_hidden_states_capture(
self, ordered_requests: RequestList, cache_seq_interface: CachedSequenceInterface
) -> None:
"""Prepare the hidden states for capture by establishing indices that the hidden states will be written to."""
seq_lens = cache_seq_interface.info.seq_len
num_tokens = sum(seq_lens)
start_idx = 0
hidden_states_write_indices = []
for request, seq_len in zip(ordered_requests, seq_lens):
request_id = request.request_id
slot_id = self.slot_manager.get_slot(request_id)
self.start_indices[slot_id] = start_idx
hidden_states_write_indices.extend(range(start_idx, start_idx + seq_len))
start_idx += max(seq_len, self.max_total_draft_tokens + 1)
assert start_idx < self.hidden_states.shape[0], (
f"start_idx {start_idx} exceeds hidden_states capacity {self.hidden_states.shape[0]}"
)
if len(hidden_states_write_indices) != num_tokens:
raise ValueError(
f"len(hidden_state_write_indices) ({len(hidden_states_write_indices)}) != num_tokens \
({num_tokens}). Check whether ordered_requests matches up with seq_lens."
)
hidden_state_write_indices_host = torch.tensor(
hidden_states_write_indices, dtype=torch.long
)
self.hidden_state_write_indices[:num_tokens].copy_(
hidden_state_write_indices_host, non_blocking=True
)
def capture_hidden_states(self, cache_seq_interface: CachedSequenceInterface) -> None:
"""Capture configured hidden states that have been written by the model,
in a format that can be used by the draft model.
"""
full_hidden_states = self._get_hidden_state_buffers(cache_seq_interface)
if not full_hidden_states:
return
num_tokens = sum(cache_seq_interface.info.seq_len)
hidden_states = [hidden_state[:num_tokens] for hidden_state in full_hidden_states]
hidden_states = torch.cat(hidden_states, dim=1)
hidden_states = hidden_states.to(dtype=self.dtype)
token_idx = self.hidden_state_write_indices[:num_tokens]
self.hidden_states[:, : hidden_states.shape[1]].index_copy_(0, token_idx, hidden_states)
def construct_draft_llm_args(
ad_config: LlmArgs,
) -> TorchLlmArgs:
@ -461,6 +551,10 @@ class ADEngine(ModelEngine):
kv_cache_manager = resource_manager.get_resource_manager(
ResourceManagerType.KV_CACHE_MANAGER
)
# resource manager for hidden state capture
spec_resource_manager = resource_manager.get_resource_manager(
ResourceManagerType.SPEC_RESOURCE_MANAGER
)
# requests in order of context, generate
context_requests = scheduled_requests.context_requests
@ -471,6 +565,7 @@ class ADEngine(ModelEngine):
r for r in scheduled_requests.generation_requests if get_draft_token_length(r) == 0
]
gen_requests = extend_requests + generation_requests
ordered_requests = context_requests + gen_requests
# info to be extracted
input_ids: List[List[int]] = []
position_ids: List[List[int]] = []
@ -670,6 +765,13 @@ class ADEngine(ModelEngine):
self.cache_seq_interface.info.run_host_prepare_for_attention_forward()
if spec_resource_manager is not None and isinstance(
spec_resource_manager, ADHiddenStateManager
):
spec_resource_manager.prepare_hidden_states_capture(
ordered_requests, self.cache_seq_interface
)
self.iter_states["num_ctx_requests"] = num_ctx_requests
self.iter_states["num_ctx_tokens"] = num_ctx_tokens
# TODO: handle extend requests and draft requests for specdec
@ -710,14 +812,74 @@ class ADEngine(ModelEngine):
outputs = {
"logits": self._compute_logits(),
}
# save hidden states after running model.forward() in _compute_logits()
spec_resource_manager = resource_manager.get_resource_manager(
ResourceManagerType.SPEC_RESOURCE_MANAGER
)
if spec_resource_manager is not None and isinstance(
spec_resource_manager, ADHiddenStateManager
):
spec_resource_manager.capture_hidden_states(self.cache_seq_interface)
if self.mapping is not None:
self._execute_logit_post_processors(scheduled_requests, outputs)
return outputs
def share_target_weights_with_draft(
target_model_engine: "ADEngine", draft_model_engine: PyTorchModelEngine
):
"""
Certain speculative decoding methods (e.g. Eagle3) require sharing the target model's embedding and lm_head weights
with the draft model. This function does this sharing if necessary.
"""
assert isinstance(draft_model_engine.model, Eagle3ForCausalLM), (
f"Expected draft_model_engine.model to be Eagle3ForCausalLM, got {type(draft_model_engine.model)}"
)
def share_embedding_weights_with_draft(
target_model_engine: "ADEngine", draft_model_engine: PyTorchModelEngine
):
embedding_weight = get_input_embeddings(target_model_engine.model)
world_size = mpi_world_size()
assert world_size <= 1, f"This code assumes tp<=1. World size: {world_size}"
# Note: This simple forward function implementation assumes tp=1.
# TODO(govind): Handle the tp>1 case.
def new_embedding_forward(self, input_ids):
return F.embedding(input_ids, self.weight)
if draft_model_engine.model.model.embed_tokens is None:
submodule = torch.nn.Module()
submodule.forward = MethodType(new_embedding_forward, submodule)
submodule.weight = embedding_weight
draft_model_engine.model.model.embed_tokens = submodule
def share_lm_head_weights_with_draft(
target_model_engine: "ADEngine", draft_model_engine: PyTorchModelEngine
):
vocab_size = target_model_engine.cache_seq_interface.info.vocab_size_padded
lm_head_weight = get_lm_head_weights(target_model_engine.model)
assert lm_head_weight.shape[0] == vocab_size, (
f"Expected lm_head weight first dimension to be vocab_size={vocab_size}, "
f"but got shape {lm_head_weight.shape}"
)
if draft_model_engine.model.load_lm_head_from_target:
draft_model_engine.model.lm_head.weight = lm_head_weight
share_embedding_weights_with_draft(target_model_engine, draft_model_engine)
share_lm_head_weights_with_draft(target_model_engine, draft_model_engine)
def create_draft_model_engine_maybe(
ad_config: LlmArgs, engine, dist_mapping: Mapping, mpi_dist: MPIDist
ad_config: LlmArgs, target_engine: ADEngine, dist_mapping: Mapping, mpi_dist: MPIDist
) -> Optional[PyTorchModelEngine]:
"""Create a draft model engine for speculative decoding.
@ -745,7 +907,7 @@ def create_draft_model_engine_maybe(
chunked_prefill=ad_config.enable_chunked_prefill,
cache_reuse=kv_cache_config.enable_block_reuse,
has_speculative_draft_tokens=has_spec_drafter,
chunk_size=engine.llm_args.max_num_tokens,
chunk_size=target_engine.llm_args.max_num_tokens,
)
# Construct TorchLlmArgs for the draft model
@ -753,6 +915,10 @@ def create_draft_model_engine_maybe(
ad_config=ad_config,
)
# chain drafter is not supported currently for AutoDeploy.
# TODO(govind): Do this when we want to optimize 2-model spec dec performance.
drafting_loop_wrapper = None
draft_model_engine = PyTorchModelEngine(
model_path=draft_spec_config.speculative_model_dir,
llm_args=draft_llm_args,
@ -761,9 +927,14 @@ def create_draft_model_engine_maybe(
dist=mpi_dist,
spec_config=draft_spec_config,
is_draft_model=True,
drafting_loop_wrapper=None,
drafting_loop_wrapper=drafting_loop_wrapper,
)
if draft_spec_config.spec_dec_mode.is_eagle3():
share_target_weights_with_draft(
target_model_engine=target_engine, draft_model_engine=draft_model_engine
)
draft_model_engine.kv_cache_manager_key = ResourceManagerType.DRAFT_KV_CACHE_MANAGER
return draft_model_engine
@ -855,9 +1026,11 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
engine = ADEngine.build_from_config(ad_config=ad_config, mapping=dist_mapping)
spec_config = ad_config.speculative_config
if spec_config is not None and not spec_config.spec_dec_mode.is_draft_target():
if spec_config is not None and not (
spec_config.spec_dec_mode.is_draft_target() or spec_config.spec_dec_mode.is_eagle3()
):
raise ValueError(
"Currently, AutoDeploy only supports speculative decoding in draft target mode."
"Currently, AutoDeploy only supports speculative decoding in draft target or eagle3 mode."
)
if spec_config is not None and ad_config.guided_decoding_backend is not None:
@ -865,11 +1038,20 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
"Guided decoding is not currently supported for speculative decoding in AutoDeploy."
)
# Speculative resource manager not needed for DraftTargetDecoding.
spec_resource_manager = None
draft_model_engine = create_draft_model_engine_maybe(
ad_config=ad_config, engine=engine, dist_mapping=dist_mapping, mpi_dist=mpi_dist
ad_config=ad_config, target_engine=engine, dist_mapping=dist_mapping, mpi_dist=mpi_dist
)
spec_resource_manager = (
ADHiddenStateManager(
cache_seq_interface=engine.cache_seq_interface,
config=spec_config,
max_num_requests=ad_config.max_batch_size,
max_seq_len=engine.llm_args.max_seq_len,
max_num_tokens=engine.llm_args.max_num_tokens,
)
if isinstance(spec_config, EagleDecodingConfig)
else None
)
# check kvcache config for partial block reuse

View File

@ -25,7 +25,9 @@ from typing import Tuple
import torch
from torch.fx import GraphModule
from ...utils.node_utils import is_linear_op, is_op
from tensorrt_llm._torch.auto_deploy.utils._graph import get_lm_head_node
from ...utils.node_utils import is_linear_op
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
@ -54,9 +56,7 @@ class GatherLogitsBeforeLmHeadTransform(BaseTransform):
self._log_info("Applying GatherLogitsBeforeLmHead transform...")
# assume lm head node is the input to the output node
lm_head_node = gm.graph.find_nodes(op="output")[0].all_input_nodes[0]
if is_op(lm_head_node, torch.ops.aten.to):
lm_head_node = lm_head_node.all_input_nodes[0]
lm_head_node = get_lm_head_node(gm)
if is_linear_op(lm_head_node):
node_to_gather = lm_head_node.all_input_nodes[0]

View File

@ -0,0 +1,239 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The transform passes to capture the hidden states of the target model."""
from typing import Dict, List, Optional, Set, Tuple, Type
import torch
from torch._ops import OpOverloadPacket
from torch.fx import GraphModule, Node
from ...custom_ops.attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,
CacheConfig,
CacheInitializerDict,
MHACallable,
SequenceInfo,
)
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.node_utils import get_all_layer_subgraphs, is_op
from ..interface import (
BaseTransform,
SharedConfig,
TransformConfig,
TransformInfo,
TransformRegistry,
)
from .kvcache import InsertCachedAttention
@torch.library.custom_op("auto_deploy::residual_add_for_capture", mutates_args=())
def residual_add_for_capture(t1: torch.Tensor, t2: torch.Tensor) -> torch.Tensor:
return torch.ops.aten.add(t1, t2)
@residual_add_for_capture.register_fake
def residual_add_for_capture_fake(t1: torch.Tensor, t2: torch.Tensor) -> torch.Tensor:
return torch.ops.aten.add(t1, t2)
@torch.library.custom_op("auto_deploy::cached_residual_add", mutates_args=())
def cached_residual_add(
t1: torch.Tensor, t2: torch.Tensor, hidden_states_cache: torch.Tensor
) -> torch.Tensor:
ret = torch.ops.aten.add(t1, t2)
b, s, _ = ret.shape
num_tokens = b * s
hidden_states_cache[:num_tokens].copy_(ret.view(num_tokens, -1), non_blocking=True)
return ret
@cached_residual_add.register_fake
def cached_residual_add_fake(
t1: torch.Tensor, t2: torch.Tensor, hidden_states_cache: torch.Tensor
) -> torch.Tensor:
return torch.ops.aten.add(t1, t2)
class DetectHiddenStatesForCaptureConfig(TransformConfig):
"""Configuration for the hidden states detection transform."""
# Whether to capture hidden states at all. If False we will not capture any layers.
capture_hidden_states: bool = False
# TODO: figure out how to get layers to capture.
# We should consider if we can use the layer indices stored in eagle checkpoints, e.g.
# https://huggingface.co/nvidia/gpt-oss-120b-Eagle3/blob/main/config.json#L9-L14
eagle3_layers_to_capture: Optional[Set[int]] = None
def set_default_eagle3_layers_to_capture(self, num_hidden_layers: int):
"""
Used to set default layers to capture when we want to capture hidden states, but
no layers to capture are provided.
"""
if num_hidden_layers <= 6:
raise ValueError("Not enough hidden layers for default EAGLE3 capture")
self.eagle3_layers_to_capture = {1, num_hidden_layers // 2 - 1, num_hidden_layers - 4}
@TransformRegistry.register("detect_hidden_states_for_capture")
class DetectHiddenStatesForCapture(BaseTransform):
"""Detect the hidden states we should capture in the graph."""
config: DetectHiddenStatesForCaptureConfig
@classmethod
def get_config_class(cls) -> Type[TransformConfig]:
return DetectHiddenStatesForCaptureConfig
def collect_residual_add_nodes(self, gm: GraphModule) -> Dict[int, Node]:
def _get_layer_number(lin_node: Node) -> Optional[int]:
weight = lin_node.args[1]
if weight.op == "get_attr":
subnames = weight.target.split(".")
for subname in subnames:
if subname.isdigit():
return int(subname)
return None
# find last closing linear node of each layer
# from there we will find the residual add node for that layer
layer_subgraphs, unprocessed_linear_nodes = get_all_layer_subgraphs(gm)
residual_add_nodes: Dict[int, Node] = {}
for layer_subgraph in layer_subgraphs:
lin_node_closing = layer_subgraph.terminating_node
# need layer number to correctly identify the residual add node
layer_number = _get_layer_number(lin_node_closing)
if layer_number is None:
continue
# Conditions to identify as the hidden states after the residual
# The first node after the linear closing node that satisfies:
# 1. is an add node with > 1 users (hidden states before the last are used directly by next layer
# as well as having a residual add to the next hidden state). Stopping here prevents us from
# using a future residual add node for the next layer.
# 2. is the last add node in a 1 user chain (for last layer or layers with no following residual add)
# This stops us before we go to the next layer.
res_node = lin_node_closing
while len(res_node.users) == 1:
user_node = list(res_node.users)[0]
if not is_op(user_node, torch.ops.aten.add):
break
res_node = user_node
if is_op(res_node, torch.ops.aten.add):
# this stores the last residual add node encountered for each layer
residual_add_nodes[layer_number] = res_node
return residual_add_nodes
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
if not self.config.capture_hidden_states:
info = TransformInfo(skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True)
return gm, info
if gm.graph.find_nodes(
op="call_function", target=torch.ops.auto_deploy.residual_add_for_capture.default
):
info = TransformInfo(skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True)
return gm, info
residual_add_nodes = self.collect_residual_add_nodes(gm)
if self.config.eagle3_layers_to_capture is None:
num_hidden_layers = len(residual_add_nodes)
self.config.set_default_eagle3_layers_to_capture(num_hidden_layers)
residual_add_nodes = {
k: v for k, v in residual_add_nodes.items() if k in self.config.eagle3_layers_to_capture
}
assert residual_add_nodes.keys() == self.config.eagle3_layers_to_capture, (
f"Unable to find residual add nodes for layers. Expected: {self.config.eagle3_layers_to_capture}, \
Found: {residual_add_nodes.keys()}"
)
# replace residual add nodes with special placeholder nodes
for layer_number, res_node in residual_add_nodes.items():
with gm.graph.inserting_before(res_node):
new_node = gm.graph.call_function(
torch.ops.auto_deploy.residual_add_for_capture.default,
args=res_node.args,
kwargs=res_node.kwargs,
)
res_node.replace_all_uses_with(new_node)
gm.graph.erase_node(res_node)
cnt = len(residual_add_nodes)
info = TransformInfo(
skipped=False, num_matches=cnt, is_clean=(cnt == 0), has_valid_shapes=(cnt == 0)
)
return gm, info
@AttentionRegistry.register("cached_residual_add")
class CachedResidualAdd(AttentionDescriptor):
@classmethod
def is_paged(cls) -> bool:
return True
@classmethod
def get_attention_layout(cls) -> AttentionLayout:
return "bsnd"
@classmethod
def get_num_qkv_args(cls) -> int:
return 2
@classmethod
def get_source_attention_op(cls) -> OpOverloadPacket:
return torch.ops.auto_deploy.residual_add_for_capture
@classmethod
def get_cached_attention_op(cls) -> MHACallable:
return torch.ops.auto_deploy.cached_residual_add
@classmethod
def get_cache_initializers(
cls, source_attn_node: Node, cache_config: CacheConfig
) -> CacheInitializerDict:
hidden_size = source_attn_node.meta["val"].shape[-1]
hidden_type = source_attn_node.meta["val"].dtype
def _get_hidden_states_cache(si: SequenceInfo):
return torch.empty(si.max_num_tokens, hidden_size, device=si.device, dtype=hidden_type)
return {"hidden_states_cache": _get_hidden_states_cache}
@classmethod
def get_standard_metadata_args(cls) -> List[str]:
return []
@TransformRegistry.register("insert_cached_residual_add")
class InsertCachedResidualAdd(InsertCachedAttention):
"""A transform to handle residual add cache operations."""

View File

@ -17,7 +17,7 @@ from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils._pytree import _LEAF_SPEC
from .logger import ad_logger
from .node_utils import is_op
from .node_utils import get_weight_tensor, is_op
_NoValType = type("_NoValType", (), {})
_NO_VAL = _NoValType()
@ -344,3 +344,60 @@ def placeholders_on_meta(mod: nn.Module) -> bool:
return True
return False
def get_input_embeddings(model: nn.Module) -> torch.Tensor:
"""Find the unique embedding node across all graph modules."""
embedding_weights = []
for _, gm in named_graphmodules(model):
found_nodes = gm.graph.find_nodes(
op="call_function", target=torch.ops.aten.embedding.default
)
for node in found_nodes:
embedding_weights.append(get_weight_tensor(gm, node))
if hasattr(model, "get_input_embeddings"):
embedding_weights.append(model.get_input_embeddings())
for _, gm in named_graphmodules(model):
if hasattr(gm, "get_input_embeddings"):
embedding_weights.append(gm.get_input_embeddings())
assert len(embedding_weights) > 0, "No embedding weights found"
unique_embedding_weights = [embedding_weights[0]]
for weight in embedding_weights:
if weight is not unique_embedding_weights[0]:
unique_embedding_weights.append(weight)
assert len(unique_embedding_weights) == 1, (
f"Expected exactly 1 unique embedding weight, but found {len(unique_embedding_weights)}."
)
return unique_embedding_weights[0]
def get_output_node(model: nn.Module) -> tuple[GraphModule, Node]:
"""Find the unique output node across all graph modules."""
output_nodes = []
for _, gm in named_graphmodules(model):
output_nodes.extend([(gm, node) for node in gm.graph.find_nodes(op="output")])
assert len(output_nodes) == 1, f"Expected exactly 1 output node, but found {len(output_nodes)}."
return output_nodes[0]
def get_lm_head_node(gm: GraphModule, output_node: Optional[Node] = None) -> Node:
if output_node is None:
output_node = gm.graph.find_nodes(op="output")[0]
lm_head_node = output_node.all_input_nodes[0]
if is_op(lm_head_node, torch.ops.aten.to):
lm_head_node = lm_head_node.all_input_nodes[0]
return lm_head_node
def get_lm_head_weights(model: nn.Module) -> torch.Tensor:
gm, output_node = get_output_node(model)
lm_head_node = get_lm_head_node(gm, output_node)
return get_weight_tensor(gm, lm_head_node)

View File

@ -923,6 +923,12 @@ def shape(node: Node) -> Tuple[int, ...]:
return node.meta["val"].shape
def get_weight_tensor(gm: GraphModule, node: Node) -> "torch.Tensor":
"""Extract the weight tensor from a node within a GraphModule."""
weight_name = extract_param_names_from_node(node)[0]
return gm.get_parameter(weight_name)
def draw_graph(gm: GraphModule, filename: str):
"""
Dump graphmodule to SVG file using PyTorch's built-in drawer.

View File

@ -19,39 +19,58 @@ import pytest
from build_and_run_ad import ExperimentConfig, main
from defs.conftest import llm_models_root
from tensorrt_llm.llmapi import DraftTargetDecodingConfig, KvCacheConfig
from tensorrt_llm import SamplingParams
from tensorrt_llm._torch.auto_deploy.llm import LLM
from tensorrt_llm.llmapi import DraftTargetDecodingConfig, EagleDecodingConfig, KvCacheConfig
prompts = [
"What is the capital of France?",
"Please explain the concept of gravity in simple words and a single sentence.",
"What is the capital of Norway?",
"What is the highest mountain in the world?",
]
EAGLE_MODEL_SUBPATH = "EAGLE3-LLaMA3.1-Instruct-8B"
LLAMA_BASE_SUBPATH = "llama-3.1-model/Llama-3.1-8B-Instruct"
DRAFT_TARGET_MAX_DRAFT_LEN = 3
EAGLE_MAX_DRAFT_LEN = 3
def get_model_paths():
"""Get model paths using llm_models_root()."""
models_root = llm_models_root()
base_model = os.path.join(
models_root,
"llama-3.1-model/Llama-3.1-8B-Instruct",
)
speculative_model = os.path.join(
base_model = os.path.join(models_root, LLAMA_BASE_SUBPATH)
draft_target_model = os.path.join(
models_root,
"llama-models-v2/TinyLlama-1.1B-Chat-v1.0",
)
eagle_model = os.path.join(models_root, EAGLE_MODEL_SUBPATH)
print(f"Base model path: {base_model}")
print(f"Speculative model path: {speculative_model}")
return base_model, speculative_model
print(f"DraftTarget draft model path: {draft_target_model}")
print(f"EAGLE model path: {eagle_model}")
return base_model, draft_target_model, eagle_model
def run_with_autodeploy(model, speculative_model_dir, batch_size):
def make_draft_target_config(spec_model_path: str):
return DraftTargetDecodingConfig(
max_draft_len=DRAFT_TARGET_MAX_DRAFT_LEN, speculative_model_dir=spec_model_path
)
def make_eagle3_config(spec_model_path: str):
return EagleDecodingConfig(
max_draft_len=EAGLE_MAX_DRAFT_LEN,
speculative_model_dir=spec_model_path,
eagle3_one_model=False,
eagle3_layers_to_capture=None,
)
def run_with_autodeploy(model, speculative_config, batch_size):
"""Run AutoDeploy with or without speculative decoding.
Args:
model: Path to the base model
speculative_model_dir: Path to the speculative model (None for baseline mode)
speculative_config: Speculative decoding config (None for baseline mode)
batch_size: Number of prompts to process
Returns:
@ -60,16 +79,11 @@ def run_with_autodeploy(model, speculative_model_dir, batch_size):
# Select prompts based on batch size
selected_prompts = prompts[:batch_size]
# Configure speculative decoding if speculative_model_dir is provided
spec_config = None
if speculative_model_dir is not None:
spec_config = DraftTargetDecodingConfig(
max_draft_len=3, speculative_model_dir=speculative_model_dir
)
spec_config = speculative_config
# Configure KV cache
kv_cache_config = KvCacheConfig(
free_gpu_memory_fraction=0.1,
free_gpu_memory_fraction=0.01,
)
# Configure AutoDeploy LLM arguments
@ -81,9 +95,6 @@ def run_with_autodeploy(model, speculative_model_dir, batch_size):
"world_size": 1,
"kv_cache_config": kv_cache_config,
"disable_overlap_scheduler": True,
"transforms": {
"fuse_rmsnorm": {"rmsnorm_backend": "triton"},
},
"max_num_tokens": 64,
}
@ -116,34 +127,46 @@ def run_with_autodeploy(model, speculative_model_dir, batch_size):
return result["prompts_and_outputs"]
@pytest.mark.parametrize("batch_size", [1, 4])
def test_autodeploy_spec_dec(batch_size):
"""Test AutoDeploy speculative decoding with different batch sizes.
# Note: This test tests exact equality of outputs between speculative and baseline modes.
# This can fail for larger batch sizes due to nondeterminism with in flight batching.
# TODO: Figure out a robust test for output correctness that can pass for larger batch sizes.
@pytest.mark.parametrize("spec_dec_mode", ["draft_target", "eagle3"])
def test_autodeploy_spec_dec_output(spec_dec_mode):
"""Test AutoDeploy speculative decoding output correctness.
Runs with and without speculative decoding and verifies outputs are identical.
"""
print("\n" + "=" * 80)
print(f"Testing AutoDeploy Speculative Decoding - Batch Size {batch_size}")
print(f"Testing AutoDeploy Speculative Decoding ({spec_dec_mode}) - Output Correctness")
print("=" * 80)
base_model, speculative_model = get_model_paths()
base_model, draft_target_model, eagle_model = get_model_paths()
# Select model and config based on mode
if spec_dec_mode == "draft_target":
spec_model = draft_target_model
spec_config = make_draft_target_config(spec_model)
elif spec_dec_mode == "eagle3": # eagle3
spec_model = eagle_model
spec_config = make_eagle3_config(spec_model)
else:
raise ValueError(f"Unsupported speculative decoding mode: {spec_dec_mode}")
print(f"\nBase Model: {base_model}")
print(f"Speculative Model: {speculative_model}")
print(f"Batch Size: {batch_size}")
print(f"Speculative Model ({spec_dec_mode}): {spec_model}")
# Run with speculative decoding
print("\n[1/2] Running with speculative decoding enabled...")
spec_outputs = run_with_autodeploy(
model=base_model, speculative_model_dir=speculative_model, batch_size=batch_size
model=base_model,
speculative_config=spec_config,
batch_size=1,
)
print(f"Generated {len(spec_outputs)} outputs with speculative decoding")
# Run without speculative decoding (baseline)
print("\n[2/2] Running without speculative decoding (baseline)...")
baseline_outputs = run_with_autodeploy(
model=base_model, speculative_model_dir=None, batch_size=batch_size
)
baseline_outputs = run_with_autodeploy(model=base_model, speculative_config=None, batch_size=1)
print(f"Generated {len(baseline_outputs)} outputs in baseline mode")
# Verify outputs are identical
@ -171,3 +194,85 @@ def test_autodeploy_spec_dec(batch_size):
print("\n" + "=" * 80)
print("SUCCESS! All outputs are identical between spec-dec and baseline modes")
print("=" * 80)
def test_autodeploy_eagle3_acceptance_rate():
"""Test Eagle3 acceptance rate with AutoDeploy engine.
Runs Eagle3 speculative decoding with streaming and verifies
that the acceptance rate is above a minimum threshold.
"""
print("\n" + "=" * 80)
print("Testing AutoDeploy Eagle3 Acceptance Rate")
print("=" * 80)
base_model, _, eagle_model = get_model_paths()
print(f"\nBase Model: {base_model}")
print(f"Eagle3 Model: {eagle_model}")
max_draft_len = EAGLE_MAX_DRAFT_LEN
# Configure Eagle3 speculative decoding
speculative_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model,
eagle3_one_model=False,
eagle3_layers_to_capture=None,
)
# Configure KV cache
kv_cache_config = KvCacheConfig(
free_gpu_memory_fraction=0.01,
)
# Create AutoDeploy LLM with Eagle3 speculative decoding
# We directly instantiate the LLM class instead of using the main() function
# so that we can stream the outputs to see acceptance rates without needing to
# collect them in the executor.
llm = LLM(
model=base_model,
skip_loading_weights=False,
runtime="trtllm",
world_size=1,
kv_cache_config=kv_cache_config,
speculative_config=speculative_config,
disable_overlap_scheduler=True,
max_num_tokens=64,
)
# Tokenize 2 prompts to test multiple sequential requests
batch_tok_ids = [llm.tokenizer.encode(p) for p in prompts[:2]]
sampling_params = SamplingParams(max_tokens=128, temperature=0, seed=42)
print("\nRunning Eagle3 speculative decoding with streaming...")
# Process each request sequentially and verify acceptance rate
for i in range(len(batch_tok_ids)):
num_tokens = 0
num_drafted = 0
num_accepted = 0
for output in llm.generate_async(batch_tok_ids[i], sampling_params, streaming=True):
new_tokens = output.outputs[0].token_ids
num_drafted += max_draft_len
num_accepted += len(new_tokens) - num_tokens - 1
num_tokens = len(new_tokens)
accept_rate = num_accepted / num_drafted
print(f"\nRequest {i + 1} Acceptance Rate Statistics:")
print(f" Total tokens drafted: {num_drafted}")
print(f" Total tokens accepted: {num_accepted}")
print(f" Acceptance rate: {accept_rate:.2%}")
# Verify acceptance rate is above minimum threshold (10%)
min_acceptance_rate = 0.10
assert accept_rate > min_acceptance_rate, (
f"Request {i + 1}: Acceptance rate {accept_rate:.2%} is below minimum threshold {min_acceptance_rate:.0%}"
)
print("\n" + "=" * 80)
print("SUCCESS! All requests passed acceptance rate threshold")
print("=" * 80)

View File

@ -116,8 +116,9 @@ l0_h100:
- accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[True]
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16
- examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec[1]
- examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec[4]
- examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[draft_target]
- examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[eagle3]
- examples/test_ad_speculative_decoding.py::test_autodeploy_eagle3_acceptance_rate
- condition:
ranges:
system_gpu_count: