[None][feat] Export ONNX for DriveOS LLM (#10117)

Signed-off-by: yocox <yocox@nvidia.com>
This commit is contained in:
nvyocox 2026-01-31 04:43:11 +08:00 committed by GitHub
parent f42a6cbae0
commit 4af47208d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 4115 additions and 21 deletions

1
.gitignore vendored
View File

@ -1,5 +1,6 @@
__pycache__/
.vscode
.cursor
*.engine
*.engine.config
*.cache

View File

@ -68,6 +68,7 @@ init_ubuntu() {
gdb \
git-lfs \
clang \
graphviz \
lld \
llvm \
libclang-rt-dev \

View File

@ -0,0 +1,93 @@
# Export ONNX for EdgeLLM
AutoDeploy provides a mode to export PyTorch/HuggingFace models to ONNX format specifically designed for EdgeLLM deployment. This mode performs graph transformations to fuse RoPE (Rotary Position Embedding) and attention operations into a single `AttentionPlugin` operation, then exports the optimized graph to ONNX.
## Overview
The `export_edgellm_onnx` mode differs from the standard AutoDeploy workflow in two key ways:
1. **Operation Fusion**: Fuses `torch_rope_with_explicit_cos_sin` and `torch_cached_attention_with_cache` into a single `AttentionPlugin` operation
1. **ONNX Export**: Outputs an ONNX model file instead of a TensorRT Engine
## Quick Start
Use the `onnx_export_llm.py` script to export a model:
```bash
cd examples/auto_deploy
python onnx_export_llm.py --model "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
```
This will export the model to ONNX format in the current directory.
## Command Line Options
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--model` | str | Required | HuggingFace model name or path to a local checkpoint |
| `--device` | str | `cpu` | Device to use for export (`cpu` or `cuda`) |
| `--output_dir` | str | `.` | Directory to save the exported ONNX model |
## Examples
### Basic Export
Export a DeepSeek model with default settings:
```bash
python onnx_export_llm.py --model "Qwen/Qwen2.5-0.5B-Instruct"
```
### Custom Output Location
Export to a specific directory with a custom filename:
```bash
python onnx_export_llm.py \
--model "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" \
--output_dir "./exported_models"
```
## Output Files
The export process generates the following files in the output directory:
| File | Description |
|------|-------------|
| `model.onnx` | The exported ONNX model with fused attention operations |
| `config.json` | Model configuration (architecture, hidden size, etc.) |
| `tokenizer.json` | Tokenizer vocabulary and configuration |
| `tokenizer_config.json` | Tokenizer settings |
| `special_tokens_map.json` | Special token mappings |
| `processed_chat_template.json` | Processed chat template for inference |
## Programmatic Usage
You can also use the ONNX export functionality programmatically:
```python
from tensorrt_llm._torch.auto_deploy import LLM, AutoDeployConfig
# Create AutoDeploy config with export_edgellm_onnx mode
ad_config = AutoDeployConfig(
model="Qwen/Qwen2.5-0.5B-Instruct",
mode="export_edgellm_onnx",
max_batch_size=8,
max_seq_len=512,
device="cpu",
)
# Configure attention backend
ad_config.attn_backend = "torch"
# Optionally customize output location
ad_config.transforms["export_to_onnx"]["output_dir"] = "./my_output"
# Run the export
LLM(**ad_config.to_llm_kwargs())
```
## Notes
- **Device Selection**: Using `cpu` for the `--device` option is recommended to reduce GPU memory footprint during export.
- **Custom Operations**: The exported ONNX model contains custom operations (e.g., `AttentionPlugin`) in the `trt` domain that require corresponding implementations in the target inference runtime.

View File

@ -60,6 +60,7 @@ The exported graph then undergoes a series of automated transformations, includi
- [Expert Configurations](./advanced/expert_configurations.md)
- [Performance Benchmarking](./advanced/benchmarking_with_trtllm_bench.md)
- [Serving with trtllm-serve](./advanced/serving_with_trtllm_serve.md)
- [Export ONNX for EdgeLLM](./advanced/export_onnx.md)
## Roadmap

View File

@ -0,0 +1,78 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ONNX export script for AutoDeploy models.
This script exports a HuggingFace model to ONNX format using the AutoDeploy
transform pipeline directly, without initializing the full LLM executor.
"""
import argparse
from tensorrt_llm._torch.auto_deploy.export import export_onnx
from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs
def main():
parser = argparse.ArgumentParser(
description="Export HuggingFace model to ONNX format using AutoDeploy."
)
parser.add_argument(
"--model",
type=str,
required=True,
help="The HF model to use for onnx export.",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="The device to use when exporting the model.",
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="The directory to save the exported ONNX model.",
)
args = parser.parse_args()
print(f"Constructing model from {args.model}")
# to enable dynamic batch_size, the batch size must > 1
# NOTE(yoco): Originally this is 2, however, don't know why, when set to 2,
# the batch_size will collapse static int 2 even we explicitly it is dynamic axis.
# And more weird, when set to 13, the batch_size will be dynamic.
# Probably some value between 2 and 13 will work,
# We use 13 here for debugging purpose.
max_batch_size = 13
max_seq_len = 4
# Prepare the AutoDeploy config, mode is export_edgellm_onnx
ad_config = LlmArgs(
model=args.model,
mode="export_edgellm_onnx",
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
device=args.device,
)
ad_config.attn_backend = "torch"
if args.output_dir is not None:
ad_config.transforms["export_to_onnx"]["output_dir"] = args.output_dir
# Use direct InferenceOptimizer instead of LLM to avoid executor initialization
export_onnx(ad_config)
if __name__ == "__main__":
main()

View File

@ -13,7 +13,7 @@
# images are adopted from PostMerge pipelines, the abbreviated commit hash is used instead.
IMAGE_NAME=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm
LLM_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.12-py3-x86_64-ubuntu24.04-trt10.14.1.48-skip-tritondevel-202601230553-10896
LLM_SBSA_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.12-py3-aarch64-ubuntu24.04-trt10.14.1.48-skip-tritondevel-202601230553-10896
LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-13.1.0-devel-rocky8-x86_64-rocky8-py310-trt10.14.1.48-skip-tritondevel-202601230553-10896
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-13.1.0-devel-rocky8-x86_64-rocky8-py312-trt10.14.1.48-skip-tritondevel-202601230553-10896
LLM_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.12-py3-x86_64-ubuntu24.04-trt10.14.1.48-skip-tritondevel-202601281024-10117
LLM_SBSA_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.12-py3-aarch64-ubuntu24.04-trt10.14.1.48-skip-tritondevel-202601281024-10117
LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-13.1.0-devel-rocky8-x86_64-rocky8-py310-trt10.14.1.48-skip-tritondevel-202601281024-10117
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-13.1.0-devel-rocky8-x86_64-rocky8-py312-trt10.14.1.48-skip-tritondevel-202601281024-10117

View File

@ -10,6 +10,8 @@ mpi4py
numpy<2
onnx>=1.18.0,<1.20.0
onnx_graphsurgeon>=0.5.2
onnxscript==0.5.4
graphviz
openai
polygraphy
psutil

View File

@ -0,0 +1,126 @@
# This is the set of transforms running in "graph" mode. In this mode, we capture the full graph
# of the model and optimize it for inference.
transforms:
############################################################################################
# BUILD MODEL, EXPORT TO GRAPH MODULE, AND CLEAN UP
############################################################################################
build_model:
stage: factory
run_per_gm: false
device: meta
requires_clean_graph: false
export_to_gm:
stage: export
clone_state_dict: false
strict: false
run_per_gm: false
requires_clean_graph: false
cleanup_noop_slice:
stage: post_export
cleanup_noop_add:
stage: post_export
cleanup_input_constraints:
stage: post_export
############################################################################################
# RUN PATTERN MATCHER TRANSFORMATIONS TO STANDARDIZE GRAPH REPRESENTATION
############################################################################################
match_moe_pattern:
stage: pattern_matcher
match_dense_moe_pattern:
stage: pattern_matcher
match_bmm_moe_pattern:
stage: pattern_matcher
match_repeat_kv:
stage: pattern_matcher
run_shape_prop: true
match_eager_attention:
stage: pattern_matcher
requires_shape_prop: true
match_sdpa_to_torch_attention:
stage: pattern_matcher
match_grouped_attention:
stage: pattern_matcher
match_attention_layout:
stage: pattern_matcher
attn_layout: bsnd
match_rope_pattern:
stage: pattern_matcher
match_rope_layout:
stage: pattern_matcher
expected_layout: bsnd
############################################################################################
# RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION
############################################################################################
eliminate_redundant_transposes:
stage: pattern_matcher
quantize_int4_linear_from_config:
stage: pattern_matcher
quantize_fp8_linear_from_config:
stage: pattern_matcher
quantize_nvfp4_linear_from_config:
stage: pattern_matcher
quantize_fp8_bmm_from_config:
stage: pattern_matcher
quantize_fp8_from_graph:
stage: pattern_matcher
quantize_nvfp4_from_graph:
stage: pattern_matcher
quantize_fp8_moe:
stage: pattern_matcher
quantize_nvfp4_moe:
stage: pattern_matcher
quantize_mxfp4_moe:
stage: pattern_matcher
############################################################################################
# MOVE MODEL AND LOAD WEIGHTS
############################################################################################
load_weights:
stage: weight_load
run_per_gm: false
checkpoint_device: cpu
move_inputs_to_device:
stage: weight_load
checkpoint_device: cpu
run_per_gm: false
############################################################################################
# RUN POST-LOAD FUSION AND OPTIMIZATIONS
############################################################################################
fuse_gemms:
stage: post_load_fusion
enabled:
false # https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this causes OOMs on GPU
# But while export ONNX for EdgeLLM, we compile the model on CPU..
# So we can enable this transform.
# TODO(yoco):However, it currently make non-strict ONNX export fail.
fuse_moe:
stage: post_load_fusion
enabled: true
backend: trtllm
fuse_fp8_moe:
stage: post_load_fusion
enabled: true
backend: trtllm
fuse_nvfp4_moe:
stage: post_load_fusion
enabled: false
############################################################################################
# VISUALIZE GRAPH
############################################################################################
visualize_namespace:
stage: visualize
enabled: false # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/8460
############################################################################################
# FUSE Rope Attention & export to ONNX
############################################################################################
fuse_rope_attention:
stage: export_onnx
short_reshape_attention_output:
stage: export_onnx
gather_last_token_ids:
stage: export_onnx
adapt_to_edgellm:
stage: export_onnx
export_to_onnx:
stage: export_onnx
output_dir: "."
output_name: "model.onnx"

View File

@ -0,0 +1,194 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom operations for ONNX export of attention mechanisms.
This module provides placeholder custom operations for exporting attention-related
operations to ONNX format. These operations serve as intermediate representations
during the graph transformation pipeline and are intended to be replaced by actual
backend implementations during deployment.
"""
from typing import Tuple
import torch
@torch.library.custom_op("auto_deploy::torch_onnx_attention_plugin", mutates_args=())
def attention_plugin(
# Inputs
qkv: torch.Tensor,
past_key_values: torch.Tensor,
context_lengths: torch.Tensor,
rope_rotary_cos_sin: torch.Tensor,
kvcache_start_index: torch.Tensor,
# Attributes
enable_tree_attention: int,
head_size: int,
num_kv_heads: int,
num_q_heads: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Fused attention operation with integrated RoPE (Rotary Position Embedding).
This custom operation combines rotary position embedding, and scaled
dot-product attention into a single fused operation. It also handles
KV-cache management for efficient autoregressive generation.
Note:
This is a placeholder implementation for ONNX export. The actual computation
is performed by the backend runtime (e.g., TensorRT, EdgeLLM
Args:
qkv: Concatenated query, key, value tensor of shape
[batch_size, seq_len, (num_q_heads + 2 * num_kv_heads) * head_size].
past_key_values: KV-cache tensor of shape
[batch_size, 2, num_kv_heads, past_seq_len, head_size].
context_lengths: Sequence lengths for each batch element, shape [batch_size].
rope_rotary_cos_sin: Precomputed RoPE cosine and sine values of shape
[batch_size, max_seq_len, head_size].
kvcache_start_index: Starting index in KV-cache for each batch element,
shape [batch_size].
enable_tree_attention: Flag to enable tree attention mode (0 or 1).
head_size: Dimension of each attention head.
num_kv_heads: Number of key-value heads (for grouped-query attention).
num_q_heads: Number of query heads.
Returns:
A tuple containing:
- attention_output: Attention output tensor of shape
[batch_size, seq_len, num_q_heads, head_size].
- present_key_values: Updated KV-cache tensor of shape
[batch_size, 2, num_kv_heads, present_seq_len, head_size].
"""
return qkv.new_empty(0), past_key_values.new_empty(0)
@attention_plugin.register_fake
def attention_plugin_fake(
qkv: torch.Tensor,
past_key_values: torch.Tensor,
context_lengths: torch.Tensor,
rope_rotary_cos_sin: torch.Tensor,
kvcache_start_index: torch.Tensor,
enable_tree_attention: int,
head_size: int,
num_kv_heads: int,
num_q_heads: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Fake implementation of attention_plugin for torch.compile shape inference.
This function computes the output shapes without performing actual computation,
enabling torch.compile to trace through the custom operation.
Args:
qkv: Concatenated QKV tensor.
past_key_values: Previous KV-cache tensor.
context_lengths: Sequence lengths per batch.
rope_rotary_cos_sin: RoPE embedding values.
kvcache_start_index: KV-cache start indices.
enable_tree_attention: Tree attention flag.
head_size: Attention head dimension.
num_kv_heads: Number of KV heads.
num_q_heads: Number of query heads.
Returns:
Tuple of empty tensors with correct shapes for attention output and
present KV-cache.
"""
batch_size = qkv.size(0)
seq_len = qkv.size(1)
past_len = past_key_values.size(3)
present_kv_len = seq_len + past_len
attn_shape = (batch_size, seq_len, num_q_heads, head_size)
present_kv_shape = (batch_size, 2, num_kv_heads, present_kv_len, head_size)
return torch.empty(attn_shape, device=qkv.device, dtype=qkv.dtype), torch.empty(
present_kv_shape, device=past_key_values.device, dtype=past_key_values.dtype
)
def _fake_gather_nd(data: torch.Tensor, indices: torch.Tensor, batch_dims: int) -> torch.Tensor:
"""Compute output shape for GatherND operation without actual gathering.
This helper function creates an empty tensor with the correct output shape
for the GatherND operation, used for both the actual op and its fake
implementation.
Args:
data: Source tensor of shape [batch_size, seq_len, embedding_dim].
indices: Index tensor of shape [batch_size, num_selected] or
[batch_size, num_selected, index_depth].
batch_dims: Number of leading batch dimensions (must be 1).
Returns:
Empty tensor with shape [batch_size, num_selected, embedding_dim].
Raises:
AssertionError: If batch_dims != 1, data is not 3D, or indices is not 2D/3D.
"""
assert batch_dims == 1, "Current only support batch_dims = 1"
assert data.ndim == 3, "Current only support 3D data tensor"
assert indices.ndim == 2 or indices.ndim == 3, "Current only support 2D or 3D indices tensor"
dim_batch_size = indices.size(0)
dim_selected_token = indices.size(1)
dim_emb = data.size(-1)
result_shape = [dim_batch_size, dim_selected_token, dim_emb]
return torch.empty(result_shape, device=data.device, dtype=data.dtype)
@torch.library.custom_op("auto_deploy::torch_onnx_gather_nd", mutates_args=())
def gather_nd(
data: torch.Tensor,
indices: torch.Tensor,
batch_dims: int,
) -> torch.Tensor:
"""N-dimensional gather operation following ONNX gather_nd semantics.
Gathers slices from the data tensor based on indices, supporting batched
operations. This operation is commonly used for selecting specific tokens
from a sequence based on their positions.
Note:
This is a placeholder implementation for ONNX export. The actual
computation is performed by the backend runtime.
Args:
data: Source tensor to gather from, shape [batch_size, seq_len, embedding_dim].
indices: Index tensor specifying which elements to gather,
shape [batch_size, num_selected] or [batch_size, num_selected, index_depth].
batch_dims: Number of leading dimensions to treat as batch dimensions.
Currently only batch_dims=1 is supported.
Returns:
Gathered tensor of shape [batch_size, num_selected, embedding_dim].
"""
return _fake_gather_nd(data, indices, batch_dims)
@gather_nd.register_fake
def gather_nd_fake(
data: torch.Tensor,
indices: torch.Tensor,
batch_dims: int,
) -> torch.Tensor:
"""Fake implementation of gather_nd for torch.compile shape inference.
Args:
data: Source tensor to gather from.
indices: Index tensor for gathering.
batch_dims: Number of batch dimensions.
Returns:
Empty tensor with the correct output shape.
"""
return _fake_gather_nd(data, indices, batch_dims)

View File

@ -3,7 +3,7 @@
from collections import defaultdict
from contextlib import nullcontext
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.export as te
@ -15,6 +15,9 @@ from ..utils.logger import ad_logger
from ..utils.node_utils import is_op
from .interface import apply_export_patches
if TYPE_CHECKING:
from ..llm_args import LlmArgs
try:
from modelopt.torch.quantization.utils import export_torch_mode as torch_export_context
except ImportError:
@ -342,3 +345,57 @@ def torch_export_to_gm(
ad_logger.debug("exported graph: " + str(egm))
return egm
def export_onnx(ad_config: "LlmArgs") -> nn.Module:
"""Export model to ONNX using InferenceOptimizer directly.
This is a lightweight export path that avoids initializing the full LLM executor,
which requires KVCacheManager and other runtime components not needed for ONNX export.
Args:
ad_config: The AutoDeploy configuration for the model. Should use a mode like
"export_edgellm_onnx" that includes the export_to_onnx transform.
Returns:
The transformed model after running through the inference optimizer pipeline.
Example:
>>> from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs
>>> from tensorrt_llm._torch.auto_deploy.export import export_onnx
>>>
>>> ad_config = LlmArgs(
... model="meta-llama/Llama-2-7b-hf",
... mode="export_edgellm_onnx",
... max_batch_size=13,
... max_seq_len=4,
... device="cpu",
... )
>>> ad_config.transforms["export_to_onnx"]["output_dir"] = "/tmp/onnx_output"
>>> model = export_onnx(ad_config)
"""
# Import here to avoid circular imports
from ..shim.interface import CachedSequenceInterface
from ..transform.optimizer import InferenceOptimizer
# 1. Create factory from config
factory = ad_config.create_factory()
# 2. Create CachedSequenceInterface (lightweight, no KVCacheManager initialization)
cache_seq_interface = CachedSequenceInterface(
max_seq_len=ad_config.max_seq_len,
max_batch_size=ad_config.max_batch_size,
device=ad_config.device,
kv_cache_config=ad_config.kv_cache_config,
max_num_tokens=ad_config.max_num_tokens,
vocab_size_padded=factory.vocab_size_padded,
)
# 3. Create InferenceOptimizer with transform config
inference_optimizer = InferenceOptimizer(
factory=factory,
config=ad_config.transforms,
)
# 4. Run the transform pipeline (includes export_to_onnx transform)
return inference_optimizer(cache_seq_interface)

View File

@ -206,7 +206,7 @@ class LlmArgs(DynamicYamlMixInForSettings, TorchLlmArgs, BaseSettings):
)
### INFERENCE OPTIMIZER CONFIG #################################################################
mode: Literal["graph", "transformers"] = Field(
mode: Literal["graph", "transformers", "export_edgellm_onnx"] = Field(
default="graph",
description="The mode to use for the inference optimizer. Currently, we "
"support only the 'graph' and 'transformers' modes, i.e., full-graph capture + optimization"
@ -335,5 +335,6 @@ class LlmArgs(DynamicYamlMixInForSettings, TorchLlmArgs, BaseSettings):
mapping = {
"graph": str(config_path / "default.yaml"),
"transformers": str(config_path / "transformers.yaml"),
"export_edgellm_onnx": str(config_path / "export_edgellm_onnx.yaml"),
}
return mapping.get(mode)

View File

@ -0,0 +1,691 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
PyTorch GraphModule Visualization Tool
This module provides functionality to convert PyTorch GraphModule to Graphviz diagrams.
Supports different node styles and detailed graph annotations.
Key Features:
- Convert FX GraphModule to Graphviz diagrams
- Display tensor shape information on edges
- Adjust edge width based on tensor element count
- Intelligent port assignment for multi-input/output handling
- Color coding based on tensor identity
Usage Example:
import torch
import torch.fx as fx
from graph_module_visualizer import to_dot
# Trace model
model = YourModel()
traced = fx.symbolic_trace(model)
# Generate visualization
dot = to_dot(traced, format="svg", include_shapes=True)
Requirements: pip install graphviz
"""
import math
import re
from typing import Any, Dict, Optional
import graphviz
import torch
import torch.fx as fx
from torch.fx import GraphModule
from ..utils.logger import ad_logger
def _calculate_edge_width(val) -> float:
"""
Calculate edge width based on tensor element count
Formula: log10(num_elements) + 2
Args:
val: FakeTensor, Tensor, or any object with .shape attribute
"""
min_width = 2.0
max_width = 10.0
if not hasattr(val, "shape"):
return min_width # Default width
try:
# Calculate total number of elements
num_elements = 1
for dim in val.shape:
if isinstance(dim, int):
num_elements *= dim
else:
num_elements *= 1
if num_elements <= 0:
return min_width
# Calculate width: log10(element count) + 2
width = math.log10(num_elements) + 2
# Constrain width range (2.0 to 10.0)
width = max(min_width, min(max_width, width))
return width
except (ValueError, TypeError):
return min_width # Use default width on error
def _get_edge_color(source_node_name: str, output_index: int) -> str:
"""
Assign color based on source node and actual output index
Ensures all edges of the same tensor use the same color
"""
colors = [
"#FF6B9D80", # Pink
"#4ECDC480", # Mint green
"#45B7D180", # Sky blue
"#96CEB480", # Light green
"#DDA0DD80", # Light purple
"#98D8C880", # Teal
"#BB8FCE80", # Violet
"#85C1E980", # Light blue
"#F8C47180", # Peach
"#82E0AA80", # Mint
"#F7DC6F80", # Lemon yellow
"#AED6F180", # Light sky blue
]
# Use combination of source_node + output_index
# Ensures multiple edges of the same tensor have the same color
tensor_id = f"{source_node_name}_{output_index}"
color_index = hash(tensor_id) % len(colors)
return colors[color_index]
def _get_port_of_five(input_idx, total_inputs):
"""
0 xxxxxxxx a = 7
1 xxxxxxxx b = 2
2 xxxxxxx x = 2 * 8 = 16
3 xxxxxxx
4 xxxxxxx
"""
K = 5
a, b = total_inputs // K, total_inputs % K
x = b * (a + 1)
if input_idx < x:
return input_idx // (a + 1)
else:
return (input_idx - x) // a + b
def _get_input_port(input_index: int, total_inputs: int) -> str:
"""
Get input port based on input index and total input count
"""
if total_inputs <= 1:
return "n" # Single input uses default port
elif total_inputs == 2:
return "nw" if input_index == 0 else "ne" # Northwest, northeast
elif total_inputs == 3:
ports = ["nw", "n", "ne"] # Northwest, north, northeast
return ports[input_index]
elif total_inputs == 4:
ports = ["nw", "n", "ne", "e"]
return ports[input_index]
else:
# 5+ inputs: west, northwest, north, northeast, east in order
ports = ["w", "nw", "n", "ne", "e"]
# Cycle through ports for more than 5 inputs
return ports[_get_port_of_five(input_index, total_inputs)]
def _get_output_port(output_index: int, total_outputs: int) -> str:
"""
Get output port based on output index and total output count (symmetric to input, but on bottom)
"""
if total_outputs <= 1:
return "s" # Single output uses default port
elif total_outputs == 2:
return "sw" if output_index == 0 else "se" # Southwest, southeast
elif total_outputs == 3:
ports = ["sw", "s", "se"] # Southwest, south, southeast
return ports[output_index]
else:
# 4+ outputs: west, southwest, south, southeast, east in order
ports = ["w", "sw", "s", "se", "e"]
# Cycle through ports for more than 5 outputs
return ports[_get_port_of_five(output_index, total_outputs)]
def to_dot(
graph_module: GraphModule,
name: str,
save_path: str,
format: str = "svg",
include_shapes: bool = True,
) -> Optional["graphviz.Digraph"]:
"""
Convert PyTorch GraphModule to Graphviz diagram
Args:
graph_module: GraphModule to visualize
name: Name of the diagram
save_path: Save path, if None uses name
format: Output format ('png', 'pdf', 'svg', 'dot', etc.)
include_shapes: Whether to include tensor shape information
Returns:
graphviz.Digraph object
"""
# Create Graphviz diagram
dot = graphviz.Digraph(
name=name, comment=f"PyTorch GraphModule: {graph_module.__class__.__name__}", format=format
)
# Set graph attributes
dot.attr(rankdir="TB") # Top to bottom
dot.attr("node", shape="box", style="rounded,filled", height="0.2")
# Remove default edge color, let each edge use its own color
# Node style configuration
node_styles = {
"placeholder": {"fillcolor": "lightgreen", "shape": "box"},
"get_attr": {"fillcolor": "lightcyan", "shape": "box"},
"call_function": {"fillcolor": "lightblue", "shape": "box"},
"call_method": {"fillcolor": "lightyellow", "shape": "box"},
"call_module": {"fillcolor": "lightpink", "shape": "box"},
"output": {"fillcolor": "lightcoral", "shape": "box"},
}
# Analyze graph structure
graph = graph_module.graph
nodes = list(graph.nodes)
node_labels = {}
# Process each node
for node in nodes:
# Get basic node information
node_name = node.name
op_type = node.op
# Create node label
label = _get_node_label(graph_module, node)
node_labels[node_name] = label
# Set node style
style = node_styles.get(op_type, {"fillcolor": "white", "shape": "box"})
# Add node to diagram
node_attrs = {
"label": label,
"fillcolor": style["fillcolor"],
"shape": style["shape"],
"tooltip": node_name, # Use node name as tooltip
}
# If the node has no value, set the fillcolor to red
if "val" not in node.meta:
node_attrs["fillcolor"] = "red"
elif isinstance(node.meta["val"], torch.Tensor):
node_attrs["label"] += "\n" + str(node.meta["val"].device)
elif isinstance(node.meta["val"], (list, tuple)):
for val in node.meta["val"]:
if isinstance(val, torch.Tensor):
node_attrs["label"] += "\n" + str(val.device)
else:
node_attrs["label"] += "\n" + str(val)
dot.node(node_name, **node_attrs)
# First collect all edge information
edges = [] # Format: (source_node, target_node, val, source_output_index, target_input_index)
node_inputs = {} # Input list for each node: [(source_node_name, output_index)]
node_outputs = {} # Output list for each node: [(target_node_name, input_index)]
# Initialize
for node in nodes:
node_inputs[node.name] = []
node_outputs[node.name] = []
# Collect edge information
def _add_edge_from_node_with_index(input_idx, source_node: fx.Node, target_node: fx.Node):
"""Add edge from source_node to target_node, automatically determining correct output index"""
# Extract val (FakeTensor or Tensor) from node.meta["val"]
# This is more reliable than tensor_meta since FakeTensorProp always populates val
val = None
if include_shapes and hasattr(source_node, "meta") and "val" in source_node.meta:
val = source_node.meta["val"]
# Calculate indices
source_output_index = _determine_output_index(source_node, target_node)
# Add edge and update indices (store tuple containing node name and corresponding index)
edges.append((source_node.name, target_node.name, val, source_output_index, input_idx))
node_inputs[target_node.name].append((source_node.name, source_output_index))
node_outputs[source_node.name].append((target_node.name, input_idx))
def _determine_output_index(source_node: fx.Node, target_node: fx.Node) -> int:
"""Determine the output index for the edge from source_node to target_node"""
# Check if target_node is a getitem operation
if (
target_node.op == "call_function"
and hasattr(target_node.target, "__name__")
and target_node.target.__name__ == "getitem"
and len(target_node.args) >= 2
):
# target_node.args[0] is source node, target_node.args[1] is index
if target_node.args[0] == source_node and isinstance(target_node.args[1], int):
return target_node.args[1]
# By default, FX nodes have only one output
return 0
def _fix_problematic_negative_one(text):
"""Fix the problematic 9223372036854775807 number back to -1"""
return text.replace("9223372036854775807", "-1")
def _format_constant_label(value):
"""Format constant value for display as node label"""
# Special case: direct problematic integer
if isinstance(value, int) and value == 9223372036854775807: # 2^63 - 1
return "-1"
# Handle different value types
if isinstance(value, (int, float, bool, str)):
result = str(value)
elif hasattr(value, "shape"):
# Tensor-like object with shape
if hasattr(value, "numel") and value.numel() > 6:
return f"shape={tuple(value.shape)}"
else:
result = str(value)
elif isinstance(value, (list, tuple)):
# Collection of elements
if len(value) > 6:
return f"length={len(value)}"
else:
result = str(value)
else:
# Other types
result = str(value)[:50] + ("..." if len(str(value)) > 50 else "")
# Apply the fix to any string representation
return _fix_problematic_negative_one(result)
# Store constants to be created later
constants_to_create = []
def _process_arg(input_idx, arg, target_node: fx.Node):
"""Process a single argument and create appropriate edges in the graph.
Handles three types of arguments:
- fx.Node: Creates an edge from the source node to the target node.
- Container (list/tuple): Iterates through and creates edges for any fx.Node elements.
- Constant: Creates a constant node and adds it to the graph with appropriate edges.
Args:
input_idx: The index of this argument in the target node's input list.
arg: The argument to process. Can be an fx.Node, a container (list/tuple),
or a constant value.
target_node: The fx.Node that receives this argument as input.
Note:
This function modifies external state including `constants_to_create`,
`node_inputs`, `node_outputs`, and `edges`.
"""
if isinstance(arg, fx.Node):
_add_edge_from_node_with_index(input_idx, arg, target_node)
elif isinstance(arg, (list, tuple)):
for sub_arg in arg:
if isinstance(sub_arg, fx.Node):
_add_edge_from_node_with_index(input_idx, sub_arg, target_node)
else:
# This is a constant value - store for later creation
const_node_name = f"const_{target_node.name}_{input_idx}"
constants_to_create.append((const_node_name, arg, target_node.name, input_idx))
# Add to node tracking immediately
node_inputs[target_node.name].append(const_node_name)
if const_node_name not in node_outputs:
node_outputs[const_node_name] = []
node_outputs[const_node_name].append((target_node.name, input_idx))
# Create edge from constant to target
edges.append((const_node_name, target_node.name, None, 0, input_idx))
def _determine_node_output_count(node_name: str) -> int:
"""Determine the actual output count of a node (applies to all node types)"""
# Check if any getitem operations use this node
max_index = -1
for n in nodes:
# Only check getitem operation pattern, don't restrict source node op type
if (
n.op == "call_function"
and hasattr(n.target, "__name__")
and n.target.__name__ == "getitem"
and len(n.args) >= 2
and hasattr(n.args[0], "name")
and n.args[0].name == node_name
and isinstance(n.args[1], int)
):
max_index = max(max_index, n.args[1])
# If getitem operations found, output count is max_index + 1
if max_index >= 0:
return max_index + 1
# By default, nodes have only one output
return 1
# Traverse all nodes to collect edges
for node in nodes:
if not node.args:
continue
for input_idx, arg in enumerate(node.args):
_process_arg(input_idx, arg, node)
# Print 10 nodes with most inputs
# ad_logger.debug("Nodes with most inputs:")
# node_inputs_sorted = sorted(node_inputs.items(), key=lambda x: len(x[1]), reverse=True)
# for node_name, input_list in node_inputs_sorted[:10]:
# ad_logger.debug(f" {node_name}: {len(input_list)}")
# Print 10 nodes with most outputs
node_outputs_sorted = sorted(node_outputs.items(), key=lambda x: len(x[1]), reverse=True)
# ad_logger.debug("Nodes with most outputs:")
large_fanout_nodes: Dict[str, int] = {}
for node_name, output_list in node_outputs_sorted[:10]:
if len(output_list) > 10:
large_fanout_nodes[node_name] = 0
# ad_logger.debug(f" {node_name}: {len(output_list)}")
# Overwrite large fanout nodes style
for node_name in large_fanout_nodes:
# Add node to diagram
dot.node(
node_name,
label=node_name,
fillcolor="#ffdddd80",
color="#88888880",
shape="box",
style="filled,dashed,rounded",
tooltip=node_name, # Use node name as tooltip
)
node_inputs_sorted = sorted(node_inputs.items(), key=lambda x: len(x[1]), reverse=True)
large_fanin_nodes: Dict[str, int] = {}
for node_name, input_list in node_inputs_sorted[:10]:
if len(input_list) > 12:
large_fanin_nodes[node_name] = 0
for node_name in large_fanin_nodes:
# Add node to diagram
dot.node(
node_name,
label=node_name,
fillcolor="#ddffdd80",
color="#88888880",
shape="box",
style="filled,dashed,rounded",
tooltip=node_name, # Use node name as tooltip
)
# Create constant nodes
for const_node_name, const_value, target_name, input_idx in constants_to_create:
const_label = _format_constant_label(const_value)
# Add constant node to dot graph
const_attrs = {
"label": const_label,
"shape": "box",
"style": "rounded,filled",
"fillcolor": "#ffffcc", # Light yellow background
"color": "#cccccc", # Light gray border
"width": "0.2",
"fontsize": "10",
}
dot.node(const_node_name, **const_attrs)
# Add edges with ports and colors
for source_name, target_name, val, source_output_index, target_input_index in edges:
edge_attrs = {}
# Calculate ports (for graphical display positioning)
input_list = node_inputs[target_name]
# Use actual output count, not usage count
source_output_count = _determine_node_output_count(source_name)
input_port = _get_input_port(target_input_index, len(input_list))
output_port = _get_output_port(source_output_index, source_output_count)
# Build node names with ports
source_name_port = f"{source_name}:{output_port}" if output_port else source_name
target_name_port = f"{target_name}:{input_port}" if input_port else target_name
# Set edge color (based on actual output_index)
edge_color = _get_edge_color(source_name, source_output_index)
edge_attrs["color"] = edge_color
# Add tensor shape and width information
# val can be FakeTensor, Tensor, or other types with .shape attribute
if val is not None and include_shapes and hasattr(val, "shape"):
shape_str = str(tuple(val.shape))
# Add dtype if available
dtype_str = ""
if hasattr(val, "dtype"):
dtype_str = str(val.dtype).replace("torch.", "")
edge_attrs["xlabel"] = f"{shape_str}\n{dtype_str}" if dtype_str else shape_str
edge_attrs["fontsize"] = "10"
edge_attrs["fontcolor"] = "blue"
# Calculate edge width based on element count
width = _calculate_edge_width(val)
edge_attrs["penwidth"] = str(width)
# For those large fanout nodes, large fantout nodes will stuck the graphviz layout algorithm.
# So we need to duplicate the node, so each edge has its own source node.
# Make the layout algorithm work.
# So is large fanin nodes.
if source_name in large_fanout_nodes:
node_attrs = {
"fillcolor": "#ddffdd80",
"color": "#88888880",
"shape": "box",
"style": "filled,dashed,rounded",
}
large_fanout_nodes[source_name] += 1
source_name_port = source_name + f"___{large_fanout_nodes[source_name]}"
dot.node(
name=source_name_port,
label=source_name,
**node_attrs,
)
if target_name in large_fanin_nodes:
node_attrs = {
"fillcolor": "#ddffdd80",
"color": "#88888880",
"shape": "box",
"style": "filled,dashed,rounded",
}
large_fanin_nodes[target_name] += 1
target_name_port = target_name + f"___{large_fanin_nodes[target_name]}"
dot.node(name=target_name_port, label=target_name, **node_attrs)
dot.edge(source_name_port, target_name_port, **edge_attrs)
# Save diagram
try:
dot.render(save_path, cleanup=True)
ad_logger.info(f"Diagram saved: {save_path}.{format}")
with open(save_path + ".txt", "w") as f:
f.write(str(graph_module.graph))
except Exception as e:
ad_logger.error(f"Failed to save diagram: {e}")
return dot
def analyze_graph_structure(graph_module: GraphModule) -> Dict[str, Any]:
"""
Analyze structural statistics of GraphModule
Args:
graph_module: GraphModule to analyze
Returns:
Dictionary containing structural statistics
"""
graph = graph_module.graph
nodes = list(graph.nodes)
# Count node types
op_counts = {}
for node in nodes:
op_type = node.op
op_counts[op_type] = op_counts.get(op_type, 0) + 1
# Analyze connections
total_connections = 0
for node in nodes:
if node.args:
for arg in node.args:
if isinstance(arg, fx.Node):
total_connections += 1
elif isinstance(arg, (list, tuple)):
for sub_arg in arg:
if isinstance(sub_arg, fx.Node):
total_connections += 1
# Calculate graph complexity
complexity_score = len(nodes) + total_connections
return {
"total_nodes": len(nodes),
"node_types": op_counts,
"total_connections": total_connections,
"complexity_score": complexity_score,
"graph_depth": _calculate_graph_depth(nodes),
}
def _get_node_label(graph_module: GraphModule, node: fx.Node) -> str:
"""Get node label"""
if node.op == "call_function":
func_name = _get_function_name(node.target)
tokens = func_name.split(".")
assert len(tokens) <= 2, f"Function name {func_name} has more than 2 tokens"
label = tokens[0] if tokens[0] != "to" else func_name
elif node.op == "call_method":
label = str(node.target)
elif node.op == "call_module":
label = _get_module_name(graph_module, node.target)
elif node.op == "get_attr":
attr_name = str(node.target).split(".")[-1] if "." in str(node.target) else str(node.target)
label = attr_name
elif node.op == "placeholder":
label = "ph: " + str(node.name)
elif node.op == "output":
label = "out: " + str(node.name)
else:
label = node.op
return label
def _get_function_name(func) -> str:
"""Get simplified function name"""
if hasattr(func, "__name__"):
return func.__name__
func_str = str(func)
# Handle torch functions
if "torch." in func_str:
match = re.search(r"torch\.(\w+)\.(\w+)", func_str)
if match:
return f"{match.group(1)}.{match.group(2)}"
match = re.search(r"torch\.(\w+)", func_str)
if match:
return match.group(1)
# Handle built-in functions
if "built-in" in func_str:
match = re.search(r"'(\w+)'", func_str)
if match:
return match.group(1)
return str(func).split(".")[-1] if "." in str(func) else str(func)
def _get_module_name(graph_module: GraphModule, target) -> str:
"""Extract module name, handle numeric indices in Sequential"""
try:
# Try to get actual module type name
actual_module = graph_module.get_submodule(str(target))
module_type = actual_module.__class__.__name__
# Extract the last part of module name
module_name = str(target).split(".")[-1] if "." in str(target) else str(target)
# If it's numeric index (like modules in Sequential), show type name
if module_name.isdigit():
return module_type
else:
return module_name
except Exception:
# If unable to get module, fall back to original logic
module_name = str(target).split(".")[-1] if "." in str(target) else str(target)
return module_name
def _calculate_graph_depth(nodes) -> int:
"""Calculate maximum depth of the graph"""
# Build dependency relationships
dependencies = {}
for node in nodes:
dependencies[node.name] = []
if node.args:
for arg in node.args:
if isinstance(arg, fx.Node):
dependencies[node.name].append(arg.name)
elif isinstance(arg, (list, tuple)):
for sub_arg in arg:
if isinstance(sub_arg, fx.Node):
dependencies[node.name].append(sub_arg.name)
# Calculate depth of each node
depths = {}
def calculate_depth(node_name):
if node_name in depths:
return depths[node_name]
if not dependencies[node_name]:
depths[node_name] = 0
return 0
max_dep_depth = max(calculate_depth(dep) for dep in dependencies[node_name])
depths[node_name] = max_dep_depth + 1
return depths[node_name]
for node in nodes:
calculate_depth(node.name)
return max(depths.values()) if depths else 0

View File

@ -3,14 +3,16 @@
This module defines the base classes and interfaces for all transforms.
"""
import os
import time
from abc import ABC
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from enum import Enum
from functools import total_ordering, wraps
from typing import Any, Callable, Dict, Mapping, Tuple, Type, Union, final
from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type, Union, final
import torch
import torch.nn as nn
from pydantic import BaseModel, Field
from torch.fx import GraphModule, Node
@ -27,6 +29,7 @@ from ..utils._graph import (
)
from ..utils.cuda_mem_tracker import get_mem_info
from ..utils.logger import ad_logger
from .graph_module_visualizer import to_dot
# ANSI color codes for log formatting (set to False to disable colors)
# NOTE: colors disabled by default to make logging in CI/CD pipelines easier to read
@ -102,6 +105,7 @@ class Stages(Enum):
POST_LOAD_FUSION = "post_load_fusion" # post-loading fusion and perf optimizations of the graph
CACHE_INIT = "cache_init" # initialization of cached attention + (KV) cache initialization
VISUALIZE = "visualize" # visualization of the graph
EXPORT_ONNX = "export_onnx" # export the graph to onnx
COMPILE = "compile" # graph compilation stage using low-level compilers like torch.compile
def __lt__(self, other):
@ -166,6 +170,11 @@ class TransformConfig(BaseModel):
default=False,
description="Whether this transform requires shape propagation before it is applied.",
)
debug_visualize_dir: Optional[str] = Field(
default=None,
description="Debug visualization directory. None to disable visualization, "
"or a path string to specify the output directory.",
)
expect_mem_change: bool = Field(
default=False,
@ -257,6 +266,7 @@ def with_transform_logging(call_fn: Callable) -> Callable:
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
idx: int,
) -> nn.Module:
prefix = f"[stage={self.config.stage.value}, transform={self.get_transform_key()}]"
original_log = ad_logger.log
@ -268,7 +278,7 @@ def with_transform_logging(call_fn: Callable) -> Callable:
ad_logger.log = _patched_log # type: ignore[assignment]
try:
return call_fn(self, gm, cm, factory, shared_config)
return call_fn(self, gm, cm, factory, shared_config, idx)
finally:
ad_logger.log = original_log # type: ignore[assignment]
@ -346,6 +356,7 @@ class BaseTransform(ABC):
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
idx: int,
) -> nn.Module:
"""Apply the transform to the graph.
@ -354,6 +365,7 @@ class BaseTransform(ABC):
cm: The cached sequence interface defining the sequence interface.
factory: The model factory used to build the model.
shared_config: Global info shared between multiple transforms.
idx: The index of the transform in the pipeline.
Returns:
nn.Module: The transformed model.
@ -473,10 +485,33 @@ class BaseTransform(ABC):
autodeploy_meta[self._mem_history_key] = mem_history
self._set_autodeploy_meta(mod, autodeploy_meta)
self._visualize_graph(mod, idx)
# return the graph module
return mod
@final
def _visualize_graph(self, mod: nn.Module, idx: int) -> None:
"""Visualize the graph if debug visualization is enabled.
Args:
mod: The graph module to visualize.
idx: The index of the transform in the pipeline.
Note:
we may want to consider doing this for each subgraph.
See https://github.com/NVIDIA/TensorRT-LLM/issues/10203
"""
if not isinstance(mod, torch.fx.GraphModule):
return
visualize_dir = self.config.debug_visualize_dir
if not visualize_dir:
return
if not os.path.exists(visualize_dir):
os.makedirs(visualize_dir)
name_stem = f"gm_{idx + 1:02d}_{self.get_transform_key()}"
visualize_path = os.path.join(visualize_dir, f"{name_stem}")
to_dot(mod, name=name_stem, save_path=visualize_path, format="svg")
ad_logger.debug(f"[{idx + 1:02d}] Visualized {name_stem} to {visualize_path}")
@final
def _apply_per_gm_or_whole_model(
self,

View File

@ -0,0 +1,382 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
For EdgeLLM API. Processes the chat template to create a JSON file with
chat template data for the following:
Roles:
- System
- User
- Assistant
Messages:
- Role
- Content
- Type
- text
- image
- video
The JSON file is saved to the exported ONNX model directory.
This implementation uses the HF tokenizer's apply_chat_template method with test cases
to extract the actual prefix/suffix patterns used by the model, rather than trying
to parse the Jinja template directly.
"""
import json
import os
import re
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers import AutoConfig, AutoProcessor, AutoTokenizer
from ...utils.logger import ad_logger
def is_vlm(model_dir: str) -> bool:
"""Check if the model is a VLM."""
cfg = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
cfg_dict = cfg.to_dict()
has_vision = "vision_config" in cfg_dict
has_phi4_vision = "image_embd_layer" in cfg_dict.get("embd_layer", {})
if has_vision or has_phi4_vision:
ad_logger.debug("Set use_prompt_tuning to True")
return True
else:
ad_logger.debug("Set use_prompt_tuning to False")
return False
@dataclass
class Message:
role: str
content: Union[str, List[Dict[str, str]]] = field(default_factory=list)
@dataclass
class SystemMessage(Message):
role: str = "system"
content: str = "<placeholder_system_prompt>"
@dataclass
class UserMessage(Message):
role: str = "user"
content: str = "<placeholder_user_text>"
@dataclass
class MultimodalUserMessage(Message):
role: str = "user"
content: List[Dict[str, str]] = field(
default_factory=lambda: [{"type": "text", "text": "<placeholder_user_text>"}]
)
def add_text_content(self, text: str):
self.content.append({"type": "text", "text": text})
def add_image_content(self, image: str):
self.content.append({"type": "image", "image": image})
def add_video_content(self, video: str):
self.content.append({"type": "video", "video": video})
@dataclass
class AssistantMessage(Message):
role: str = "assistant"
content: str = "<placeholder_assistant_text>"
# TODO: Add tool calling
# TODO: Add ToolMessage
def _format_messages(
tokenizer: Any, messages: List[Message], add_generation_prompt: bool = False
) -> str:
"""
Format the messages using the tokenizer's chat template.
Args:
tokenizer: HuggingFace loaded tokenizer
messages: List of messages
add_generation_prompt: Whether to add generation prompt
Returns:
Formatted text
Raises:
ValueError: If unable to format messages
"""
try:
# Convert dataclass messages to dictionaries using asdict
message_dicts = [asdict(msg) for msg in messages]
return tokenizer.apply_chat_template(
message_dicts, tokenize=False, add_generation_prompt=add_generation_prompt
)
except Exception:
# Try fallback: convert list content to string for tokenizers that don't support multimodal
try:
message_dicts = []
for msg in messages:
content = msg.content
# If content is a list, extract the first text element
if isinstance(content, list):
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
content = item.get("text", "")
break
message_dicts.append({"role": msg.role, "content": content})
return tokenizer.apply_chat_template(
message_dicts, tokenize=False, add_generation_prompt=add_generation_prompt
)
except Exception as e2:
raise ValueError(
f"Unable to format messages using HuggingFace tokenizer's apply_chat_template method."
f"Messages need to be in the format: role: <str>, content: <str|list of dicts>. "
f"Check INPUT_FORMAT.md for more details."
f"Error: {e2}"
) from e2
def _extract_prefix_suffix(text: str, placeholder: str) -> Tuple[str, str]:
"""
Extract prefix and suffix from the differential text by finding the placeholder content.
Args:
text: The text to extract the prefix and suffix from
placeholder : The placeholder content to search for in the formatted output
Returns:
Tuple of (prefix, suffix) strings
"""
content_start = text.find(placeholder)
if content_start == -1:
return "", ""
prefix = text[:content_start]
suffix = text[content_start + len(placeholder) :]
return prefix, suffix
def _extract_content_pattern(
tokenizer: Any,
system_prompt: SystemMessage,
content_type: str,
placeholder: str,
text_only_formatted: str,
placeholder_text: str,
) -> Optional[str]:
"""
Extract the pattern for a specific content type (image/video) by comparing
with text-only message.
Args:
tokenizer: The loaded tokenizer
system_prompt: System message to use
content_type: Type of content ('image' or 'video')
placeholder: Placeholder string for the content
text_only_formatted: Formatted text-only message
placeholder_text: The text placeholder used
Returns:
Extracted pattern string or None if failed or tokenizer does not support multimodal content
"""
# Create user message with the content type
user_with_content = MultimodalUserMessage()
if content_type == "image":
user_with_content.add_image_content(placeholder)
elif content_type == "video":
user_with_content.add_video_content(placeholder)
else:
return None
with_content_formatted = _format_messages(tokenizer, [system_prompt, user_with_content])
# Extract the differential - what was added for this content type
if placeholder_text in text_only_formatted and placeholder_text in with_content_formatted:
# Find position after the placeholder in both
text_pos = text_only_formatted.find(placeholder_text) + len(placeholder_text)
content_pos = with_content_formatted.find(placeholder_text) + len(placeholder_text)
# Get what comes after the placeholder in both
text_only_suffix = text_only_formatted[text_pos:]
with_content_suffix = with_content_formatted[content_pos:]
# The pattern is what was added (the difference)
if text_only_suffix and with_content_suffix.endswith(text_only_suffix):
pattern = with_content_suffix[: -len(text_only_suffix)]
else:
pattern = with_content_suffix
# Strip dynamic prefixes like "Image 1:" or "Video 1:"
pattern = re.sub(rf"^{content_type.capitalize()} \d+:\s*", "", pattern)
return pattern if pattern else None
return None
def process_chat_template(model_dir: str, output_dir: str) -> None:
"""
Process the chat template from model's tokenizer and create a JSON file
with parsed template information.
This function uses the tokenizer's apply_chat_template method with various
test cases to extract the actual prefix/suffix patterns.
Args:
model_dir: Path to the model directory containing tokenizer files
output_dir: Path to save the chat_template.json file
Returns:
None
"""
ad_logger.info(f"Processing chat template from {model_dir}")
tokenizer = None
loaders = (
[AutoProcessor, AutoTokenizer] if is_vlm(model_dir) else [AutoTokenizer, AutoProcessor]
)
for ldr in loaders:
try:
tokenizer = ldr.from_pretrained(model_dir, trust_remote_code=True)
if getattr(tokenizer, "chat_template", None):
ad_logger.debug(f"Successfully loaded chat template from {ldr.__name__}")
break
else:
ad_logger.debug(f"{ldr.__name__} loaded but no chat template found")
tokenizer = None
except Exception as e:
ad_logger.error(f"Failed to load {ldr.__name__}: {e}")
tokenizer = None
if tokenizer is None:
ad_logger.debug("Skipping chat template processing - no chat template available")
return
ad_logger.debug("Extracting patterns from chat template...")
# Extract system role patterns (base case)
system_prompt = SystemMessage()
system_formatted = _format_messages(tokenizer, [system_prompt])
system_prefix, system_suffix = _extract_prefix_suffix(system_formatted, system_prompt.content)
# Extract user role patterns (compare with system base)
user_prompt = UserMessage()
user_formatted = _format_messages(tokenizer, [system_prompt, user_prompt])
user_prefix, user_suffix = _extract_prefix_suffix(
user_formatted[len(system_formatted) :], user_prompt.content
)
# Extract assistant role patterns (compare with user case)
assistant_prompt = AssistantMessage()
assistant_formatted = _format_messages(
tokenizer, [system_prompt, user_prompt, assistant_prompt]
)
assistant_prefix, assistant_suffix = _extract_prefix_suffix(
assistant_formatted[len(user_formatted) :], assistant_prompt.content
)
# Extract generation prompt
generation_formatted = _format_messages(
tokenizer, [system_prompt, user_prompt], add_generation_prompt=True
)
generation_prompt = generation_formatted[len(user_formatted) :]
# Build content types
content_types = {}
# Only extract multimodal patterns if this is a VLM model
if is_vlm(model_dir):
ad_logger.debug("Detected VLM model, extracting multimodal content patterns...")
# Get base text-only formatted message for comparison
user_text_only = MultimodalUserMessage()
text_only_formatted = _format_messages(tokenizer, [system_prompt, user_text_only])
placeholder_text = user_text_only.content[0]["text"]
# Extract image pattern
image_pattern = _extract_content_pattern(
tokenizer,
system_prompt,
"image",
"<placeholder_image_path>",
text_only_formatted,
placeholder_text,
)
if image_pattern:
content_types["image"] = {"format": image_pattern}
# Extract video pattern
video_pattern = _extract_content_pattern(
tokenizer,
system_prompt,
"video",
"<placeholder_video_path>",
text_only_formatted,
placeholder_text,
)
if video_pattern:
content_types["video"] = {"format": video_pattern}
else:
ad_logger.debug("Text-only LLM detected, skipping multimodal content pattern extraction")
# Extract default system prompt by testing without system message
user_only_prompt = UserMessage()
user_only_formatted = _format_messages(tokenizer, [user_only_prompt])
# Extract default system prompt
default_system_prompt = ""
# Check if a default system prompt was added
# The system message should appear in user_only_formatted if there's a default
system_start = user_only_formatted.find(system_prefix)
if system_start != -1:
# Extract the system content between prefix and suffix
content_start = system_start + len(system_prefix)
content_end = user_only_formatted.find(system_suffix, content_start)
if content_end != -1:
default_system_prompt = user_only_formatted[content_start:content_end]
# Remove the placeholder if it appears
if default_system_prompt == system_prompt.content:
default_system_prompt = ""
# Build the final JSON structure
chat_template_data = {
"model_path": model_dir,
"roles": {
"system": {"prefix": system_prefix, "suffix": system_suffix},
"user": {"prefix": user_prefix, "suffix": user_suffix},
"assistant": {"prefix": assistant_prefix, "suffix": assistant_suffix},
},
"content_types": content_types,
"generation_prompt": generation_prompt,
"default_system_prompt": default_system_prompt,
}
# Save to output directory
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "processed_chat_template.json")
with open(output_path, "w") as f:
json.dump(chat_template_data, f, indent=2)
ad_logger.info(f"Chat template saved to {output_path}")

View File

@ -0,0 +1,205 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict
from ...utils.logger import ad_logger
EDGELLM_VERSION = "0.5.0.0"
def _export_native_llm_config(config_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Export LLM configuration with required fields."""
required_fields = [
"vocab_size",
"max_position_embeddings",
"hidden_size",
"intermediate_size",
"num_hidden_layers",
"num_attention_heads",
"num_key_value_heads",
"rope_theta",
"rope_scaling",
]
llm_config = {}
for field in required_fields:
if field not in config_dict:
raise KeyError(f"Required field '{field}' not found in config")
llm_config[field] = config_dict[field]
# Handle LongRoPE (rope_scaling already validated in required_fields)
rope_scaling = config_dict["rope_scaling"]
if rope_scaling and rope_scaling.get("type", None) == "longrope":
if "original_max_position_embeddings" not in config_dict:
raise KeyError("Required field 'original_max_position_embeddings' not found in config")
llm_config["original_max_position_embeddings"] = config_dict[
"original_max_position_embeddings"
]
# Handle head_dim
if "head_dim" in config_dict:
llm_config["head_dim"] = config_dict["head_dim"]
else:
ad_logger.warning(
"Warning: head_dim not found in config, calculating as hidden_size // num_attention_heads"
)
llm_config["head_dim"] = config_dict["hidden_size"] // config_dict["num_attention_heads"]
if "partial_rotary_factor" in config_dict:
llm_config["partial_rotary_factor"] = config_dict["partial_rotary_factor"]
else:
llm_config["partial_rotary_factor"] = 1.0
llm_config["model_type"] = "llm"
return llm_config
def _export_eagle_base_config(config_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Export EAGLE base configuration with required fields."""
required_fields = [
"vocab_size",
"max_position_embeddings",
"hidden_size",
"intermediate_size",
"num_hidden_layers",
"num_attention_heads",
"num_key_value_heads",
"rope_theta",
"rope_scaling",
]
eagle_config = {}
for field in required_fields:
if field not in config_dict:
raise KeyError(f"Required field '{field}' not found in config")
eagle_config[field] = config_dict[field]
# Handle head_dim
if "head_dim" in config_dict:
eagle_config["head_dim"] = config_dict["head_dim"]
else:
ad_logger.warning(
"Warning: head_dim not found in config, calculating as hidden_size // num_attention_heads"
)
eagle_config["head_dim"] = config_dict["hidden_size"] // config_dict["num_attention_heads"]
if "partial_rotary_factor" in config_dict:
eagle_config["partial_rotary_factor"] = config_dict["partial_rotary_factor"]
else:
eagle_config["partial_rotary_factor"] = 1.0
eagle_config["model_type"] = "eagle3_base"
return eagle_config
def _export_eagle_draft_config(config_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Export EAGLE draft configuration with required fields."""
required_fields = [
"hidden_size",
"max_position_embeddings",
"intermediate_size",
"num_hidden_layers",
"num_attention_heads",
"num_key_value_heads",
"rope_theta",
"rope_scaling",
]
draft_config = {}
for field in required_fields:
if field not in config_dict:
raise KeyError(f"Required field '{field}' not found in config")
draft_config[field] = config_dict[field]
# Handle head_dim
if "head_dim" in config_dict:
draft_config["head_dim"] = config_dict["head_dim"]
else:
ad_logger.warning(
"Warning: head_dim not found in config, calculating as hidden_size // num_attention_heads"
)
draft_config["head_dim"] = config_dict["hidden_size"] // config_dict["num_attention_heads"]
# Handle draft_vocab_size based on EAGLE version
if "draft_vocab_size" not in config_dict:
raise KeyError("Required field 'draft_vocab_size' not found in config")
draft_config["draft_vocab_size"] = config_dict["draft_vocab_size"]
# Add base model configuration fields
# The target_hidden_size from the model config represents the base model's hidden dimension
if "target_hidden_size" in config_dict:
# Use target_hidden_size * 3 as the base model hidden dimension (as per llm_export.py logic)
draft_config["base_model_hidden_size"] = config_dict["target_hidden_size"] * 3
else:
# Fallback: assume base model hidden size is 3x draft model (Eagle3 default)
draft_config["base_model_hidden_size"] = config_dict["hidden_size"] * 3
ad_logger.warning(
f"Warning: target_hidden_size not found, using default 3x draft hidden size: "
f"{draft_config['base_model_hidden_size']}"
)
# Set model_type for draft
draft_config["model_type"] = "eagle3_draft"
return draft_config
def export_vision_config(config: Any) -> Dict[str, Any]:
"""Export vision configuration without modification."""
config_dict = config.to_dict()
has_vision = "vision_config" in config_dict
has_phi4_vision = "image_embd_layer" in config_dict.get("embd_layer", {})
if not (has_vision or has_phi4_vision):
raise KeyError(
"Required field 'vision_config' or 'image_embd_layer' in 'embd_layer' not found in config"
)
# Add EdgeLLM API version
config_dict["edgellm_version"] = EDGELLM_VERSION
# Return the original config_dict as-is without any modification
# Since MRoPE needs LLM config, ViTRunner will use the LLM config.
return config_dict
def export_llm_config(config: Any, model_type: str) -> Dict[str, Any]:
"""Export configuration based on model type and EAGLE version."""
config_dict = config.to_dict()
# Extract model name from config class
config_class_name = config.__class__.__name__
model_name = config_class_name.lower().replace("config", "")
# For other model types, use text_config if available
if "text_config" in config_dict:
ad_logger.info("Detected multimodal model, using text_config")
config_dict = config_dict["text_config"]
if model_type == "llm":
output_config = _export_native_llm_config(config_dict)
elif model_type == "eagle3_base":
output_config = _export_eagle_base_config(config_dict)
elif model_type == "eagle_draft":
output_config = _export_eagle_draft_config(config_dict)
else:
raise ValueError(f"Unsupported model type: {model_type}")
# Add model name to output
output_config["model"] = model_name
# Add EdgeLLM API version
output_config["edgellm_version"] = EDGELLM_VERSION
return output_config

View File

@ -0,0 +1,258 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from onnx import defs
from onnxscript.values import Opset
_TRT_DOMAIN_NAME = "trt"
_AUTO_DEPLOY_DOMAIN_NAME = "auto_deploy"
# public opset objects, used in ONNX translation functions
trt_opset = Opset(_TRT_DOMAIN_NAME, 1)
auto_deploy_opset = Opset(_AUTO_DEPLOY_DOMAIN_NAME, 1)
# ONNX Custom Op Registration for RoPE
_torch_rope_with_explicit_cos_sin_schema = defs.OpSchema(
name="rope_with_explicit_cos_sin",
domain=_AUTO_DEPLOY_DOMAIN_NAME,
since_version=1,
doc="Rope with explicit cos and sin caches.",
inputs=[
defs.OpSchema.FormalParameter(
name="q",
description="Q tensor",
type_str="T",
),
defs.OpSchema.FormalParameter(
name="k",
description="K tensor",
type_str="T",
),
defs.OpSchema.FormalParameter(
name="cos",
description="Cos cache",
type_str="T",
),
defs.OpSchema.FormalParameter(
name="sin",
description="Sin cache",
type_str="T",
),
],
outputs=[
defs.OpSchema.FormalParameter(
name="output",
description="Output tensor",
type_str="T",
)
],
type_constraints=[
(
"T",
["tensor(float)", "tensor(float16)", "tensor(bfloat16)"],
"Input and output data type.",
),
],
attributes=[
defs.OpSchema.Attribute(
name="unsqueeze_dim",
type=defs.OpSchema.AttrType.INT,
description="Unsqueeze dimension. Must be 1 or 2.",
required=True,
),
],
)
# ONNX Custom Op Registration for AttentionPlugin
_attention_plugin_schema = defs.OpSchema(
name="AttentionPlugin",
domain=_TRT_DOMAIN_NAME,
since_version=1,
doc="Fused RoPE + Attention operation for efficient inference.",
inputs=[
defs.OpSchema.FormalParameter(
name="qkv",
description="Concatenated Q, K, V tensors in shape [batch, seq_len, qkv_hidden_size]",
type_str="T",
),
defs.OpSchema.FormalParameter(
name="past_key_values",
description="Concatenated past K and V cache in shape [batch, 2, num_kv_heads, past_len, head_size]",
type_str="T",
),
defs.OpSchema.FormalParameter(
name="context_lengths",
description="Context lengths for each sequence in shape [batch]",
type_str="T1",
),
defs.OpSchema.FormalParameter(
name="rope_rotary_cos_sin",
description="Concatenated cos and sin values for RoPE in shape [max_seq_len, head_dim]",
type_str="T",
),
defs.OpSchema.FormalParameter(
name="kvcache_start_index",
description="KV cache start index for each sequence in shape [batch]",
type_str="T1",
),
],
outputs=[
defs.OpSchema.FormalParameter(
name="output",
description="Attention output in shape [batch, seq_len, hidden_size]",
type_str="T",
),
defs.OpSchema.FormalParameter(
name="present_key_values",
description="Updated K and V cache",
type_str="T",
),
],
type_constraints=[
(
"T",
["tensor(float16)", "tensor(float)", "tensor(bfloat16)"],
"Input and output data type for floating point tensors.",
),
(
"T1",
["tensor(int32)", "tensor(int64)"],
"Input data type for integer tensors.",
),
],
attributes=[
defs.OpSchema.Attribute(
name="enable_tree_attention",
type=defs.OpSchema.AttrType.INT,
description="Whether to enable tree attention (0 or 1).",
required=True,
),
defs.OpSchema.Attribute(
name="head_size",
type=defs.OpSchema.AttrType.INT,
description="Size of each attention head.",
required=True,
),
defs.OpSchema.Attribute(
name="num_kv_heads",
type=defs.OpSchema.AttrType.INT,
description="Number of key-value heads.",
required=True,
),
defs.OpSchema.Attribute(
name="num_q_heads",
type=defs.OpSchema.AttrType.INT,
description="Number of query heads.",
required=True,
),
],
)
# ONNX Custom Op Registration for torch_attention
_torch_attention_schema = defs.OpSchema(
name="torch_attention",
domain=_AUTO_DEPLOY_DOMAIN_NAME,
since_version=1,
doc="SDPA attention (with optional GQA) that supports bnsd and bsnd memory layouts.",
inputs=[
defs.OpSchema.FormalParameter(
name="query",
description="Query tensor [batch, seq_len_q/num_heads, num_heads/seq_len_q, head_dim]",
type_str="T",
),
defs.OpSchema.FormalParameter(
name="key",
description="Key tensor [batch, seq_len_k/num_kv_heads, num_kv_heads/seq_len_k, head_dim]",
type_str="T",
),
defs.OpSchema.FormalParameter(
name="value",
description="Value tensor [batch, seq_len_k/num_kv_heads, num_kv_heads/seq_len_k, head_dim]",
type_str="T",
),
defs.OpSchema.FormalParameter(
name="attn_mask",
description="Optional attention mask in [batch, num_heads, seq_len_q, seq_len_k] layout",
type_str="T",
),
defs.OpSchema.FormalParameter(
name="sinks",
description="Optional sinks tensor",
type_str="T",
),
],
outputs=[
defs.OpSchema.FormalParameter(
name="output",
description="Attention output in the same layout as inputs",
type_str="T",
)
],
type_constraints=[
(
"T",
["tensor(float16)", "tensor(float)", "tensor(bfloat16)"],
"Input and output data type for floating point tensors.",
),
],
attributes=[
defs.OpSchema.Attribute(
name="dropout_p",
type=defs.OpSchema.AttrType.FLOAT,
description="Dropout probability.",
required=False,
),
defs.OpSchema.Attribute(
name="is_causal",
type=defs.OpSchema.AttrType.INT,
description="Whether to apply causal masking (0 or 1).",
required=False,
),
defs.OpSchema.Attribute(
name="scale",
type=defs.OpSchema.AttrType.FLOAT,
description="Attention scale factor.",
required=False,
),
defs.OpSchema.Attribute(
name="sliding_window",
type=defs.OpSchema.AttrType.INT,
description="Sliding window size for attention.",
required=False,
),
defs.OpSchema.Attribute(
name="logit_cap",
type=defs.OpSchema.AttrType.FLOAT,
description="Logit capping value.",
required=False,
),
defs.OpSchema.Attribute(
name="layout",
type=defs.OpSchema.AttrType.STRING,
description="Memory layout: 'bnsd' or 'bsnd'.",
required=False,
),
],
)
def register_onnx_schemas():
"""Register ONNX custom ops."""
defs.register_schema(_torch_rope_with_explicit_cos_sin_schema)
defs.register_schema(_torch_attention_schema)
defs.register_schema(_attention_plugin_schema)

View File

@ -0,0 +1,166 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import operator
from typing import Tuple
import torch
from torch.fx import GraphModule
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.logger import ad_logger
from ...utils.node_utils import is_op
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
@TransformRegistry.register("adapt_to_edgellm")
class AdaptToEdgeLLM(BaseTransform):
"""Transform that adapts the model graph for EdgeLLM deployment.
This transform performs several modifications to make the model compatible
with EdgeLLM runtime requirements:
1. Converts all model weights to float16 precision
2. Adds float32 cast after the final linear layer (for logits output)
3. Inserts float16 casts after attention output reshapes
4. Changes any bfloat16 casts to float16 (EdgeLLM may not support bfloat16)
These modifications ensure proper data type handling throughout the model
while maintaining numerical precision where needed (e.g., logits output).
"""
def _add_cast_after_last_linear(self, gm: GraphModule) -> torch.fx.Node:
"""Add a float32 cast operation after the final linear layer.
The final linear layer produces logits which are typically kept in
float32 for numerical stability during softmax and sampling operations.
Args:
gm: The GraphModule to modify.
Returns:
The newly created cast node.
"""
graph = gm.graph
linear_nodes = graph.find_nodes(
op="call_function", target=torch.ops.auto_deploy.torch_linear_simple.default, sort=True
)
assert len(linear_nodes) > 0, "No linear nodes found"
last_linear_node = linear_nodes[-1]
with graph.inserting_after(last_linear_node):
cast_node = graph.call_function(
torch.ops.aten.to.dtype, args=(last_linear_node, torch.float32)
)
last_linear_node.replace_all_uses_with(cast_node)
# Restore cast_node's input to last_linear_node
cast_node.update_arg(0, last_linear_node)
return cast_node
def _insert_cast_after_attn_reshape(self, gm: GraphModule) -> int:
"""Insert float16 cast after reshape nodes following AttentionPlugin output.
The AttentionPlugin may output tensors that need explicit casting to float16
before being consumed by subsequent operations (e.g., linear projections).
This ensures consistent data types throughout the attention block.
Graph transformation:
Before: AttentionPlugin[0] -> Reshape -> MatMul
After: AttentionPlugin[0] -> Reshape -> Cast(to float16) -> MatMul
Args:
gm: The GraphModule to modify.
Returns:
Number of cast nodes inserted.
"""
graph = gm.graph
# Find all AttentionPlugin nodes
attention_plugin_nodes = graph.find_nodes(
op="call_function", target=torch.ops.auto_deploy.torch_onnx_attention_plugin.default
)
num_inserted = 0
for attn_node in attention_plugin_nodes:
# Find getitem[0] for this AttentionPlugin (first output)
for user in attn_node.users:
if is_op(user, operator.getitem) and user.args[1] == 0:
getitem_0_node = user
# Find reshape nodes that use this getitem[0]
for reshape_user in list(getitem_0_node.users):
if is_op(reshape_user, torch.ops.aten.reshape.default):
reshape_node = reshape_user
# Insert cast (to float16) after reshape
with graph.inserting_after(reshape_node):
cast_node = graph.call_function(
torch.ops.aten.to.dtype,
args=(reshape_node, torch.float16),
)
reshape_node.replace_all_uses_with(cast_node)
# Fix: restore cast_node's input to reshape_node
# (replace_all_uses_with also replaced it)
cast_node.update_arg(0, reshape_node)
num_inserted += 1
ad_logger.debug(f"Inserted cast (to float16) after {reshape_node.name}")
return num_inserted
def _change_cast_bfloat16_to_float16(self, gm: GraphModule) -> int:
"""Replace all bfloat16 cast operations with float16 casts.
EdgeLLM or certain hardware backends may not support bfloat16 natively.
This method converts all bfloat16 casts to float16 for compatibility.
Args:
gm: The GraphModule to modify.
Returns:
Number of cast operations changed.
"""
graph = gm.graph
cast_nodes = graph.find_nodes(op="call_function", target=torch.ops.aten.to.dtype)
num_changed = 0
for cast_node in cast_nodes:
if cast_node.args[1] == torch.bfloat16:
cast_node.update_arg(1, torch.float16)
num_changed += 1
return num_changed
def _to_float16(self, gm: GraphModule) -> None:
"""Convert all model parameters and buffers to float16 precision.
Args:
gm: The GraphModule to convert.
"""
gm.half()
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
self._to_float16(gm)
logits_cast = self._add_cast_after_last_linear(gm)
assert logits_cast is not None, "Failed to add cast after last linear"
num_attn_casts = self._insert_cast_after_attn_reshape(gm)
num_bfloat16_casts = self._change_cast_bfloat16_to_float16(gm)
ad_logger.info(f"Changed {num_bfloat16_casts} bfloat16 casts to float16")
ad_logger.info(f"Adapted EdgeLLM model (inserted {num_attn_casts} attention casts)")
return gm, TransformInfo(
skipped=False, num_matches=num_attn_casts, is_clean=False, has_valid_shapes=True
)

View File

@ -0,0 +1,418 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from pathlib import Path
from typing import Optional, Tuple, Type
import torch
from onnxscript import ir, opset20
from pydantic import Field
from torch.export import Dim
from torch.fx import GraphModule
from ...models import hf
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.logger import ad_logger
from ..interface import (
BaseTransform,
SharedConfig,
TransformConfig,
TransformInfo,
TransformRegistry,
)
from . import _onnx_schemas
from ._chat_template import process_chat_template
from ._config_export import export_llm_config
class ExportToONNXConfig(TransformConfig):
"""Configuration for the export to ONNX transform."""
output_dir: Path = Field(
description="The directory to save the exported ONNX model.",
)
is_eagle_base: bool = Field(
description="Whether the model is an Eagle base model.",
default=False,
)
# ============================================================================
# Custom translation functions for ONNX export
# ============================================================================
def _translate_rope_op(
q: ir.Tensor, k: ir.Tensor, cos: ir.Tensor, sin: ir.Tensor, unsqueeze_dim: int
):
return _onnx_schemas.auto_deploy_opset.rope_with_explicit_cos_sin(
q, k, cos, sin, unsqueeze_dim=unsqueeze_dim
)
def _translate_simple_linear_op(input: ir.Tensor, weight: ir.Tensor, bias: Optional[ir.Tensor]):
weight = opset20.Transpose(weight, perm=[1, 0])
if bias is None:
return opset20.MatMul(input, weight)
return opset20.Add(opset20.MatMul(input, weight), bias)
def _translate_gather_nd_op(data: ir.Tensor, indices: ir.Tensor, batch_dims: int):
return opset20.GatherND(data, indices, batch_dims=batch_dims)
def _translate_rope_attention_op(
qkv: ir.Tensor,
past_key_values: ir.Tensor,
context_lengths: ir.Tensor,
rope_rotary_cos_sin: ir.Tensor,
kvcache_start_index: ir.Tensor,
enable_tree_attention: int,
head_size: int,
num_kv_heads: int,
num_q_heads: int,
):
"""
ONNX custom op translation function for AttentionPlugin.
This function creates a custom ONNX op node in the trt domain.
The actual implementation will need to be provided in the inference engine
(e.g., ONNX Runtime or TensorRT) that loads this ONNX model.
Note: This is a translation function for torch.onnx.export's custom_translation_table,
so it should NOT have @script() decorator.
"""
# Call the custom op from the trt domain
return _onnx_schemas.trt_opset.AttentionPlugin(
qkv,
past_key_values,
context_lengths,
rope_rotary_cos_sin,
kvcache_start_index,
enable_tree_attention=enable_tree_attention,
head_size=head_size,
num_kv_heads=num_kv_heads,
num_q_heads=num_q_heads,
)
def _translate_torch_attention_op(
query: ir.Tensor,
key: ir.Tensor,
value: ir.Tensor,
attn_mask: Optional[ir.Tensor] = None,
sinks: Optional[ir.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
sliding_window: Optional[int] = None,
logit_cap: Optional[float] = None,
layout: str = "bnsd",
):
"""
ONNX custom op translation function for torch_attention.
This function creates a custom ONNX op node in the auto_deploy domain.
The actual implementation will need to be provided in the inference engine
(e.g., ONNX Runtime or TensorRT) that loads this ONNX model.
Note: This is a translation function for torch.onnx.export's custom_translation_table,
so it should NOT have @script() decorator.
"""
# Call the custom op from the auto_deploy domain
return _onnx_schemas.auto_deploy_opset.torch_attention(
query,
key,
value,
attn_mask,
sinks,
dropout_p=dropout_p,
is_causal=1 if is_causal else 0,
scale=scale,
sliding_window=sliding_window,
logit_cap=logit_cap,
layout=layout,
)
@TransformRegistry.register("export_to_onnx")
class ExportToONNX(BaseTransform):
"""Transform that exports a PyTorch GraphModule to ONNX format for deployment.
This transform is responsible for:
1. Exporting the model graph to ONNX format with dynamic shapes support
2. Generating configuration files (config.json) for the exported model
3. Saving tokenizer files (tokenizer.json, vocab.json, etc.)
4. Processing and exporting chat templates
The exported ONNX model includes custom ops from the auto_deploy. These custom ops include:
- torch_onnx_attention_plugin: Fused RoPE + Attention for efficient inference(exported as EdgeLLM's custom op)
- torch_onnx_gather_nd: N-dimensional gather operation (exported as onnxscript.opset20.GatherND)
Note:
This transform does NOT modify the input graph. It only exports the graph
to external files and returns the original graph unchanged.
Attributes:
config: ExportToONNXConfig containing output directory, and is_eagle_base flag.
"""
config: ExportToONNXConfig
@classmethod
def get_config_class(cls) -> Type[TransformConfig]:
"""Return the configuration class for this transform."""
return ExportToONNXConfig
def _export_onnx_model(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
_factory: ModelFactory,
_shared_config: SharedConfig,
) -> bool:
"""Export the GraphModule to ONNX format using torch.onnx.export with dynamo.
This method handles the core ONNX export logic:
1. Extracts input placeholders and their metadata from the graph
2. Configures dynamic shapes for batch size, sequence length, and KV cache
3. Sets up custom translation table for auto_deploy custom ops
4. Exports the model using torch.onnx.export with dynamo=True
Args:
gm: The PyTorch FX GraphModule to export.
cm: CachedSequenceInterface containing max batch size and sequence length info.
_factory: ModelFactory instance (unused in this method).
_shared_config: SharedConfig instance (unused in this method).
Returns:
bool: True if export was successful.
"""
# Extract input placeholders from graph to build kwargs for export
args = []
kwargs = {}
placeholders = gm.graph.find_nodes(op="placeholder")
for ph in placeholders:
kwargs[ph.name] = ph.meta["val"]
args = tuple(args)
ad_logger.info("Placeholders args:")
for i, e in enumerate(args):
ad_logger.info(f" {i}: {placeholders[i].name:20} {e}")
ad_logger.info("Placeholders kwargs:")
for k, v in kwargs.items():
ad_logger.info(f" {k}: {v}")
# Build output path
output_path = self.config.output_dir / "model.onnx"
# Build dynamic_shapes for dynamo export
# For dynamo, we need to specify dynamic dimensions for each input tensor
dynamic_shapes = {}
dynamic_shapes["input_ids"] = {
0: Dim("batch_size", min=0, max=cm.info.max_batch_size),
1: Dim("seq_len", min=0, max=cm.info.max_seq_len),
}
# Add dynamic shapes for context_lengths and rope_rotary_cos_sin
dynamic_shapes["context_lengths"] = {
0: Dim("batch_size", min=0, max=cm.info.max_batch_size)
}
dynamic_shapes["rope_rotary_cos_sin"] = {
0: Dim("rope_batch_size", min=1, max=16),
1: Dim("max_position_embeddings", min=1, max=4096),
}
dynamic_shapes["kvcache_start_index"] = {
0: Dim("kv_cache_start_batch_size", min=0, max=cm.info.max_batch_size)
}
# Add dynamic shapes for past_key_values
num_layers = len(
gm.graph.find_nodes(
op="call_function", target=torch.ops.auto_deploy.torch_onnx_attention_plugin.default
)
)
for i in range(num_layers):
dynamic_shapes[f"past_key_values_{i}"] = {
0: Dim("batch_size", min=0, max=cm.info.max_batch_size),
3: Dim("past_len", min=1, max=4096),
}
dynamic_shapes["last_token_ids"] = {
0: Dim("batch_size", min=0, max=cm.info.max_batch_size),
1: Dim("num_selected_tokens", min=1, max=cm.info.max_seq_len),
}
# Create custom translation table for ONNX export
# Map torch custom ops to their corresponding onnxscript translation functions
custom_translation_table = {
# Before fuse rope attention
# NOTE (yoco): This 2 ops will be fused into the AttentionPlugin operation
# in the fuse_rope_attention transform.
# However, when TensorRT-LLM changed, the fusion might not be applied.
# And for debug purpose we might want to export the .onnx to check the graph.
# So let's just keep them here.
torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin.default: _translate_rope_op,
torch.ops.auto_deploy.torch_attention.default: _translate_torch_attention_op,
# Before and after fuse rope attention
torch.ops.auto_deploy.torch_linear_simple.default: _translate_simple_linear_op,
# After fuse rope attention
torch.ops.auto_deploy.torch_onnx_attention_plugin.default: _translate_rope_attention_op,
torch.ops.auto_deploy.torch_onnx_gather_nd.default: _translate_gather_nd_op,
}
# Prepare output names
output_names = []
output_node = gm.graph.find_nodes(op="output")[0]
outputs = output_node.args[0]
for output in outputs:
output_names.append(output.name)
output_names[0] = "logits"
# Register ONNX custom ops
_onnx_schemas.register_onnx_schemas()
# Export the graph module to ONNX using dynamo (more advanced tracer)
ad_logger.info(f"Exporting GraphModule to ONNX with dynamo: {output_path}")
torch.onnx.export(
gm,
tuple(args),
output_path,
opset_version=20,
kwargs=kwargs,
dynamo=True,
dynamic_shapes=dynamic_shapes,
report=False,
output_names=output_names,
custom_translation_table=custom_translation_table,
)
ad_logger.info(f"Successfully exported ONNX model to {output_path}")
return True
def _export_config_json(
self, factory: ModelFactory, output_dir: str, is_eagle_base: bool
) -> None:
"""Export model configuration to config.json.
Generates a configuration file containing model architecture parameters
such as hidden size, number of layers, attention heads, etc.
Args:
factory: The ModelFactory containing model configuration.
output_dir: Directory path where config.json will be saved.
is_eagle_base: If True, exports as "eagle3_base" model type,
otherwise exports as "llm" model type.
"""
model_type = "eagle3_base" if is_eagle_base else "llm"
assert isinstance(factory, hf.AutoModelFactory)
model_config, _ = factory._get_model_config()
model_config = export_llm_config(model_config, model_type)
# Add reduced_vocab_size to config if vocabulary reduction is used
reduced_vocab_size = None # TODO: Implement this
if reduced_vocab_size is not None:
model_config["reduced_vocab_size"] = reduced_vocab_size
ad_logger.info(f"Added reduced_vocab_size={reduced_vocab_size} to config")
config_path = os.path.join(output_dir, "config.json")
with open(config_path, "w") as f:
json.dump(model_config, f, indent=2)
ad_logger.info(f"Model configuration saved to {config_path}")
def _export_json_files(
self,
gm: GraphModule,
_cm: CachedSequenceInterface,
factory: ModelFactory,
_shared_config: SharedConfig,
) -> None:
"""Export all JSON configuration and tokenizer files required for deployment.
This method orchestrates the export of:
1. config.json - Model architecture configuration
2. Tokenizer files - added_tokens.json, special_tokens_map.json,
tokenizer.json, tokenizer_config.json, vocab.json
3. processed_chat_template.json - Processed chat template for inference
Args:
gm: The GraphModule containing model configuration.
_cm: CachedSequenceInterface (unused, kept for interface consistency).
factory: ModelFactory used to initialize tokenizer and get model directory.
_shared_config: SharedConfig (unused, kept for interface consistency).
"""
# Export model configuration (architecture params, layer counts, etc.)
is_eagle_base = self.config.is_eagle_base
output_dir = self.config.output_dir
self._export_config_json(factory, output_dir, is_eagle_base)
# Export tokenizer files for text processing during inference
# Includes: added_tokens.json, special_tokens_map.json, tokenizer.json,
# tokenizer_config.json, vocab.json
tokenizer = factory.init_tokenizer()
tokenizer.save_pretrained(output_dir)
# Export processed chat template for conversational inference
model_dir = factory.model
process_chat_template(model_dir, output_dir)
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
"""Apply the ONNX export transform to the graph module.
This is the main entry point that orchestrates the export process:
1. Creates the output directory if it doesn't exist
2. Exports all JSON configuration files (config, tokenizer, chat template)
3. Exports the ONNX model with dynamic shapes and custom ops
Note:
Unlike other transforms, this does NOT modify the graph.
It only performs side effects (writing files to disk).
Args:
gm: The PyTorch FX GraphModule to export.
cm: CachedSequenceInterface containing runtime constraints.
factory: ModelFactory for tokenizer initialization.
shared_config: SharedConfig for transform coordination.
Returns:
Tuple containing:
- gm: The original GraphModule (unchanged)
- info: TransformInfo with export status (is_clean=True since graph is unmodified)
"""
# Ensure output directory exists before writing any files
if not self.config.output_dir.exists():
self.config.output_dir.mkdir(parents=True, exist_ok=True)
# Step 1: Export all auxiliary JSON files (config, tokenizer, chat template)
self._export_json_files(gm, cm, factory, shared_config)
# Step 2: Export the ONNX model with dynamic shapes
success = self._export_onnx_model(gm, cm, factory, shared_config)
# Return original graph unchanged with export status info
# This transform is "clean" because it doesn't modify the graph structure
info = TransformInfo(
skipped=not success,
num_matches=1 if success else 0,
is_clean=True, # Graph is not modified
has_valid_shapes=True, # Shape validity is preserved
)
return gm, info

View File

@ -0,0 +1,551 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import operator
from typing import List, Tuple
import torch
import transformers
from torch.fx import GraphModule, Node
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils._graph import add_graph_input, add_graph_output, remove_graph_input
from ...utils.logger import ad_logger
from ...utils.node_utils import extract_op_args, is_op
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
class MatchResult:
"""Container for matched RoPE + attention pattern nodes.
This class stores all the relevant nodes and metadata from a successfully
matched RoPE (Rotary Position Embedding) + attention pattern in the graph.
It is used to facilitate the pattern replacement during the fusion transform.
Attributes:
q: The original query tensor node before view/reshape.
k: The original key tensor node before view/reshape.
v: The original value tensor node before view/reshape.
cos: The cosine embedding node for RoPE.
sin: The sine embedding node for RoPE.
attn_node: The attention operation node to be replaced.
rope_node: The RoPE operation node to be fused.
head_dim: Dimension of each attention head.
num_q_heads: Number of query attention heads.
num_kv_heads: Number of key-value attention heads (for GQA/MQA).
"""
def __init__(
self,
q: Node,
k: Node,
v: Node,
cos: Node,
sin: Node,
attn_node: Node,
rope_node: Node,
head_dim: int,
num_q_heads: int,
num_kv_heads: int,
):
"""Initialize MatchResult with matched pattern nodes and metadata.
Args:
q: Query tensor node.
k: Key tensor node.
v: Value tensor node.
cos: RoPE cosine node.
sin: RoPE sine node.
attn_node: Attention operation node.
rope_node: RoPE operation node.
head_dim: Attention head dimension.
num_q_heads: Number of query heads.
num_kv_heads: Number of key-value heads.
"""
self.q = q
self.k = k
self.v = v
self.cos = cos
self.sin = sin
self.attn_node = attn_node
self.rope_node = rope_node
self.head_dim = head_dim
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
def __repr__(self):
"""Return string representation of the match result."""
return (
f"MatchResult(q={self.q.name}, k={self.k.name}, v={self.v.name}, "
f"cos={self.cos.name}, sin={self.sin.name}, head_dim={self.head_dim}, "
f"num_q_heads={self.num_q_heads}, num_kv_heads={self.num_kv_heads})"
)
@TransformRegistry.register("fuse_rope_attention")
class FuseRopeAttention(BaseTransform):
"""Transform that fuses RoPE and attention operations into a single AttentionPlugin.
This transform identifies patterns in the graph where RoPE (Rotary Position
Embedding) is applied to query and key tensors followed by scaled dot-product
attention, and replaces them with a fused AttentionPlugin operation.
The fusion provides several benefits:
- Reduced memory bandwidth by eliminating intermediate tensors
- Enables KV-cache integration for efficient autoregressive generation
- Allows backend-specific optimizations (e.g., TensorRT, EdgeLLM)
Pattern matched (backwards from attention):
1. torch_attention(rope_q, rope_k, view_v, attn_mask, ...)
2. rope_q = rope[1], rope_k = rope[0]
3. rope = torch_rope_with_explicit_cos_sin(cont_q, cont_k, cos, sin, 2)
4. cont_q = contiguous(view_q), cont_k = contiguous(view_k)
5. view_q = view(q, ...), view_k = view(k, ...), view_v = view(v, ...)
The transform also:
- Adds new graph inputs: context_lengths, rope_rotary_cos_sin, kvcache_start_index
- Adds past_key_values inputs for each attention layer
- Adds present_key_values outputs for each attention layer
- Removes the position_ids placeholder (no longer needed after fusion)
"""
def _get_batch_size_and_max_seq_len(self, gm: GraphModule) -> Tuple[int, int]:
"""Get batch size and max sequence length from the graph.
Args:
gm: The GraphModule to get batch size and max sequence length from.
Returns:
Tuple of (batch_size_dim, max_seq_len_dim, batch_size_sym_node, max_seq_len_sym_node).
Note:
For clarity, we return the symbolic nodes as well, which can be
used in operations like view or reshape.
"""
graph = gm.graph
input_ids_node = graph.find_nodes(op="placeholder", target="input_ids")[0]
position_ids_node = graph.find_nodes(op="placeholder", target="position_ids")[0]
input_ids_meta = input_ids_node.meta.get("val")
batch_size_dim = input_ids_meta.size(0)
max_seq_len_dim = input_ids_meta.size(1)
# We scan all the sym_size.int nodes to find the symbolic nodes for
# batch size and max sequence length. The symbolic nodes has two sources.
# 1. The input_ids placeholder.
# 2. The position_ids placeholder.
# Both of their shapes are (batch_size, max_seq_len).
sym_ints = graph.find_nodes(op="call_function", target=torch.ops.aten.sym_size.int)
for sym_int in sym_ints:
if sym_int.args[0] != input_ids_node and sym_int.args[0] != position_ids_node:
continue
if sym_int.args[1] == 0:
batch_size_sym_node = sym_int
elif sym_int.args[1] == 1:
max_seq_len_sym_node = sym_int
assert batch_size_sym_node is not None and max_seq_len_sym_node is not None
return batch_size_dim, max_seq_len_dim, batch_size_sym_node, max_seq_len_sym_node
def _get_config_head_dim(self, model_config: transformers.PretrainedConfig) -> int:
"""Get head dimension from model config."""
if hasattr(model_config, "head_dim"):
return model_config.head_dim
else:
return model_config.hidden_size // model_config.num_attention_heads
def _match_rope_attention_pattern(
self, gm: GraphModule, model_config: transformers.PretrainedConfig
) -> List[MatchResult]:
"""Match RoPE + attention patterns in the computation graph.
Traverses the graph backwards from attention nodes to identify the
complete pattern of RoPE application followed by attention computation.
Pattern structure (backwards from attention):
1. torch_attention(rope_q, rope_k, bind_v, attn_mask, ...)
2. rope_q, rope_k = rope[1], rope[0]
3. rope = torch_rope_with_explicit_cos_sin(bind_q, bind_k, cos, sin, 2)
Args:
gm: The GraphModule to search for patterns.
Returns:
List of MatchResult objects, each containing the matched nodes
and extracted metadata (head_dim, num_q_heads, num_kv_heads).
"""
matches = []
graph = gm.graph
head_dim = self._get_config_head_dim(model_config)
# Iterate through all nodes to find attention ops
for attn_node in graph.nodes:
if not is_op(attn_node, torch.ops.auto_deploy.torch_attention):
continue
if attn_node.args[10] != "bsnd":
ad_logger.error(
f" Skipping: attention layout is not bsnd: {attn_node.kwargs.get('layout', None)}"
)
continue
ad_logger.debug(f"Found attention node: {attn_node.name}")
# Extract attention inputs: (rope_q, rope_k, v, attn_mask, ...)
if len(attn_node.args) < 4:
ad_logger.error(f" Skipping: insufficient args ({len(attn_node.args)})")
continue
rope_q_node, rope_k_node, bind_v = extract_op_args(attn_node, "query", "key", "value")
# Step 1: Match rope_q and rope_k as getitem[1] and getitem[0] from rope output
if not (is_op(rope_q_node, operator.getitem) and is_op(rope_k_node, operator.getitem)):
ad_logger.error(" Skipping: rope_q or rope_k not getitem")
continue
# Verify they come from the same rope node
assert rope_q_node.target == operator.getitem
assert rope_k_node.target == operator.getitem
rope_node_from_q, rope_q_idx = rope_q_node.args
rope_node_from_k, rope_k_idx = rope_k_node.args
if rope_node_from_q != rope_node_from_k:
ad_logger.error(" Skipping: rope_q and rope_k come from different rope nodes")
continue
rope_node = rope_node_from_q
# Verify getitem indices: rope[0] = rope_k, rope[1] = rope_q
if rope_k_idx != 0 or rope_q_idx != 1:
ad_logger.error(
f" Skipping: incorrect getitem indices (k={rope_k_idx}, q={rope_q_idx})"
)
continue
# Step 2: Match the rope node
if not is_op(rope_node, torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin):
ad_logger.error(" Skipping: not a rope node")
continue
# Extract rope inputs: (cont_q, cont_k, cos, sin, 2)
if len(rope_node.args) < 5:
ad_logger.error(f" Skipping: rope has insufficient args ({len(rope_node.args)})")
continue
bind_k = rope_node.args[0]
bind_q = rope_node.args[1]
cos_node = rope_node.args[2]
sin_node = rope_node.args[3]
num_q_heads = model_config.num_attention_heads
num_kv_heads = model_config.num_key_value_heads
# Successfully matched the pattern!
match = MatchResult(
q=bind_q,
k=bind_k,
v=bind_v,
cos=cos_node,
sin=sin_node,
attn_node=attn_node,
rope_node=rope_node,
head_dim=head_dim,
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
)
matches.append(match)
ad_logger.debug(f" ✓ Matched pattern: {match}")
return matches
def _add_global_placeholders(self, gm: GraphModule, factory: ModelFactory) -> tuple:
"""Add global input placeholders required by the fused AttentionPlugin.
Creates three new graph inputs that are shared across all attention layers:
- context_lengths: Actual sequence length for each batch element
- rope_rotary_cos_sin: Precomputed RoPE embeddings
- kvcache_start_index: Starting position in KV-cache for each batch
These inputs enable dynamic sequence handling and efficient KV-cache
management during autoregressive generation.
Args:
gm: GraphModule to add placeholders to.
cm: CachedSequenceInterface for registering dynamic shapes.
Returns:
Tuple of (context_lengths_node, rope_rotary_cos_sin_node,
kvcache_start_index_node).
"""
graph = gm.graph
# Find token_ids placeholder to get batch_size symbolic dimension
token_ids_node = None
for node in graph.nodes:
if node.op == "placeholder" and "token" in node.name.lower():
token_ids_node = node
break
if token_ids_node is None:
# Fallback: use first placeholder
token_ids_node = graph.find_nodes(op="placeholder", sort=True)[0]
batch_size_dim, max_seq_len_dim, _, _ = self._get_batch_size_and_max_seq_len(gm)
ad_logger.debug(f"Extracted batch_size={batch_size_dim}, max_seq_len={max_seq_len_dim}")
# 1. Add context_lengths placeholder: int32[batch_size]
context_lengths_example = torch.zeros(batch_size_dim, dtype=torch.int32, device="meta")
context_lengths_node = add_graph_input(
gm, name="context_lengths", val=context_lengths_example
)
ad_logger.debug(f"Added context_lengths placeholder: {context_lengths_node.name}")
# 2. Add rope_rotary_cos_sin placeholder: float32[rope_batch_size, rope_max_position_length, 64]
# Create with concrete example tensor
model_config, _ = factory._get_model_config()
head_dim = self._get_config_head_dim(model_config)
rope_example = torch.zeros(
batch_size_dim, max_seq_len_dim, head_dim, dtype=torch.float32, device="meta"
)
rope_rotary_cos_sin_node = add_graph_input(gm, name="rope_rotary_cos_sin", val=rope_example)
ad_logger.debug(f"Added rope_rotary_cos_sin placeholder: {rope_rotary_cos_sin_node.name}")
# 3. Add kvcache_start_index placeholder: int32[batch_size]
kvcache_start_index_example = torch.zeros(batch_size_dim, dtype=torch.int32, device="meta")
kvcache_start_index_node = add_graph_input(
gm, name="kvcache_start_index", val=kvcache_start_index_example
)
ad_logger.debug(f"Added kvcache_start_index placeholder: {kvcache_start_index_node.name}")
return context_lengths_node, rope_rotary_cos_sin_node, kvcache_start_index_node
def _perform_replacement(
self,
gm: GraphModule,
cm: "CachedSequenceInterface",
matches: List[MatchResult],
context_lengths_node: Node,
rope_rotary_cos_sin_node: Node,
kvcache_start_index_node: Node,
) -> int:
"""Replace matched RoPE + attention patterns with fused AttentionPlugin.
For each matched pattern, this method:
1. Creates a past_key_values input placeholder for KV-cache
2. Reshape Q, K, V to (batch_size, seq_len, -1)
3. Concatenates reshaped Q, K, V tensors into a single QKV tensor
4. Inserts the fused AttentionPlugin operation
5. Creates getitem nodes to extract attention output and present KV-cache
6. Replaces the original attention node with the fused output
7. Adds present_key_values to graph outputs
Args:
gm: The GraphModule being transformed.
cm: CachedSequenceInterface for shape management.
matches: List of matched patterns to replace.
context_lengths_node: Shared context lengths input node.
rope_rotary_cos_sin_node: Shared RoPE embeddings input node.
kvcache_start_index_node: Shared KV-cache index input node.
Returns:
Number of patterns successfully replaced.
"""
graph = gm.graph
past_len = 4096 # Does not matter, this value will be replaced by symbolic dimension
# Get batch size & max sequence length
batch_size_dim, _, batch_size_sym_node, max_seq_len_sym_node = (
self._get_batch_size_and_max_seq_len(gm)
)
# Process each match
for match_id, match in enumerate(matches):
ad_logger.debug(f"Processing match {match_id}: {match}")
# 1. Create past_key_values_<id> placeholder
# Shape: float16[batch_size, 2, num_kv_heads, past_len, head_dim]
past_key_values_example = torch.zeros(
batch_size_dim,
2,
match.num_kv_heads,
past_len,
match.head_dim,
dtype=torch.float16,
device="meta",
)
past_key_values_node = add_graph_input(
gm, name=f"past_key_values_{match_id}", val=past_key_values_example
)
ad_logger.debug(f"Added past_key_values_{match_id} placeholder")
# 2. Reshape Q, K, V to (batch_size, seq_len, -1)
with graph.inserting_before(match.attn_node):
q_node = graph.call_function(
torch.ops.aten.view.default,
args=(match.q, (batch_size_sym_node, max_seq_len_sym_node, -1)),
)
k_node = graph.call_function(
torch.ops.aten.view.default,
args=(match.k, (batch_size_sym_node, max_seq_len_sym_node, -1)),
)
v_node = graph.call_function(
torch.ops.aten.view.default,
args=(match.v, (batch_size_sym_node, max_seq_len_sym_node, -1)),
)
ad_logger.debug(
f"Reshaped Q, K, V to (batch_size, seq_len, -1): {q_node.name}, {k_node.name}, {v_node.name}"
)
# 3. Concatenate reshaped Q, K, V to (batch_size, seq_len, -1)
with graph.inserting_before(match.attn_node):
qkv_node = graph.call_function(
torch.ops.aten.cat.default, args=([q_node, k_node, v_node], -1)
)
ad_logger.debug(f"Created qkv concat node: {qkv_node.name}")
# 4. Create AttentionPlugin node
enable_tree_attention = 0
head_dim = match.head_dim
num_kv_heads = match.num_kv_heads
num_q_heads = match.num_q_heads
with graph.inserting_before(match.attn_node):
cast_node = graph.call_function(
torch.ops.aten.to.dtype, args=(qkv_node, torch.float16)
)
rope_attn_node = graph.call_function(
torch.ops.auto_deploy.torch_onnx_attention_plugin.default,
args=(
cast_node,
past_key_values_node,
context_lengths_node,
rope_rotary_cos_sin_node,
kvcache_start_index_node,
enable_tree_attention,
head_dim,
num_kv_heads,
num_q_heads,
),
)
ad_logger.debug(f"Created AttentionPlugin node: {rope_attn_node.name}")
# 5. Create getitem nodes for rope_attn outputs
with graph.inserting_after(rope_attn_node):
rope_attn_output = graph.call_function(operator.getitem, args=(rope_attn_node, 0))
rope_attn_present_kv = graph.call_function(
operator.getitem,
args=(rope_attn_node, 1),
name=f"present_key_values_{match_id}",
)
ad_logger.debug(
f"Created getitem nodes: {rope_attn_output.name}, {rope_attn_present_kv.name}"
)
# 6. Replace all uses of attn_node with rope_attn_output
match.attn_node.replace_all_uses_with(rope_attn_output)
ad_logger.debug(f"Replaced {match.attn_node.name} with {rope_attn_output.name}")
# 7. Add present_key_values_<id> to graph outputs using add_graph_output
add_graph_output(gm, rope_attn_present_kv, f"present_key_values_{match_id}")
ad_logger.debug(f"Added present_key_values_{match_id} to graph outputs")
# 8. Eliminate dead code (remove old pattern nodes)
graph.eliminate_dead_code()
ad_logger.debug("Eliminated dead code")
return len(matches)
def _remove_position_ids_placeholder(self, gm: GraphModule) -> str:
"""Remove the position_ids placeholder from the graph.
After fusing RoPE into AttentionPlugin, position_ids is no longer needed
as an input since positional information is now provided through
rope_rotary_cos_sin and kvcache_start_index.
This method also updates any sym_size operations that were reading from
position_ids to use input_ids instead.
Args:
gm: The GraphModule to modify.
Returns:
Name of the removed position_ids node.
"""
graph = gm.graph
position_ids_node = graph.find_nodes(op="placeholder", target="position_ids")[0]
input_ids_node = graph.find_nodes(op="placeholder", target="input_ids")[0]
# Get size from token_ids placeholder instead of position_ids placeholder
sym_size_int_nodes = graph.find_nodes(
op="call_function", target=torch.ops.aten.sym_size.int
)
for node in sym_size_int_nodes:
if node.args[0] == position_ids_node:
new_args = (input_ids_node, *node.args[1:])
node.args = new_args
return remove_graph_input(gm, position_ids_node)
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
ad_logger.debug("Fusing rope attention")
# Perform pattern matching
model_config, _ = factory._get_model_config()
matches = self._match_rope_attention_pattern(gm, model_config)
cnt = len(matches)
ad_logger.info(f"Matched {cnt} rope attention patterns")
if cnt == 0:
ad_logger.info("No patterns matched, skipping replacement")
info = TransformInfo(skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True)
return gm, info
# Add global placeholders
context_lengths_node, rope_rotary_cos_sin_node, kvcache_start_index_node = (
self._add_global_placeholders(gm, factory)
)
# Perform replacement for all matches
num_replaced = self._perform_replacement(
gm,
cm,
matches,
context_lengths_node,
rope_rotary_cos_sin_node,
kvcache_start_index_node,
)
# Remove position_ids placeholder
removed_node_name = self._remove_position_ids_placeholder(gm)
ad_logger.debug(f"Removed position_ids placeholder: {removed_node_name}")
info = TransformInfo(
skipped=False, num_matches=num_replaced, is_clean=True, has_valid_shapes=True
)
ad_logger.info(f"Fused {num_replaced} rope attention patterns")
return gm, info

View File

@ -0,0 +1,162 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
from torch.fx import GraphModule, Node
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils._graph import add_graph_input
from ...utils.logger import ad_logger
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
@TransformRegistry.register("gather_last_token_ids")
class GatherLastTokenIds(BaseTransform):
def _find_last_linear_simple(self, gm: GraphModule) -> Node:
"""Find the final torch_linear_simple node that produces logits.
This is the last matmul operation in the graph
that produces the vocabulary logits.
"""
graph = gm.graph
# Look for the output node and trace back to find MatMul
output_node = graph.find_nodes(op="output")[0]
# Look for the last torch_linear_simple node
linear_simple_nodes = graph.find_nodes(
op="call_function", target=torch.ops.auto_deploy.torch_linear_simple.default, sort=True
)
assert len(linear_simple_nodes) > 0, "No linear simple nodes found"
last_linear_simple_node = linear_simple_nodes[-1]
# Verify that the last linear simple node is the one producing the logits
if last_linear_simple_node not in output_node.args[0]:
ad_logger.warning(
f"Last linear simple node {last_linear_simple_node.name} is not the one producing the logits"
)
return None
return last_linear_simple_node
def _add_last_token_ids_input(self, gm: GraphModule) -> Node:
"""Add last_token_ids as a graph input.
Shape: int64[batch_size, num_selected_tokens]
"""
graph = gm.graph
# Find token_ids or input_ids placeholder to get batch_size dimension
token_ids_node = None
for node in graph.nodes:
if node.op == "placeholder" and (
"token" in node.name.lower() or "input_ids" in node.name
):
token_ids_node = node
break
if token_ids_node is None:
# Fallback: use first placeholder
token_ids_node = graph.find_nodes(op="placeholder", sort=True)[0]
ad_logger.info(f"Using {token_ids_node.name} to extract batch_size dimension")
# Get symbolic batch_size dimension
token_ids_meta = token_ids_node.meta.get("val")
assert token_ids_meta is not None, "token_ids_meta is None"
batch_size_dim = token_ids_meta.size(0)
ad_logger.info(f"Extracted batch_size={batch_size_dim}")
# Add last_token_ids placeholder: int64[batch_size]
num_selected_tokens = 2
input_shape = (batch_size_dim, num_selected_tokens)
last_token_ids_example = torch.zeros(input_shape, dtype=torch.int64, device="meta")
last_token_ids_node = add_graph_input(gm, name="last_token_ids", val=last_token_ids_example)
ad_logger.info(f"Added last_token_ids placeholder: {last_token_ids_node.name}")
return last_token_ids_node
def _insert_gather_nd(
self,
gm: GraphModule,
linear_simple: Node,
last_token_ids_node: Node,
) -> bool:
"""Insert GatherND operation before MatMul.
The pattern to create:
1. Create GatherND with the linear input and last_token_ids
2. Replace MatMul input with GatherND output
"""
graph = gm.graph
# Get the input to linear_simple (should be from the previous layer)
linear_input = linear_simple.args[0]
ad_logger.info(
f"Linear input: {linear_input.name if hasattr(linear_input, 'name') else linear_input}"
)
# 1. Insert GatherND operation before linear_simple
# GatherND takes the hidden states (linear_input) and gathers based on last_token_ids
with graph.inserting_before(linear_simple):
unsqueeze_node = graph.call_function(
torch.ops.aten.unsqueeze.default, args=(last_token_ids_node, -1)
)
gather_nd_node = graph.call_function(
torch.ops.auto_deploy.torch_onnx_gather_nd.default,
args=(linear_input, unsqueeze_node, 1), # Use linear_input, not linear_simple!
)
ad_logger.info(f"Created GatherND node: {gather_nd_node.name}")
# 2. Replace linear_simple's first input with GatherND output
linear_simple.replace_input_with(linear_input, gather_nd_node)
ad_logger.info("Replaced Linear input with GatherND output")
return True
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
ad_logger.info("Adding last_token_ids gather operation")
# Step 1: Find last linear simple node
linear_simple_node = self._find_last_linear_simple(gm)
if linear_simple_node is None:
ad_logger.info("Could not find last linear simple node, skipping")
return gm, TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)
# Step 2: Add last_token_ids input
last_token_ids_node = self._add_last_token_ids_input(gm)
# Step 3: Insert GatherND operation
success = self._insert_gather_nd(gm, linear_simple_node, last_token_ids_node)
if not success:
ad_logger.info("Failed to insert GatherND operation")
return gm, TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)
return gm, TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)

View File

@ -0,0 +1,165 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
import torch
from torch.fx import GraphModule, Node
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
@TransformRegistry.register("short_reshape_attention_output")
class ShortReshapeAttentionOutput(BaseTransform):
"""Transform that simplifies reshape operations after attention outputs.
This transform optimizes reshape nodes that follow AttentionPlugin outputs
by replacing fixed shape dimensions with dynamic symbolic dimensions derived
from the input tensor. This ensures the reshape operation correctly handles
dynamic batch sizes and sequence lengths.
The transform identifies patterns where:
1. A reshape node follows an AttentionPlugin output
2. The reshape's input can be traced back to a linear projection
It then replaces the reshape with a new one that uses symbolic dimensions
from the linear projection's input, resulting in the shape [batch, seq, -1].
"""
def _lookup_ascending_node(self, node: Node, target, max_depth: int = 3) -> Node:
"""Recursively search for an ancestor node with a specific target operation.
Traverses the graph backwards through node arguments to find a node
that matches the specified target operation type.
Args:
node: Starting node for the backward search.
target: Target operation to search for (e.g., torch.ops.aten.reshape.default).
max_depth: Maximum number of levels to traverse (default: 3).
Returns:
The matching ancestor node if found, None otherwise.
"""
if max_depth == 0:
return None
if node.target == target:
return node
# Helper function to check a single node
def check_node(n):
if isinstance(n, Node):
result = self._lookup_ascending_node(n, target, max_depth - 1)
if result is not None:
return result
return None
# Check all arguments
for arg in node.args:
if isinstance(arg, (tuple, list)):
for item in arg:
result = check_node(item)
if result is not None:
return result
else:
result = check_node(arg)
if result is not None:
return result
return None
def _find_reshape_attention_output(self, gm: GraphModule) -> List[Node]:
"""Find reshape nodes that follow AttentionPlugin outputs.
Searches for reshape operations that:
1. Have an AttentionPlugin in their input chain
2. Can be traced back to a torch_linear_simple operation
The linear operation is needed to extract the original input shape
for constructing the new reshape with symbolic dimensions.
Args:
gm: The GraphModule to search.
Returns:
List of (reshape_node, linear_node) tuples representing the
matched patterns.
"""
reshape_linear_pairs = []
reshape_nodes = gm.graph.find_nodes(
op="call_function", target=torch.ops.aten.reshape.default
)
for node in reshape_nodes:
# Looking for AttentionPlugin from the input of the reshape node
attention_plugin_node = self._lookup_ascending_node(
node.args[0], torch.ops.auto_deploy.torch_onnx_attention_plugin.default
)
if attention_plugin_node is None:
continue
# Looking for torch_linear_simple from the input of the AttentionPlugin node
linear_simple_node = self._lookup_ascending_node(
attention_plugin_node.args[0], torch.ops.auto_deploy.torch_linear_simple.default
)
if linear_simple_node is None:
continue
reshape_linear_pairs.append((node, linear_simple_node))
return reshape_linear_pairs
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
num_matches = 0
reshape_linear_pairs = self._find_reshape_attention_output(gm)
self._log_info(f"Found {len(reshape_linear_pairs)} reshape_linear_pairs")
graph = gm.graph
for reshape_node, linear_simple_node in reshape_linear_pairs:
# Get the input tensor of reshape node (keep it unchanged)
shape_src = linear_simple_node.args[0]
data_input = reshape_node.args[0]
# Extract first two dimensions from the input tensor
with graph.inserting_before(reshape_node):
dim0 = graph.call_function(torch.ops.aten.sym_size.int, args=(shape_src, 0))
dim1 = graph.call_function(torch.ops.aten.sym_size.int, args=(shape_src, 1))
# Create new shape using input tensor's first two dimensions
new_shape = [dim0, dim1, -1]
# Create a NEW reshape node instead of modifying the existing one
# This ensures the graph dependencies are correctly established
new_reshape_node = graph.call_function(
torch.ops.aten.reshape.default, args=(data_input, new_shape)
)
# Replace all uses of the old reshape node with the new one
reshape_node.replace_all_uses_with(new_reshape_node)
# Remove the old reshape node from the graph
graph.erase_node(reshape_node)
num_matches += 1
info = TransformInfo(
skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=False
)
return gm, info

View File

@ -64,11 +64,11 @@ class InferenceOptimizer:
mod = nn.Module()
# iterate over all transforms sorted by stage in the config
for t_name, t_config in self.config.items():
for idx, (t_name, t_config) in enumerate(self.config.items()):
# instantiate transform
transform = TransformRegistry.get(t_name)(t_config)
# run transform
mod = transform(mod, cm, self.factory, self.shared_config)
mod = transform(mod, cm, self.factory, self.shared_config, idx)
############################################################################################
# RETURN OPTIMIZED MODEL

View File

@ -15,12 +15,26 @@ from torch._subclasses import FakeTensor, FakeTensorMode
from torch.fx import Graph, GraphModule, Node
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils._pytree import _LEAF_SPEC
from torch.utils._pytree import _LEAF_SPEC, TreeSpec
from .logger import ad_logger
from .node_utils import get_weight_tensor, is_op
_NoValType = type("_NoValType", (), {})
def _call_post_init_recursive(spec: TreeSpec) -> None:
"""Recursively call __post_init__ on TreeSpec starting from leaves.
This is needed to update internal cached values (like num_leaves) after
modifying children_specs.
"""
if hasattr(spec, "children_specs"):
for child_spec in spec.children_specs:
_call_post_init_recursive(child_spec)
spec.__post_init__()
_NO_VAL = _NoValType()
@ -367,12 +381,7 @@ def add_graph_input(
in_spec_for_args.context.append(name)
# update pytree info recursively with __post_init__ starting at leaves
def call_post_init(spec):
for child_spec in spec.children_specs:
call_post_init(child_spec)
spec.__post_init__()
call_post_init(in_spec)
_call_post_init_recursive(in_spec)
# set fake tensor information if all required information is available
fake_mode: Optional[FakeTensorMode] = _detect_fake_mode_from_gm(gm)
@ -527,3 +536,194 @@ def del_attr_by_name(obj, name):
for part in parts[:-1]:
obj = getattr(obj, part)
delattr(obj, parts[-1])
def add_graph_output(gm: GraphModule, output_node: Node, name: str) -> None:
"""Add a graph output to the given GraphModule.
This function appends a new output to the graph's output node and updates
the pytree_info metadata accordingly.
NOTE: function does NOT do any graph canonicalization. This is left to the user!
Args:
gm (GraphModule): The GraphModule to add the output to.
output_node (Node): The node to add as an output.
name (str): The name of the new output in the output dict.
Raises:
RuntimeError: If the graph has no output node or multiple output nodes.
Implementation Notes (Hacky Details):
1. **Handling single-output graphs (out_spec == _LEAF_SPEC)**:
When the graph originally has a single output, its out_spec is a
_LEAF_SPEC (a leaf node in the pytree). To add a new output, we must
first convert it to a container spec (tuple with children_specs).
2. **Forcing output type to dict**:
The original out_spec.type is typically a model-specific output class
(e.g., CausalLMOutputWithPast from transformers). When we add new
outputs with custom names (like "present_key_values_xx"), the original
class constructor will fail because these names are not valid data
members of that class.
PyTorch uses the out_spec.type's constructor to wrap the output:
result_obj = result_type(**output)
To avoid constructor failures, we forcibly change out_spec.type to
dict using object.__setattr__ (since TreeSpec is a frozen dataclass).
This ensures the output is returned as a plain dictionary instead of
the original model-specific class.
"""
# extract graph and output spec
graph: Graph = gm.graph
# find the output node
output_nodes = graph.find_nodes(op="output")
if len(output_nodes) != 1:
raise RuntimeError(f"Expected exactly one output node, found {len(output_nodes)}")
graph_output_node = output_nodes[0]
# get current outputs
current_outputs = graph_output_node.args[0]
if not isinstance(current_outputs, (tuple, list)):
# single output, convert to tuple
current_outputs = (current_outputs,)
# append new output
new_outputs = tuple(current_outputs) + (output_node,)
graph_output_node.args = (new_outputs,)
# update pytree info: append spec for the new output
# The out_spec structure mirrors the output structure
# For a tuple of outputs, out_spec should be a TreeSpec with children_specs
# And graph._codegen.pytree_info is a NamedTuple, so we have to use _replace to replace the out_spec
out_spec = graph._codegen.pytree_info.out_spec
assert out_spec is not None, "Graph must have an out_spec"
if out_spec == _LEAF_SPEC:
new_out_spec = TreeSpec(type=tuple, children_specs=[_LEAF_SPEC], context=["output"])
graph._codegen.pytree_info = graph._codegen.pytree_info._replace(out_spec=new_out_spec)
out_spec = graph._codegen.pytree_info.out_spec
if hasattr(out_spec, "children_specs"):
# out_spec is already a container (tuple/list), append to it
out_spec.children_specs.append(_LEAF_SPEC)
out_spec.context.append(name)
else:
# This shouldn't happen in normal cases, but handle it just in case
# If out_spec is a single leaf, we need to convert it to a tuple spec
raise NotImplementedError(
"Cannot add output to a graph with non-container out_spec. "
"This case is not currently supported."
)
# update pytree info recursively with __post_init__ starting at leaves
_call_post_init_recursive(out_spec)
# Change the type of the output spec to dict,
#
# NOTE(yocox) This will affect how torch inteprete the output of the graph.
# The original type is some LLM output result class,
# for example, CausalLMOutputWithPast, depends on the model.
# To add new fields to the output, we need to change the type to dict.
# The type's constructor will be used to create an object containing
# the result, for example,
#
# result_obj = reuslt_type(**output)
#
# Because we added new output's with names like "present_key_values_xx" which
# is not a data member of CausalLMOutputWithPast, so the constructor will fail.
# So we need to change the type to dict.
#
# However the out_spec.type is frozen dataclass,
# so we need to use object.__setattr__ to change it.
object.__setattr__(out_spec, "type", dict)
def remove_graph_input(gm: GraphModule, input_node: Node) -> str:
"""Remove a graph input from the given GraphModule.
This is the inverse operation of add_graph_input(). It removes a placeholder node
and updates the pytree_info metadata accordingly. The function automatically detects
whether the node belongs to args or kwargs and updates the correct spec.
NOTE: function does NOT do any graph canonicalization. This is left to the user!
Args:
gm (GraphModule): The GraphModule to remove the input from.
input_node (Node): The placeholder node to remove.
Returns:
str: The name of the removed node.
Raises:
ValueError: If the provided Node is not a placeholder or not in the graph.
RuntimeError: If the input node is still being used by other nodes.
"""
graph: Graph = gm.graph
# validate that the node is a placeholder
if input_node.op != "placeholder":
raise ValueError(
f"Node '{input_node.name}' is not a placeholder node (op='{input_node.op}'). "
f"Only placeholder nodes can be removed as graph inputs."
)
# find all placeholder nodes and validate the node is in this graph
placeholder_nodes = graph.find_nodes(op="placeholder", sort=True)
if input_node not in placeholder_nodes:
raise ValueError(
f"Node '{input_node.name}' is not found in the graph's placeholder nodes. "
f"It may belong to a different graph or have already been removed."
)
# check that the node is not being used
if len(input_node.users) > 0:
user_names = [user.name for user in input_node.users.keys()]
raise RuntimeError(
f"Cannot remove input '{input_node.name}' "
f"because it is still being used by nodes: {user_names}. "
f"Please remove or replace all usages first."
)
# get pytree info
in_spec = graph._codegen.pytree_info.in_spec
args_spec = in_spec.children_specs[0] # tuple for args
kwargs_spec = in_spec.children_specs[1] # dict for kwargs
orig_args = graph._codegen.pytree_info.orig_args
# determine the global index of the node
global_idx = placeholder_nodes.index(input_node)
# determine number of args (kwargs come after args in placeholder order)
num_args = len(args_spec.children_specs)
# determine if this is an arg or kwarg, and compute relative index
if global_idx < num_args:
# it's an arg
relative_idx = global_idx
target_spec = args_spec
is_kwarg = False
else:
# it's a kwarg
relative_idx = global_idx - num_args
target_spec = kwargs_spec
is_kwarg = True
# save the node name before removing
removed_node_name = input_node.name
# remove the node from the graph
graph.erase_node(input_node)
# update pytree info: remove the corresponding spec and arg name
target_spec.children_specs.pop(relative_idx)
if is_kwarg:
target_spec.context.pop(relative_idx)
orig_args.pop(global_idx)
# update pytree info recursively with __post_init__ starting at leaves
_call_post_init_recursive(in_spec)
return removed_node_name

View File

@ -56,7 +56,7 @@ class TorchAttentionReference:
v_flat = v.reshape(1, batch_size * seq_len, -1)
# Call torch backend via custom op registry
output_flat = torch.ops.auto_deploy.torch_cached_attention_with_cache(
output_flat = torch.ops.auto_deploy.torch_cached_attention_with_cache.default(
q_flat,
k_flat,
v_flat,
@ -97,7 +97,7 @@ class TorchAttentionReference:
This function directly calls the torch backend implementation via custom op registry.
"""
return torch.ops.auto_deploy.torch_cached_attention_with_cache(
return torch.ops.auto_deploy.torch_cached_attention_with_cache.default(
q,
k,
v,
@ -148,7 +148,7 @@ class TorchAttentionReference:
batch_info_host = torch.tensor([0, 0, batch_size], device=q.device, dtype=torch.int32)
# Call torch backend via custom op registry
output_flat = torch.ops.auto_deploy.torch_cached_attention_with_cache(
output_flat = torch.ops.auto_deploy.torch_cached_attention_with_cache.default(
q_flat,
k_flat,
v_flat,
@ -185,7 +185,7 @@ class TorchAttentionReference:
This demonstrates how to use the torch backend with additional features.
"""
return torch.ops.auto_deploy.torch_cached_attention_with_cache(
return torch.ops.auto_deploy.torch_cached_attention_with_cache.default(
q,
k,
v,

View File

@ -0,0 +1,65 @@
import os
import onnx
import pytest
from _model_test_utils import get_small_model_config
from tensorrt_llm._torch.auto_deploy.export import export_onnx
from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs
@pytest.mark.parametrize(
"model, output_dir, num_attn_ops",
[
(
"Qwen/Qwen2.5-3B-Instruct",
"/tmp/test_ad_export_onnx_qwen2.5-3b",
36,
),
],
)
def test_ad_export_onnx(model: str, output_dir: str, num_attn_ops: int):
ad_config = LlmArgs(
model=get_small_model_config(model)["args"]["model"],
mode="export_edgellm_onnx",
max_batch_size=13,
max_seq_len=4,
)
ad_config.transforms["export_to_onnx"]["output_dir"] = output_dir
export_onnx(ad_config)
# check if the output directory exists
assert os.path.exists(output_dir)
# check if the output json files exist
assert os.path.exists(os.path.join(output_dir, "added_tokens.json"))
assert os.path.exists(os.path.join(output_dir, "config.json"))
assert os.path.exists(os.path.join(output_dir, "processed_chat_template.json"))
assert os.path.exists(os.path.join(output_dir, "special_tokens_map.json"))
assert os.path.exists(os.path.join(output_dir, "tokenizer.json"))
assert os.path.exists(os.path.join(output_dir, "tokenizer_config.json"))
assert os.path.exists(os.path.join(output_dir, "vocab.json"))
# check if the output onnx file exists
assert os.path.exists(os.path.join(output_dir, "model.onnx"))
# check if the onnx file has 24 AttentionPlugin operators
onnx_model = onnx.load(os.path.join(output_dir, "model.onnx"))
attn_ops = [node for node in onnx_model.graph.node if node.op_type == "AttentionPlugin"]
assert len(attn_ops) == num_attn_ops
# check input and output names of the model
actual_inputs_names = {input.name for input in onnx_model.graph.input}
actual_outputs_names = {output.name for output in onnx_model.graph.output}
expect_inputs_names = {
"input_ids",
"context_lengths",
"rope_rotary_cos_sin",
"kvcache_start_index",
"last_token_ids",
}
expect_outputs_names = {"logits"}
expect_inputs_names.update({f"past_key_values_{i}" for i in range(num_attn_ops)})
expect_outputs_names.update({f"present_key_values_{i}" for i in range(num_attn_ops)})
assert actual_inputs_names == expect_inputs_names
assert actual_outputs_names == expect_outputs_names

View File

@ -0,0 +1,242 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for fuse_rope_attention transformation.
"""
from collections import namedtuple
import pytest
import torch
from torch.export import Dim
# Import modules to register custom ops (torch.ops.auto_deploy.*)
import tensorrt_llm._torch.auto_deploy.custom_ops.torch_attention # noqa: F401
import tensorrt_llm._torch.auto_deploy.custom_ops.torch_rope # noqa: F401
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
torch.manual_seed(0)
def _get_cos_sin(x: torch.Tensor, head_dim):
"""Precompute cos and sin for RoPE."""
batch_size = x.size(0)
seq_len = x.size(1)
cos = x.new_zeros(batch_size, seq_len, head_dim)
sin = x.new_zeros(batch_size, seq_len, head_dim)
return cos, sin
class RopeAttentionModel(torch.nn.Module):
"""
Model that implements the rope + attention pattern that fuse_rope_attention expects.
Pattern:
1. q_proj, k_proj, v_proj
2. view to [batch, seq, num_heads, head_dim]
3. contiguous
4. torch_rope_with_explicit_cos_sin for q and k
5. torch_attention with layout="bsnd"
"""
def __init__(
self,
head_dim: int,
num_q_heads: int,
num_kv_heads: int,
max_seq_len: int = 128,
):
super().__init__()
hidden_size = head_dim * num_q_heads
self.head_dim = head_dim
self.hidden_size = hidden_size
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.head_dim = hidden_size // num_q_heads
self.max_seq_len = max_seq_len
# Linear projections
self.q_proj = torch.nn.Linear(hidden_size, num_q_heads * self.head_dim, bias=False)
self.k_proj = torch.nn.Linear(hidden_size, num_kv_heads * self.head_dim, bias=False)
self.v_proj = torch.nn.Linear(hidden_size, num_kv_heads * self.head_dim, bias=False)
self.o_proj = torch.nn.Linear(num_q_heads * self.head_dim, hidden_size, bias=False)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor, # Required for export API compatibility
) -> torch.Tensor:
x = input_ids.unsqueeze(-1).expand(-1, -1, self.hidden_size).to(torch.float16)
batch_size, seq_len, _ = x.shape
# Generate q, k, v: [batch, seq, hidden_size] -> [batch, seq, num_heads * head_dim]
q = self.q_proj(x) # [batch, seq, num_q_heads * head_dim]
k = self.k_proj(x) # [batch, seq, num_kv_heads * head_dim]
v = self.v_proj(x) # [batch, seq, num_kv_heads * head_dim]
# View to [batch, seq, num_heads, head_dim]
q = q.view(batch_size, seq_len, self.num_q_heads, self.head_dim)
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
# Apply contiguous (this is part of the pattern)
q = torch.ops.aten.contiguous(q)
k = torch.ops.aten.contiguous(k)
# Precompute cos and sin for RoPE
cos, sin = _get_cos_sin(x, self.head_dim)
# Apply RoPE with unsqueeze_dim=2 for BSND layout
# NOTE(yoco): This should be (q, k) according to the definition of torch_rope_with_explicit_cos_sin.
# However, the actual graph is (k, q), Need further investigation.
q_rot, k_rot = torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin(k, q, cos, sin, 2)
# Apply attention with layout="bsnd"
attn_output = torch.ops.auto_deploy.torch_attention(
k_rot,
q_rot,
v,
None, # attn_mask
0.0, # dropout_p
True, # is_causal
None, # scale
None, # sinks
None, # sliding_window
None, # logit_cap
"bsnd", # layout
)
# Reshape back and apply output projection
attn_output = attn_output.reshape(batch_size, seq_len, -1)
output = self.o_proj(attn_output)
return output
def _apply_rope_manual(self, q, k, cos, sin):
"""Manual rope application for fallback/comparison."""
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
cos = cos.unsqueeze(2) # unsqueeze_dim=2 for BSND
sin = sin.unsqueeze(2)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def get_dynamic_shapes(self):
return [
{0: Dim("batch_size", max=8), 1: Dim("seq_len", max=128)},
{0: Dim("batch_size", max=8), 1: Dim("seq_len", max=128)},
]
def convert_placeholder_meta_to_cuda(gm: torch.fx.GraphModule):
"""Convert placeholder meta to cuda."""
for node in gm.graph.nodes:
if node.op == "placeholder":
node.meta["val"] = node.meta["val"].to("cuda")
def _run_test(
head_dim: int,
num_q_heads: int,
num_kv_heads: int,
batch_size: int,
seq_len: int,
):
"""Helper function to run the transformation test."""
model = RopeAttentionModel(
head_dim=head_dim,
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
).to("cuda", torch.float16)
input_ids = torch.randint(0, 10, (batch_size, seq_len), device="cuda", dtype=torch.int32)
position_ids = torch.randint(0, 10, (batch_size, seq_len), device="cuda", dtype=torch.int32)
dynamic_shapes = model.get_dynamic_shapes()
# Export to graph module
args = (input_ids, position_ids)
gm = torch_export_to_gm(model, args=args, dynamic_shapes=dynamic_shapes, clone=True)
# Model config object, include num_attention_heads and hidden_size
class Factory:
def _get_model_config(self):
ModelConfig = namedtuple(
"ModelConfig", ["num_attention_heads", "hidden_size", "num_key_value_heads"]
)
return ModelConfig(
num_attention_heads=num_q_heads,
hidden_size=head_dim * num_q_heads,
num_key_value_heads=num_kv_heads,
), None
# Apply fuse_rope_attention transformation
optimizer = InferenceOptimizer(
Factory(),
{
"fuse_rope_attention": {
"stage": "pattern_matcher",
},
},
)
convert_placeholder_meta_to_cuda(gm)
gm_transformed = optimizer(None, gm)
gm_transformed.to("cuda")
fused_nodes = gm_transformed.graph.find_nodes(
op="call_function", target=torch.ops.auto_deploy.torch_onnx_attention_plugin.default
)
assert len(fused_nodes) == 1, "Expected 1 AttentionPlugin node, got {len(fused_nodes)}"
input_nodes = gm_transformed.graph.find_nodes(op="placeholder")
assert len(input_nodes) == 5, "Expected 5 input nodes, got {len(input_nodes)}"
input_nodes_targets = {node.target for node in input_nodes}
assert input_nodes_targets == {
"input_ids",
"context_lengths",
"rope_rotary_cos_sin",
"kvcache_start_index",
"past_key_values_0",
}
out_node = gm_transformed.graph.find_nodes(op="output")[0]
assert len(out_node.args[0]) == 2, "Expected 2 output nodes, got {len(out_node.args[0])}"
@pytest.mark.parametrize(
"head_dim,num_q_heads,num_kv_heads,batch_size,seq_len",
[
# GQA (Grouped-Query Attention): num_q_heads > num_kv_heads
pytest.param(64, 14, 2, 2, 16, id="gqa_ratio_7"),
],
)
@torch.inference_mode()
def test_fuse_rope_attention(
head_dim: int,
num_q_heads: int,
num_kv_heads: int,
batch_size: int,
seq_len: int,
):
"""Test fuse_rope_attention transformation with various configurations."""
_run_test(
head_dim=head_dim,
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
batch_size=batch_size,
seq_len=seq_len,
)