mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[None][feat] Export ONNX for DriveOS LLM (#10117)
Signed-off-by: yocox <yocox@nvidia.com>
This commit is contained in:
parent
f42a6cbae0
commit
4af47208d8
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,5 +1,6 @@
|
||||
__pycache__/
|
||||
.vscode
|
||||
.cursor
|
||||
*.engine
|
||||
*.engine.config
|
||||
*.cache
|
||||
|
||||
@ -68,6 +68,7 @@ init_ubuntu() {
|
||||
gdb \
|
||||
git-lfs \
|
||||
clang \
|
||||
graphviz \
|
||||
lld \
|
||||
llvm \
|
||||
libclang-rt-dev \
|
||||
|
||||
93
docs/source/torch/auto_deploy/advanced/export_onnx.md
Normal file
93
docs/source/torch/auto_deploy/advanced/export_onnx.md
Normal file
@ -0,0 +1,93 @@
|
||||
# Export ONNX for EdgeLLM
|
||||
|
||||
AutoDeploy provides a mode to export PyTorch/HuggingFace models to ONNX format specifically designed for EdgeLLM deployment. This mode performs graph transformations to fuse RoPE (Rotary Position Embedding) and attention operations into a single `AttentionPlugin` operation, then exports the optimized graph to ONNX.
|
||||
|
||||
## Overview
|
||||
|
||||
The `export_edgellm_onnx` mode differs from the standard AutoDeploy workflow in two key ways:
|
||||
|
||||
1. **Operation Fusion**: Fuses `torch_rope_with_explicit_cos_sin` and `torch_cached_attention_with_cache` into a single `AttentionPlugin` operation
|
||||
1. **ONNX Export**: Outputs an ONNX model file instead of a TensorRT Engine
|
||||
|
||||
## Quick Start
|
||||
|
||||
Use the `onnx_export_llm.py` script to export a model:
|
||||
|
||||
```bash
|
||||
cd examples/auto_deploy
|
||||
python onnx_export_llm.py --model "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
```
|
||||
|
||||
This will export the model to ONNX format in the current directory.
|
||||
|
||||
## Command Line Options
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `--model` | str | Required | HuggingFace model name or path to a local checkpoint |
|
||||
| `--device` | str | `cpu` | Device to use for export (`cpu` or `cuda`) |
|
||||
| `--output_dir` | str | `.` | Directory to save the exported ONNX model |
|
||||
|
||||
## Examples
|
||||
|
||||
### Basic Export
|
||||
|
||||
Export a DeepSeek model with default settings:
|
||||
|
||||
```bash
|
||||
python onnx_export_llm.py --model "Qwen/Qwen2.5-0.5B-Instruct"
|
||||
```
|
||||
|
||||
### Custom Output Location
|
||||
|
||||
Export to a specific directory with a custom filename:
|
||||
|
||||
```bash
|
||||
python onnx_export_llm.py \
|
||||
--model "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" \
|
||||
--output_dir "./exported_models"
|
||||
```
|
||||
|
||||
## Output Files
|
||||
|
||||
The export process generates the following files in the output directory:
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| `model.onnx` | The exported ONNX model with fused attention operations |
|
||||
| `config.json` | Model configuration (architecture, hidden size, etc.) |
|
||||
| `tokenizer.json` | Tokenizer vocabulary and configuration |
|
||||
| `tokenizer_config.json` | Tokenizer settings |
|
||||
| `special_tokens_map.json` | Special token mappings |
|
||||
| `processed_chat_template.json` | Processed chat template for inference |
|
||||
|
||||
## Programmatic Usage
|
||||
|
||||
You can also use the ONNX export functionality programmatically:
|
||||
|
||||
```python
|
||||
from tensorrt_llm._torch.auto_deploy import LLM, AutoDeployConfig
|
||||
|
||||
# Create AutoDeploy config with export_edgellm_onnx mode
|
||||
ad_config = AutoDeployConfig(
|
||||
model="Qwen/Qwen2.5-0.5B-Instruct",
|
||||
mode="export_edgellm_onnx",
|
||||
max_batch_size=8,
|
||||
max_seq_len=512,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
# Configure attention backend
|
||||
ad_config.attn_backend = "torch"
|
||||
|
||||
# Optionally customize output location
|
||||
ad_config.transforms["export_to_onnx"]["output_dir"] = "./my_output"
|
||||
|
||||
# Run the export
|
||||
LLM(**ad_config.to_llm_kwargs())
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- **Device Selection**: Using `cpu` for the `--device` option is recommended to reduce GPU memory footprint during export.
|
||||
- **Custom Operations**: The exported ONNX model contains custom operations (e.g., `AttentionPlugin`) in the `trt` domain that require corresponding implementations in the target inference runtime.
|
||||
@ -60,6 +60,7 @@ The exported graph then undergoes a series of automated transformations, includi
|
||||
- [Expert Configurations](./advanced/expert_configurations.md)
|
||||
- [Performance Benchmarking](./advanced/benchmarking_with_trtllm_bench.md)
|
||||
- [Serving with trtllm-serve](./advanced/serving_with_trtllm_serve.md)
|
||||
- [Export ONNX for EdgeLLM](./advanced/export_onnx.md)
|
||||
|
||||
## Roadmap
|
||||
|
||||
|
||||
78
examples/auto_deploy/onnx_export_llm.py
Normal file
78
examples/auto_deploy/onnx_export_llm.py
Normal file
@ -0,0 +1,78 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""ONNX export script for AutoDeploy models.
|
||||
|
||||
This script exports a HuggingFace model to ONNX format using the AutoDeploy
|
||||
transform pipeline directly, without initializing the full LLM executor.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.export import export_onnx
|
||||
from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Export HuggingFace model to ONNX format using AutoDeploy."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The HF model to use for onnx export.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="The device to use when exporting the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The directory to save the exported ONNX model.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Constructing model from {args.model}")
|
||||
|
||||
# to enable dynamic batch_size, the batch size must > 1
|
||||
# NOTE(yoco): Originally this is 2, however, don't know why, when set to 2,
|
||||
# the batch_size will collapse static int 2 even we explicitly it is dynamic axis.
|
||||
# And more weird, when set to 13, the batch_size will be dynamic.
|
||||
# Probably some value between 2 and 13 will work,
|
||||
# We use 13 here for debugging purpose.
|
||||
max_batch_size = 13
|
||||
max_seq_len = 4
|
||||
|
||||
# Prepare the AutoDeploy config, mode is export_edgellm_onnx
|
||||
ad_config = LlmArgs(
|
||||
model=args.model,
|
||||
mode="export_edgellm_onnx",
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=max_seq_len,
|
||||
device=args.device,
|
||||
)
|
||||
ad_config.attn_backend = "torch"
|
||||
if args.output_dir is not None:
|
||||
ad_config.transforms["export_to_onnx"]["output_dir"] = args.output_dir
|
||||
|
||||
# Use direct InferenceOptimizer instead of LLM to avoid executor initialization
|
||||
export_onnx(ad_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -13,7 +13,7 @@
|
||||
# images are adopted from PostMerge pipelines, the abbreviated commit hash is used instead.
|
||||
IMAGE_NAME=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm
|
||||
|
||||
LLM_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.12-py3-x86_64-ubuntu24.04-trt10.14.1.48-skip-tritondevel-202601230553-10896
|
||||
LLM_SBSA_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.12-py3-aarch64-ubuntu24.04-trt10.14.1.48-skip-tritondevel-202601230553-10896
|
||||
LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-13.1.0-devel-rocky8-x86_64-rocky8-py310-trt10.14.1.48-skip-tritondevel-202601230553-10896
|
||||
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-13.1.0-devel-rocky8-x86_64-rocky8-py312-trt10.14.1.48-skip-tritondevel-202601230553-10896
|
||||
LLM_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.12-py3-x86_64-ubuntu24.04-trt10.14.1.48-skip-tritondevel-202601281024-10117
|
||||
LLM_SBSA_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.12-py3-aarch64-ubuntu24.04-trt10.14.1.48-skip-tritondevel-202601281024-10117
|
||||
LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-13.1.0-devel-rocky8-x86_64-rocky8-py310-trt10.14.1.48-skip-tritondevel-202601281024-10117
|
||||
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-13.1.0-devel-rocky8-x86_64-rocky8-py312-trt10.14.1.48-skip-tritondevel-202601281024-10117
|
||||
|
||||
@ -10,6 +10,8 @@ mpi4py
|
||||
numpy<2
|
||||
onnx>=1.18.0,<1.20.0
|
||||
onnx_graphsurgeon>=0.5.2
|
||||
onnxscript==0.5.4
|
||||
graphviz
|
||||
openai
|
||||
polygraphy
|
||||
psutil
|
||||
|
||||
126
tensorrt_llm/_torch/auto_deploy/config/export_edgellm_onnx.yaml
Normal file
126
tensorrt_llm/_torch/auto_deploy/config/export_edgellm_onnx.yaml
Normal file
@ -0,0 +1,126 @@
|
||||
# This is the set of transforms running in "graph" mode. In this mode, we capture the full graph
|
||||
# of the model and optimize it for inference.
|
||||
transforms:
|
||||
############################################################################################
|
||||
# BUILD MODEL, EXPORT TO GRAPH MODULE, AND CLEAN UP
|
||||
############################################################################################
|
||||
build_model:
|
||||
stage: factory
|
||||
run_per_gm: false
|
||||
device: meta
|
||||
requires_clean_graph: false
|
||||
export_to_gm:
|
||||
stage: export
|
||||
clone_state_dict: false
|
||||
strict: false
|
||||
run_per_gm: false
|
||||
requires_clean_graph: false
|
||||
cleanup_noop_slice:
|
||||
stage: post_export
|
||||
cleanup_noop_add:
|
||||
stage: post_export
|
||||
cleanup_input_constraints:
|
||||
stage: post_export
|
||||
############################################################################################
|
||||
# RUN PATTERN MATCHER TRANSFORMATIONS TO STANDARDIZE GRAPH REPRESENTATION
|
||||
############################################################################################
|
||||
match_moe_pattern:
|
||||
stage: pattern_matcher
|
||||
match_dense_moe_pattern:
|
||||
stage: pattern_matcher
|
||||
match_bmm_moe_pattern:
|
||||
stage: pattern_matcher
|
||||
match_repeat_kv:
|
||||
stage: pattern_matcher
|
||||
run_shape_prop: true
|
||||
match_eager_attention:
|
||||
stage: pattern_matcher
|
||||
requires_shape_prop: true
|
||||
match_sdpa_to_torch_attention:
|
||||
stage: pattern_matcher
|
||||
match_grouped_attention:
|
||||
stage: pattern_matcher
|
||||
match_attention_layout:
|
||||
stage: pattern_matcher
|
||||
attn_layout: bsnd
|
||||
match_rope_pattern:
|
||||
stage: pattern_matcher
|
||||
match_rope_layout:
|
||||
stage: pattern_matcher
|
||||
expected_layout: bsnd
|
||||
############################################################################################
|
||||
# RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION
|
||||
############################################################################################
|
||||
eliminate_redundant_transposes:
|
||||
stage: pattern_matcher
|
||||
quantize_int4_linear_from_config:
|
||||
stage: pattern_matcher
|
||||
quantize_fp8_linear_from_config:
|
||||
stage: pattern_matcher
|
||||
quantize_nvfp4_linear_from_config:
|
||||
stage: pattern_matcher
|
||||
quantize_fp8_bmm_from_config:
|
||||
stage: pattern_matcher
|
||||
quantize_fp8_from_graph:
|
||||
stage: pattern_matcher
|
||||
quantize_nvfp4_from_graph:
|
||||
stage: pattern_matcher
|
||||
quantize_fp8_moe:
|
||||
stage: pattern_matcher
|
||||
quantize_nvfp4_moe:
|
||||
stage: pattern_matcher
|
||||
quantize_mxfp4_moe:
|
||||
stage: pattern_matcher
|
||||
############################################################################################
|
||||
# MOVE MODEL AND LOAD WEIGHTS
|
||||
############################################################################################
|
||||
load_weights:
|
||||
stage: weight_load
|
||||
run_per_gm: false
|
||||
checkpoint_device: cpu
|
||||
move_inputs_to_device:
|
||||
stage: weight_load
|
||||
checkpoint_device: cpu
|
||||
run_per_gm: false
|
||||
############################################################################################
|
||||
# RUN POST-LOAD FUSION AND OPTIMIZATIONS
|
||||
############################################################################################
|
||||
fuse_gemms:
|
||||
stage: post_load_fusion
|
||||
enabled:
|
||||
false # https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this causes OOMs on GPU
|
||||
# But while export ONNX for EdgeLLM, we compile the model on CPU..
|
||||
# So we can enable this transform.
|
||||
# TODO(yoco):However, it currently make non-strict ONNX export fail.
|
||||
fuse_moe:
|
||||
stage: post_load_fusion
|
||||
enabled: true
|
||||
backend: trtllm
|
||||
fuse_fp8_moe:
|
||||
stage: post_load_fusion
|
||||
enabled: true
|
||||
backend: trtllm
|
||||
fuse_nvfp4_moe:
|
||||
stage: post_load_fusion
|
||||
enabled: false
|
||||
############################################################################################
|
||||
# VISUALIZE GRAPH
|
||||
############################################################################################
|
||||
visualize_namespace:
|
||||
stage: visualize
|
||||
enabled: false # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/8460
|
||||
############################################################################################
|
||||
# FUSE Rope Attention & export to ONNX
|
||||
############################################################################################
|
||||
fuse_rope_attention:
|
||||
stage: export_onnx
|
||||
short_reshape_attention_output:
|
||||
stage: export_onnx
|
||||
gather_last_token_ids:
|
||||
stage: export_onnx
|
||||
adapt_to_edgellm:
|
||||
stage: export_onnx
|
||||
export_to_onnx:
|
||||
stage: export_onnx
|
||||
output_dir: "."
|
||||
output_name: "model.onnx"
|
||||
194
tensorrt_llm/_torch/auto_deploy/custom_ops/onnx_attention.py
Normal file
194
tensorrt_llm/_torch/auto_deploy/custom_ops/onnx_attention.py
Normal file
@ -0,0 +1,194 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Custom operations for ONNX export of attention mechanisms.
|
||||
|
||||
This module provides placeholder custom operations for exporting attention-related
|
||||
operations to ONNX format. These operations serve as intermediate representations
|
||||
during the graph transformation pipeline and are intended to be replaced by actual
|
||||
backend implementations during deployment.
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::torch_onnx_attention_plugin", mutates_args=())
|
||||
def attention_plugin(
|
||||
# Inputs
|
||||
qkv: torch.Tensor,
|
||||
past_key_values: torch.Tensor,
|
||||
context_lengths: torch.Tensor,
|
||||
rope_rotary_cos_sin: torch.Tensor,
|
||||
kvcache_start_index: torch.Tensor,
|
||||
# Attributes
|
||||
enable_tree_attention: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
num_q_heads: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Fused attention operation with integrated RoPE (Rotary Position Embedding).
|
||||
|
||||
This custom operation combines rotary position embedding, and scaled
|
||||
dot-product attention into a single fused operation. It also handles
|
||||
KV-cache management for efficient autoregressive generation.
|
||||
|
||||
Note:
|
||||
This is a placeholder implementation for ONNX export. The actual computation
|
||||
is performed by the backend runtime (e.g., TensorRT, EdgeLLM
|
||||
Args:
|
||||
qkv: Concatenated query, key, value tensor of shape
|
||||
[batch_size, seq_len, (num_q_heads + 2 * num_kv_heads) * head_size].
|
||||
past_key_values: KV-cache tensor of shape
|
||||
[batch_size, 2, num_kv_heads, past_seq_len, head_size].
|
||||
context_lengths: Sequence lengths for each batch element, shape [batch_size].
|
||||
rope_rotary_cos_sin: Precomputed RoPE cosine and sine values of shape
|
||||
[batch_size, max_seq_len, head_size].
|
||||
kvcache_start_index: Starting index in KV-cache for each batch element,
|
||||
shape [batch_size].
|
||||
enable_tree_attention: Flag to enable tree attention mode (0 or 1).
|
||||
head_size: Dimension of each attention head.
|
||||
num_kv_heads: Number of key-value heads (for grouped-query attention).
|
||||
num_q_heads: Number of query heads.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- attention_output: Attention output tensor of shape
|
||||
[batch_size, seq_len, num_q_heads, head_size].
|
||||
- present_key_values: Updated KV-cache tensor of shape
|
||||
[batch_size, 2, num_kv_heads, present_seq_len, head_size].
|
||||
"""
|
||||
return qkv.new_empty(0), past_key_values.new_empty(0)
|
||||
|
||||
|
||||
@attention_plugin.register_fake
|
||||
def attention_plugin_fake(
|
||||
qkv: torch.Tensor,
|
||||
past_key_values: torch.Tensor,
|
||||
context_lengths: torch.Tensor,
|
||||
rope_rotary_cos_sin: torch.Tensor,
|
||||
kvcache_start_index: torch.Tensor,
|
||||
enable_tree_attention: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
num_q_heads: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Fake implementation of attention_plugin for torch.compile shape inference.
|
||||
|
||||
This function computes the output shapes without performing actual computation,
|
||||
enabling torch.compile to trace through the custom operation.
|
||||
|
||||
Args:
|
||||
qkv: Concatenated QKV tensor.
|
||||
past_key_values: Previous KV-cache tensor.
|
||||
context_lengths: Sequence lengths per batch.
|
||||
rope_rotary_cos_sin: RoPE embedding values.
|
||||
kvcache_start_index: KV-cache start indices.
|
||||
enable_tree_attention: Tree attention flag.
|
||||
head_size: Attention head dimension.
|
||||
num_kv_heads: Number of KV heads.
|
||||
num_q_heads: Number of query heads.
|
||||
|
||||
Returns:
|
||||
Tuple of empty tensors with correct shapes for attention output and
|
||||
present KV-cache.
|
||||
"""
|
||||
batch_size = qkv.size(0)
|
||||
seq_len = qkv.size(1)
|
||||
past_len = past_key_values.size(3)
|
||||
present_kv_len = seq_len + past_len
|
||||
attn_shape = (batch_size, seq_len, num_q_heads, head_size)
|
||||
present_kv_shape = (batch_size, 2, num_kv_heads, present_kv_len, head_size)
|
||||
return torch.empty(attn_shape, device=qkv.device, dtype=qkv.dtype), torch.empty(
|
||||
present_kv_shape, device=past_key_values.device, dtype=past_key_values.dtype
|
||||
)
|
||||
|
||||
|
||||
def _fake_gather_nd(data: torch.Tensor, indices: torch.Tensor, batch_dims: int) -> torch.Tensor:
|
||||
"""Compute output shape for GatherND operation without actual gathering.
|
||||
|
||||
This helper function creates an empty tensor with the correct output shape
|
||||
for the GatherND operation, used for both the actual op and its fake
|
||||
implementation.
|
||||
|
||||
Args:
|
||||
data: Source tensor of shape [batch_size, seq_len, embedding_dim].
|
||||
indices: Index tensor of shape [batch_size, num_selected] or
|
||||
[batch_size, num_selected, index_depth].
|
||||
batch_dims: Number of leading batch dimensions (must be 1).
|
||||
|
||||
Returns:
|
||||
Empty tensor with shape [batch_size, num_selected, embedding_dim].
|
||||
|
||||
Raises:
|
||||
AssertionError: If batch_dims != 1, data is not 3D, or indices is not 2D/3D.
|
||||
"""
|
||||
assert batch_dims == 1, "Current only support batch_dims = 1"
|
||||
assert data.ndim == 3, "Current only support 3D data tensor"
|
||||
assert indices.ndim == 2 or indices.ndim == 3, "Current only support 2D or 3D indices tensor"
|
||||
|
||||
dim_batch_size = indices.size(0)
|
||||
dim_selected_token = indices.size(1)
|
||||
dim_emb = data.size(-1)
|
||||
result_shape = [dim_batch_size, dim_selected_token, dim_emb]
|
||||
return torch.empty(result_shape, device=data.device, dtype=data.dtype)
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::torch_onnx_gather_nd", mutates_args=())
|
||||
def gather_nd(
|
||||
data: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
batch_dims: int,
|
||||
) -> torch.Tensor:
|
||||
"""N-dimensional gather operation following ONNX gather_nd semantics.
|
||||
|
||||
Gathers slices from the data tensor based on indices, supporting batched
|
||||
operations. This operation is commonly used for selecting specific tokens
|
||||
from a sequence based on their positions.
|
||||
|
||||
Note:
|
||||
This is a placeholder implementation for ONNX export. The actual
|
||||
computation is performed by the backend runtime.
|
||||
|
||||
Args:
|
||||
data: Source tensor to gather from, shape [batch_size, seq_len, embedding_dim].
|
||||
indices: Index tensor specifying which elements to gather,
|
||||
shape [batch_size, num_selected] or [batch_size, num_selected, index_depth].
|
||||
batch_dims: Number of leading dimensions to treat as batch dimensions.
|
||||
Currently only batch_dims=1 is supported.
|
||||
|
||||
Returns:
|
||||
Gathered tensor of shape [batch_size, num_selected, embedding_dim].
|
||||
"""
|
||||
return _fake_gather_nd(data, indices, batch_dims)
|
||||
|
||||
|
||||
@gather_nd.register_fake
|
||||
def gather_nd_fake(
|
||||
data: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
batch_dims: int,
|
||||
) -> torch.Tensor:
|
||||
"""Fake implementation of gather_nd for torch.compile shape inference.
|
||||
|
||||
Args:
|
||||
data: Source tensor to gather from.
|
||||
indices: Index tensor for gathering.
|
||||
batch_dims: Number of batch dimensions.
|
||||
|
||||
Returns:
|
||||
Empty tensor with the correct output shape.
|
||||
"""
|
||||
return _fake_gather_nd(data, indices, batch_dims)
|
||||
@ -3,7 +3,7 @@
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.export as te
|
||||
@ -15,6 +15,9 @@ from ..utils.logger import ad_logger
|
||||
from ..utils.node_utils import is_op
|
||||
from .interface import apply_export_patches
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..llm_args import LlmArgs
|
||||
|
||||
try:
|
||||
from modelopt.torch.quantization.utils import export_torch_mode as torch_export_context
|
||||
except ImportError:
|
||||
@ -342,3 +345,57 @@ def torch_export_to_gm(
|
||||
ad_logger.debug("exported graph: " + str(egm))
|
||||
|
||||
return egm
|
||||
|
||||
|
||||
def export_onnx(ad_config: "LlmArgs") -> nn.Module:
|
||||
"""Export model to ONNX using InferenceOptimizer directly.
|
||||
|
||||
This is a lightweight export path that avoids initializing the full LLM executor,
|
||||
which requires KVCacheManager and other runtime components not needed for ONNX export.
|
||||
|
||||
Args:
|
||||
ad_config: The AutoDeploy configuration for the model. Should use a mode like
|
||||
"export_edgellm_onnx" that includes the export_to_onnx transform.
|
||||
|
||||
Returns:
|
||||
The transformed model after running through the inference optimizer pipeline.
|
||||
|
||||
Example:
|
||||
>>> from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs
|
||||
>>> from tensorrt_llm._torch.auto_deploy.export import export_onnx
|
||||
>>>
|
||||
>>> ad_config = LlmArgs(
|
||||
... model="meta-llama/Llama-2-7b-hf",
|
||||
... mode="export_edgellm_onnx",
|
||||
... max_batch_size=13,
|
||||
... max_seq_len=4,
|
||||
... device="cpu",
|
||||
... )
|
||||
>>> ad_config.transforms["export_to_onnx"]["output_dir"] = "/tmp/onnx_output"
|
||||
>>> model = export_onnx(ad_config)
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from ..shim.interface import CachedSequenceInterface
|
||||
from ..transform.optimizer import InferenceOptimizer
|
||||
|
||||
# 1. Create factory from config
|
||||
factory = ad_config.create_factory()
|
||||
|
||||
# 2. Create CachedSequenceInterface (lightweight, no KVCacheManager initialization)
|
||||
cache_seq_interface = CachedSequenceInterface(
|
||||
max_seq_len=ad_config.max_seq_len,
|
||||
max_batch_size=ad_config.max_batch_size,
|
||||
device=ad_config.device,
|
||||
kv_cache_config=ad_config.kv_cache_config,
|
||||
max_num_tokens=ad_config.max_num_tokens,
|
||||
vocab_size_padded=factory.vocab_size_padded,
|
||||
)
|
||||
|
||||
# 3. Create InferenceOptimizer with transform config
|
||||
inference_optimizer = InferenceOptimizer(
|
||||
factory=factory,
|
||||
config=ad_config.transforms,
|
||||
)
|
||||
|
||||
# 4. Run the transform pipeline (includes export_to_onnx transform)
|
||||
return inference_optimizer(cache_seq_interface)
|
||||
|
||||
@ -206,7 +206,7 @@ class LlmArgs(DynamicYamlMixInForSettings, TorchLlmArgs, BaseSettings):
|
||||
)
|
||||
|
||||
### INFERENCE OPTIMIZER CONFIG #################################################################
|
||||
mode: Literal["graph", "transformers"] = Field(
|
||||
mode: Literal["graph", "transformers", "export_edgellm_onnx"] = Field(
|
||||
default="graph",
|
||||
description="The mode to use for the inference optimizer. Currently, we "
|
||||
"support only the 'graph' and 'transformers' modes, i.e., full-graph capture + optimization"
|
||||
@ -335,5 +335,6 @@ class LlmArgs(DynamicYamlMixInForSettings, TorchLlmArgs, BaseSettings):
|
||||
mapping = {
|
||||
"graph": str(config_path / "default.yaml"),
|
||||
"transformers": str(config_path / "transformers.yaml"),
|
||||
"export_edgellm_onnx": str(config_path / "export_edgellm_onnx.yaml"),
|
||||
}
|
||||
return mapping.get(mode)
|
||||
|
||||
@ -0,0 +1,691 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
PyTorch GraphModule Visualization Tool
|
||||
|
||||
This module provides functionality to convert PyTorch GraphModule to Graphviz diagrams.
|
||||
Supports different node styles and detailed graph annotations.
|
||||
|
||||
Key Features:
|
||||
- Convert FX GraphModule to Graphviz diagrams
|
||||
- Display tensor shape information on edges
|
||||
- Adjust edge width based on tensor element count
|
||||
- Intelligent port assignment for multi-input/output handling
|
||||
- Color coding based on tensor identity
|
||||
|
||||
Usage Example:
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
from graph_module_visualizer import to_dot
|
||||
|
||||
# Trace model
|
||||
model = YourModel()
|
||||
traced = fx.symbolic_trace(model)
|
||||
|
||||
# Generate visualization
|
||||
dot = to_dot(traced, format="svg", include_shapes=True)
|
||||
|
||||
Requirements: pip install graphviz
|
||||
"""
|
||||
|
||||
import math
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import graphviz
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ..utils.logger import ad_logger
|
||||
|
||||
|
||||
def _calculate_edge_width(val) -> float:
|
||||
"""
|
||||
Calculate edge width based on tensor element count
|
||||
Formula: log10(num_elements) + 2
|
||||
|
||||
Args:
|
||||
val: FakeTensor, Tensor, or any object with .shape attribute
|
||||
"""
|
||||
min_width = 2.0
|
||||
max_width = 10.0
|
||||
if not hasattr(val, "shape"):
|
||||
return min_width # Default width
|
||||
|
||||
try:
|
||||
# Calculate total number of elements
|
||||
num_elements = 1
|
||||
for dim in val.shape:
|
||||
if isinstance(dim, int):
|
||||
num_elements *= dim
|
||||
else:
|
||||
num_elements *= 1
|
||||
|
||||
if num_elements <= 0:
|
||||
return min_width
|
||||
|
||||
# Calculate width: log10(element count) + 2
|
||||
width = math.log10(num_elements) + 2
|
||||
|
||||
# Constrain width range (2.0 to 10.0)
|
||||
width = max(min_width, min(max_width, width))
|
||||
|
||||
return width
|
||||
|
||||
except (ValueError, TypeError):
|
||||
return min_width # Use default width on error
|
||||
|
||||
|
||||
def _get_edge_color(source_node_name: str, output_index: int) -> str:
|
||||
"""
|
||||
Assign color based on source node and actual output index
|
||||
Ensures all edges of the same tensor use the same color
|
||||
"""
|
||||
colors = [
|
||||
"#FF6B9D80", # Pink
|
||||
"#4ECDC480", # Mint green
|
||||
"#45B7D180", # Sky blue
|
||||
"#96CEB480", # Light green
|
||||
"#DDA0DD80", # Light purple
|
||||
"#98D8C880", # Teal
|
||||
"#BB8FCE80", # Violet
|
||||
"#85C1E980", # Light blue
|
||||
"#F8C47180", # Peach
|
||||
"#82E0AA80", # Mint
|
||||
"#F7DC6F80", # Lemon yellow
|
||||
"#AED6F180", # Light sky blue
|
||||
]
|
||||
|
||||
# Use combination of source_node + output_index
|
||||
# Ensures multiple edges of the same tensor have the same color
|
||||
tensor_id = f"{source_node_name}_{output_index}"
|
||||
color_index = hash(tensor_id) % len(colors)
|
||||
return colors[color_index]
|
||||
|
||||
|
||||
def _get_port_of_five(input_idx, total_inputs):
|
||||
"""
|
||||
0 xxxxxxxx a = 7
|
||||
1 xxxxxxxx b = 2
|
||||
2 xxxxxxx x = 2 * 8 = 16
|
||||
3 xxxxxxx
|
||||
4 xxxxxxx
|
||||
"""
|
||||
K = 5
|
||||
a, b = total_inputs // K, total_inputs % K
|
||||
x = b * (a + 1)
|
||||
if input_idx < x:
|
||||
return input_idx // (a + 1)
|
||||
else:
|
||||
return (input_idx - x) // a + b
|
||||
|
||||
|
||||
def _get_input_port(input_index: int, total_inputs: int) -> str:
|
||||
"""
|
||||
Get input port based on input index and total input count
|
||||
"""
|
||||
if total_inputs <= 1:
|
||||
return "n" # Single input uses default port
|
||||
elif total_inputs == 2:
|
||||
return "nw" if input_index == 0 else "ne" # Northwest, northeast
|
||||
elif total_inputs == 3:
|
||||
ports = ["nw", "n", "ne"] # Northwest, north, northeast
|
||||
return ports[input_index]
|
||||
elif total_inputs == 4:
|
||||
ports = ["nw", "n", "ne", "e"]
|
||||
return ports[input_index]
|
||||
else:
|
||||
# 5+ inputs: west, northwest, north, northeast, east in order
|
||||
ports = ["w", "nw", "n", "ne", "e"]
|
||||
# Cycle through ports for more than 5 inputs
|
||||
return ports[_get_port_of_five(input_index, total_inputs)]
|
||||
|
||||
|
||||
def _get_output_port(output_index: int, total_outputs: int) -> str:
|
||||
"""
|
||||
Get output port based on output index and total output count (symmetric to input, but on bottom)
|
||||
"""
|
||||
if total_outputs <= 1:
|
||||
return "s" # Single output uses default port
|
||||
elif total_outputs == 2:
|
||||
return "sw" if output_index == 0 else "se" # Southwest, southeast
|
||||
elif total_outputs == 3:
|
||||
ports = ["sw", "s", "se"] # Southwest, south, southeast
|
||||
return ports[output_index]
|
||||
else:
|
||||
# 4+ outputs: west, southwest, south, southeast, east in order
|
||||
ports = ["w", "sw", "s", "se", "e"]
|
||||
# Cycle through ports for more than 5 outputs
|
||||
return ports[_get_port_of_five(output_index, total_outputs)]
|
||||
|
||||
|
||||
def to_dot(
|
||||
graph_module: GraphModule,
|
||||
name: str,
|
||||
save_path: str,
|
||||
format: str = "svg",
|
||||
include_shapes: bool = True,
|
||||
) -> Optional["graphviz.Digraph"]:
|
||||
"""
|
||||
Convert PyTorch GraphModule to Graphviz diagram
|
||||
|
||||
Args:
|
||||
graph_module: GraphModule to visualize
|
||||
name: Name of the diagram
|
||||
save_path: Save path, if None uses name
|
||||
format: Output format ('png', 'pdf', 'svg', 'dot', etc.)
|
||||
include_shapes: Whether to include tensor shape information
|
||||
|
||||
Returns:
|
||||
graphviz.Digraph object
|
||||
"""
|
||||
# Create Graphviz diagram
|
||||
dot = graphviz.Digraph(
|
||||
name=name, comment=f"PyTorch GraphModule: {graph_module.__class__.__name__}", format=format
|
||||
)
|
||||
|
||||
# Set graph attributes
|
||||
dot.attr(rankdir="TB") # Top to bottom
|
||||
dot.attr("node", shape="box", style="rounded,filled", height="0.2")
|
||||
# Remove default edge color, let each edge use its own color
|
||||
|
||||
# Node style configuration
|
||||
node_styles = {
|
||||
"placeholder": {"fillcolor": "lightgreen", "shape": "box"},
|
||||
"get_attr": {"fillcolor": "lightcyan", "shape": "box"},
|
||||
"call_function": {"fillcolor": "lightblue", "shape": "box"},
|
||||
"call_method": {"fillcolor": "lightyellow", "shape": "box"},
|
||||
"call_module": {"fillcolor": "lightpink", "shape": "box"},
|
||||
"output": {"fillcolor": "lightcoral", "shape": "box"},
|
||||
}
|
||||
|
||||
# Analyze graph structure
|
||||
graph = graph_module.graph
|
||||
nodes = list(graph.nodes)
|
||||
node_labels = {}
|
||||
|
||||
# Process each node
|
||||
for node in nodes:
|
||||
# Get basic node information
|
||||
node_name = node.name
|
||||
op_type = node.op
|
||||
|
||||
# Create node label
|
||||
label = _get_node_label(graph_module, node)
|
||||
node_labels[node_name] = label
|
||||
|
||||
# Set node style
|
||||
style = node_styles.get(op_type, {"fillcolor": "white", "shape": "box"})
|
||||
|
||||
# Add node to diagram
|
||||
node_attrs = {
|
||||
"label": label,
|
||||
"fillcolor": style["fillcolor"],
|
||||
"shape": style["shape"],
|
||||
"tooltip": node_name, # Use node name as tooltip
|
||||
}
|
||||
# If the node has no value, set the fillcolor to red
|
||||
if "val" not in node.meta:
|
||||
node_attrs["fillcolor"] = "red"
|
||||
elif isinstance(node.meta["val"], torch.Tensor):
|
||||
node_attrs["label"] += "\n" + str(node.meta["val"].device)
|
||||
elif isinstance(node.meta["val"], (list, tuple)):
|
||||
for val in node.meta["val"]:
|
||||
if isinstance(val, torch.Tensor):
|
||||
node_attrs["label"] += "\n" + str(val.device)
|
||||
else:
|
||||
node_attrs["label"] += "\n" + str(val)
|
||||
|
||||
dot.node(node_name, **node_attrs)
|
||||
|
||||
# First collect all edge information
|
||||
edges = [] # Format: (source_node, target_node, val, source_output_index, target_input_index)
|
||||
node_inputs = {} # Input list for each node: [(source_node_name, output_index)]
|
||||
node_outputs = {} # Output list for each node: [(target_node_name, input_index)]
|
||||
|
||||
# Initialize
|
||||
for node in nodes:
|
||||
node_inputs[node.name] = []
|
||||
node_outputs[node.name] = []
|
||||
|
||||
# Collect edge information
|
||||
def _add_edge_from_node_with_index(input_idx, source_node: fx.Node, target_node: fx.Node):
|
||||
"""Add edge from source_node to target_node, automatically determining correct output index"""
|
||||
# Extract val (FakeTensor or Tensor) from node.meta["val"]
|
||||
# This is more reliable than tensor_meta since FakeTensorProp always populates val
|
||||
val = None
|
||||
if include_shapes and hasattr(source_node, "meta") and "val" in source_node.meta:
|
||||
val = source_node.meta["val"]
|
||||
|
||||
# Calculate indices
|
||||
source_output_index = _determine_output_index(source_node, target_node)
|
||||
|
||||
# Add edge and update indices (store tuple containing node name and corresponding index)
|
||||
edges.append((source_node.name, target_node.name, val, source_output_index, input_idx))
|
||||
node_inputs[target_node.name].append((source_node.name, source_output_index))
|
||||
node_outputs[source_node.name].append((target_node.name, input_idx))
|
||||
|
||||
def _determine_output_index(source_node: fx.Node, target_node: fx.Node) -> int:
|
||||
"""Determine the output index for the edge from source_node to target_node"""
|
||||
# Check if target_node is a getitem operation
|
||||
if (
|
||||
target_node.op == "call_function"
|
||||
and hasattr(target_node.target, "__name__")
|
||||
and target_node.target.__name__ == "getitem"
|
||||
and len(target_node.args) >= 2
|
||||
):
|
||||
# target_node.args[0] is source node, target_node.args[1] is index
|
||||
if target_node.args[0] == source_node and isinstance(target_node.args[1], int):
|
||||
return target_node.args[1]
|
||||
|
||||
# By default, FX nodes have only one output
|
||||
return 0
|
||||
|
||||
def _fix_problematic_negative_one(text):
|
||||
"""Fix the problematic 9223372036854775807 number back to -1"""
|
||||
return text.replace("9223372036854775807", "-1")
|
||||
|
||||
def _format_constant_label(value):
|
||||
"""Format constant value for display as node label"""
|
||||
# Special case: direct problematic integer
|
||||
if isinstance(value, int) and value == 9223372036854775807: # 2^63 - 1
|
||||
return "-1"
|
||||
|
||||
# Handle different value types
|
||||
if isinstance(value, (int, float, bool, str)):
|
||||
result = str(value)
|
||||
elif hasattr(value, "shape"):
|
||||
# Tensor-like object with shape
|
||||
if hasattr(value, "numel") and value.numel() > 6:
|
||||
return f"shape={tuple(value.shape)}"
|
||||
else:
|
||||
result = str(value)
|
||||
elif isinstance(value, (list, tuple)):
|
||||
# Collection of elements
|
||||
if len(value) > 6:
|
||||
return f"length={len(value)}"
|
||||
else:
|
||||
result = str(value)
|
||||
else:
|
||||
# Other types
|
||||
result = str(value)[:50] + ("..." if len(str(value)) > 50 else "")
|
||||
|
||||
# Apply the fix to any string representation
|
||||
return _fix_problematic_negative_one(result)
|
||||
|
||||
# Store constants to be created later
|
||||
constants_to_create = []
|
||||
|
||||
def _process_arg(input_idx, arg, target_node: fx.Node):
|
||||
"""Process a single argument and create appropriate edges in the graph.
|
||||
|
||||
Handles three types of arguments:
|
||||
- fx.Node: Creates an edge from the source node to the target node.
|
||||
- Container (list/tuple): Iterates through and creates edges for any fx.Node elements.
|
||||
- Constant: Creates a constant node and adds it to the graph with appropriate edges.
|
||||
|
||||
Args:
|
||||
input_idx: The index of this argument in the target node's input list.
|
||||
arg: The argument to process. Can be an fx.Node, a container (list/tuple),
|
||||
or a constant value.
|
||||
target_node: The fx.Node that receives this argument as input.
|
||||
|
||||
Note:
|
||||
This function modifies external state including `constants_to_create`,
|
||||
`node_inputs`, `node_outputs`, and `edges`.
|
||||
"""
|
||||
if isinstance(arg, fx.Node):
|
||||
_add_edge_from_node_with_index(input_idx, arg, target_node)
|
||||
elif isinstance(arg, (list, tuple)):
|
||||
for sub_arg in arg:
|
||||
if isinstance(sub_arg, fx.Node):
|
||||
_add_edge_from_node_with_index(input_idx, sub_arg, target_node)
|
||||
else:
|
||||
# This is a constant value - store for later creation
|
||||
const_node_name = f"const_{target_node.name}_{input_idx}"
|
||||
constants_to_create.append((const_node_name, arg, target_node.name, input_idx))
|
||||
|
||||
# Add to node tracking immediately
|
||||
node_inputs[target_node.name].append(const_node_name)
|
||||
if const_node_name not in node_outputs:
|
||||
node_outputs[const_node_name] = []
|
||||
node_outputs[const_node_name].append((target_node.name, input_idx))
|
||||
|
||||
# Create edge from constant to target
|
||||
edges.append((const_node_name, target_node.name, None, 0, input_idx))
|
||||
|
||||
def _determine_node_output_count(node_name: str) -> int:
|
||||
"""Determine the actual output count of a node (applies to all node types)"""
|
||||
# Check if any getitem operations use this node
|
||||
max_index = -1
|
||||
for n in nodes:
|
||||
# Only check getitem operation pattern, don't restrict source node op type
|
||||
if (
|
||||
n.op == "call_function"
|
||||
and hasattr(n.target, "__name__")
|
||||
and n.target.__name__ == "getitem"
|
||||
and len(n.args) >= 2
|
||||
and hasattr(n.args[0], "name")
|
||||
and n.args[0].name == node_name
|
||||
and isinstance(n.args[1], int)
|
||||
):
|
||||
max_index = max(max_index, n.args[1])
|
||||
|
||||
# If getitem operations found, output count is max_index + 1
|
||||
if max_index >= 0:
|
||||
return max_index + 1
|
||||
|
||||
# By default, nodes have only one output
|
||||
return 1
|
||||
|
||||
# Traverse all nodes to collect edges
|
||||
for node in nodes:
|
||||
if not node.args:
|
||||
continue
|
||||
|
||||
for input_idx, arg in enumerate(node.args):
|
||||
_process_arg(input_idx, arg, node)
|
||||
|
||||
# Print 10 nodes with most inputs
|
||||
# ad_logger.debug("Nodes with most inputs:")
|
||||
# node_inputs_sorted = sorted(node_inputs.items(), key=lambda x: len(x[1]), reverse=True)
|
||||
# for node_name, input_list in node_inputs_sorted[:10]:
|
||||
# ad_logger.debug(f" {node_name}: {len(input_list)}")
|
||||
|
||||
# Print 10 nodes with most outputs
|
||||
node_outputs_sorted = sorted(node_outputs.items(), key=lambda x: len(x[1]), reverse=True)
|
||||
# ad_logger.debug("Nodes with most outputs:")
|
||||
large_fanout_nodes: Dict[str, int] = {}
|
||||
for node_name, output_list in node_outputs_sorted[:10]:
|
||||
if len(output_list) > 10:
|
||||
large_fanout_nodes[node_name] = 0
|
||||
# ad_logger.debug(f" {node_name}: {len(output_list)}")
|
||||
|
||||
# Overwrite large fanout nodes style
|
||||
for node_name in large_fanout_nodes:
|
||||
# Add node to diagram
|
||||
dot.node(
|
||||
node_name,
|
||||
label=node_name,
|
||||
fillcolor="#ffdddd80",
|
||||
color="#88888880",
|
||||
shape="box",
|
||||
style="filled,dashed,rounded",
|
||||
tooltip=node_name, # Use node name as tooltip
|
||||
)
|
||||
|
||||
node_inputs_sorted = sorted(node_inputs.items(), key=lambda x: len(x[1]), reverse=True)
|
||||
large_fanin_nodes: Dict[str, int] = {}
|
||||
for node_name, input_list in node_inputs_sorted[:10]:
|
||||
if len(input_list) > 12:
|
||||
large_fanin_nodes[node_name] = 0
|
||||
|
||||
for node_name in large_fanin_nodes:
|
||||
# Add node to diagram
|
||||
dot.node(
|
||||
node_name,
|
||||
label=node_name,
|
||||
fillcolor="#ddffdd80",
|
||||
color="#88888880",
|
||||
shape="box",
|
||||
style="filled,dashed,rounded",
|
||||
tooltip=node_name, # Use node name as tooltip
|
||||
)
|
||||
|
||||
# Create constant nodes
|
||||
for const_node_name, const_value, target_name, input_idx in constants_to_create:
|
||||
const_label = _format_constant_label(const_value)
|
||||
|
||||
# Add constant node to dot graph
|
||||
const_attrs = {
|
||||
"label": const_label,
|
||||
"shape": "box",
|
||||
"style": "rounded,filled",
|
||||
"fillcolor": "#ffffcc", # Light yellow background
|
||||
"color": "#cccccc", # Light gray border
|
||||
"width": "0.2",
|
||||
"fontsize": "10",
|
||||
}
|
||||
dot.node(const_node_name, **const_attrs)
|
||||
|
||||
# Add edges with ports and colors
|
||||
for source_name, target_name, val, source_output_index, target_input_index in edges:
|
||||
edge_attrs = {}
|
||||
|
||||
# Calculate ports (for graphical display positioning)
|
||||
input_list = node_inputs[target_name]
|
||||
|
||||
# Use actual output count, not usage count
|
||||
source_output_count = _determine_node_output_count(source_name)
|
||||
|
||||
input_port = _get_input_port(target_input_index, len(input_list))
|
||||
output_port = _get_output_port(source_output_index, source_output_count)
|
||||
|
||||
# Build node names with ports
|
||||
source_name_port = f"{source_name}:{output_port}" if output_port else source_name
|
||||
target_name_port = f"{target_name}:{input_port}" if input_port else target_name
|
||||
|
||||
# Set edge color (based on actual output_index)
|
||||
edge_color = _get_edge_color(source_name, source_output_index)
|
||||
edge_attrs["color"] = edge_color
|
||||
|
||||
# Add tensor shape and width information
|
||||
# val can be FakeTensor, Tensor, or other types with .shape attribute
|
||||
if val is not None and include_shapes and hasattr(val, "shape"):
|
||||
shape_str = str(tuple(val.shape))
|
||||
# Add dtype if available
|
||||
dtype_str = ""
|
||||
if hasattr(val, "dtype"):
|
||||
dtype_str = str(val.dtype).replace("torch.", "")
|
||||
edge_attrs["xlabel"] = f"{shape_str}\n{dtype_str}" if dtype_str else shape_str
|
||||
edge_attrs["fontsize"] = "10"
|
||||
edge_attrs["fontcolor"] = "blue"
|
||||
|
||||
# Calculate edge width based on element count
|
||||
width = _calculate_edge_width(val)
|
||||
edge_attrs["penwidth"] = str(width)
|
||||
|
||||
# For those large fanout nodes, large fantout nodes will stuck the graphviz layout algorithm.
|
||||
# So we need to duplicate the node, so each edge has its own source node.
|
||||
# Make the layout algorithm work.
|
||||
# So is large fanin nodes.
|
||||
if source_name in large_fanout_nodes:
|
||||
node_attrs = {
|
||||
"fillcolor": "#ddffdd80",
|
||||
"color": "#88888880",
|
||||
"shape": "box",
|
||||
"style": "filled,dashed,rounded",
|
||||
}
|
||||
large_fanout_nodes[source_name] += 1
|
||||
source_name_port = source_name + f"___{large_fanout_nodes[source_name]}"
|
||||
dot.node(
|
||||
name=source_name_port,
|
||||
label=source_name,
|
||||
**node_attrs,
|
||||
)
|
||||
|
||||
if target_name in large_fanin_nodes:
|
||||
node_attrs = {
|
||||
"fillcolor": "#ddffdd80",
|
||||
"color": "#88888880",
|
||||
"shape": "box",
|
||||
"style": "filled,dashed,rounded",
|
||||
}
|
||||
large_fanin_nodes[target_name] += 1
|
||||
target_name_port = target_name + f"___{large_fanin_nodes[target_name]}"
|
||||
dot.node(name=target_name_port, label=target_name, **node_attrs)
|
||||
|
||||
dot.edge(source_name_port, target_name_port, **edge_attrs)
|
||||
|
||||
# Save diagram
|
||||
try:
|
||||
dot.render(save_path, cleanup=True)
|
||||
ad_logger.info(f"Diagram saved: {save_path}.{format}")
|
||||
with open(save_path + ".txt", "w") as f:
|
||||
f.write(str(graph_module.graph))
|
||||
except Exception as e:
|
||||
ad_logger.error(f"Failed to save diagram: {e}")
|
||||
|
||||
return dot
|
||||
|
||||
|
||||
def analyze_graph_structure(graph_module: GraphModule) -> Dict[str, Any]:
|
||||
"""
|
||||
Analyze structural statistics of GraphModule
|
||||
|
||||
Args:
|
||||
graph_module: GraphModule to analyze
|
||||
|
||||
Returns:
|
||||
Dictionary containing structural statistics
|
||||
"""
|
||||
graph = graph_module.graph
|
||||
nodes = list(graph.nodes)
|
||||
|
||||
# Count node types
|
||||
op_counts = {}
|
||||
for node in nodes:
|
||||
op_type = node.op
|
||||
op_counts[op_type] = op_counts.get(op_type, 0) + 1
|
||||
|
||||
# Analyze connections
|
||||
total_connections = 0
|
||||
for node in nodes:
|
||||
if node.args:
|
||||
for arg in node.args:
|
||||
if isinstance(arg, fx.Node):
|
||||
total_connections += 1
|
||||
elif isinstance(arg, (list, tuple)):
|
||||
for sub_arg in arg:
|
||||
if isinstance(sub_arg, fx.Node):
|
||||
total_connections += 1
|
||||
|
||||
# Calculate graph complexity
|
||||
complexity_score = len(nodes) + total_connections
|
||||
|
||||
return {
|
||||
"total_nodes": len(nodes),
|
||||
"node_types": op_counts,
|
||||
"total_connections": total_connections,
|
||||
"complexity_score": complexity_score,
|
||||
"graph_depth": _calculate_graph_depth(nodes),
|
||||
}
|
||||
|
||||
|
||||
def _get_node_label(graph_module: GraphModule, node: fx.Node) -> str:
|
||||
"""Get node label"""
|
||||
if node.op == "call_function":
|
||||
func_name = _get_function_name(node.target)
|
||||
tokens = func_name.split(".")
|
||||
assert len(tokens) <= 2, f"Function name {func_name} has more than 2 tokens"
|
||||
label = tokens[0] if tokens[0] != "to" else func_name
|
||||
elif node.op == "call_method":
|
||||
label = str(node.target)
|
||||
elif node.op == "call_module":
|
||||
label = _get_module_name(graph_module, node.target)
|
||||
elif node.op == "get_attr":
|
||||
attr_name = str(node.target).split(".")[-1] if "." in str(node.target) else str(node.target)
|
||||
label = attr_name
|
||||
elif node.op == "placeholder":
|
||||
label = "ph: " + str(node.name)
|
||||
elif node.op == "output":
|
||||
label = "out: " + str(node.name)
|
||||
else:
|
||||
label = node.op
|
||||
return label
|
||||
|
||||
|
||||
def _get_function_name(func) -> str:
|
||||
"""Get simplified function name"""
|
||||
if hasattr(func, "__name__"):
|
||||
return func.__name__
|
||||
|
||||
func_str = str(func)
|
||||
|
||||
# Handle torch functions
|
||||
if "torch." in func_str:
|
||||
match = re.search(r"torch\.(\w+)\.(\w+)", func_str)
|
||||
if match:
|
||||
return f"{match.group(1)}.{match.group(2)}"
|
||||
|
||||
match = re.search(r"torch\.(\w+)", func_str)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
# Handle built-in functions
|
||||
if "built-in" in func_str:
|
||||
match = re.search(r"'(\w+)'", func_str)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
return str(func).split(".")[-1] if "." in str(func) else str(func)
|
||||
|
||||
|
||||
def _get_module_name(graph_module: GraphModule, target) -> str:
|
||||
"""Extract module name, handle numeric indices in Sequential"""
|
||||
try:
|
||||
# Try to get actual module type name
|
||||
actual_module = graph_module.get_submodule(str(target))
|
||||
module_type = actual_module.__class__.__name__
|
||||
|
||||
# Extract the last part of module name
|
||||
module_name = str(target).split(".")[-1] if "." in str(target) else str(target)
|
||||
|
||||
# If it's numeric index (like modules in Sequential), show type name
|
||||
if module_name.isdigit():
|
||||
return module_type
|
||||
else:
|
||||
return module_name
|
||||
except Exception:
|
||||
# If unable to get module, fall back to original logic
|
||||
module_name = str(target).split(".")[-1] if "." in str(target) else str(target)
|
||||
return module_name
|
||||
|
||||
|
||||
def _calculate_graph_depth(nodes) -> int:
|
||||
"""Calculate maximum depth of the graph"""
|
||||
# Build dependency relationships
|
||||
dependencies = {}
|
||||
for node in nodes:
|
||||
dependencies[node.name] = []
|
||||
if node.args:
|
||||
for arg in node.args:
|
||||
if isinstance(arg, fx.Node):
|
||||
dependencies[node.name].append(arg.name)
|
||||
elif isinstance(arg, (list, tuple)):
|
||||
for sub_arg in arg:
|
||||
if isinstance(sub_arg, fx.Node):
|
||||
dependencies[node.name].append(sub_arg.name)
|
||||
|
||||
# Calculate depth of each node
|
||||
depths = {}
|
||||
|
||||
def calculate_depth(node_name):
|
||||
if node_name in depths:
|
||||
return depths[node_name]
|
||||
|
||||
if not dependencies[node_name]:
|
||||
depths[node_name] = 0
|
||||
return 0
|
||||
|
||||
max_dep_depth = max(calculate_depth(dep) for dep in dependencies[node_name])
|
||||
depths[node_name] = max_dep_depth + 1
|
||||
return depths[node_name]
|
||||
|
||||
for node in nodes:
|
||||
calculate_depth(node.name)
|
||||
|
||||
return max(depths.values()) if depths else 0
|
||||
@ -3,14 +3,16 @@
|
||||
This module defines the base classes and interfaces for all transforms.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from abc import ABC
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import total_ordering, wraps
|
||||
from typing import Any, Callable, Dict, Mapping, Tuple, Type, Union, final
|
||||
from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type, Union, final
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pydantic import BaseModel, Field
|
||||
from torch.fx import GraphModule, Node
|
||||
@ -27,6 +29,7 @@ from ..utils._graph import (
|
||||
)
|
||||
from ..utils.cuda_mem_tracker import get_mem_info
|
||||
from ..utils.logger import ad_logger
|
||||
from .graph_module_visualizer import to_dot
|
||||
|
||||
# ANSI color codes for log formatting (set to False to disable colors)
|
||||
# NOTE: colors disabled by default to make logging in CI/CD pipelines easier to read
|
||||
@ -102,6 +105,7 @@ class Stages(Enum):
|
||||
POST_LOAD_FUSION = "post_load_fusion" # post-loading fusion and perf optimizations of the graph
|
||||
CACHE_INIT = "cache_init" # initialization of cached attention + (KV) cache initialization
|
||||
VISUALIZE = "visualize" # visualization of the graph
|
||||
EXPORT_ONNX = "export_onnx" # export the graph to onnx
|
||||
COMPILE = "compile" # graph compilation stage using low-level compilers like torch.compile
|
||||
|
||||
def __lt__(self, other):
|
||||
@ -166,6 +170,11 @@ class TransformConfig(BaseModel):
|
||||
default=False,
|
||||
description="Whether this transform requires shape propagation before it is applied.",
|
||||
)
|
||||
debug_visualize_dir: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Debug visualization directory. None to disable visualization, "
|
||||
"or a path string to specify the output directory.",
|
||||
)
|
||||
|
||||
expect_mem_change: bool = Field(
|
||||
default=False,
|
||||
@ -257,6 +266,7 @@ def with_transform_logging(call_fn: Callable) -> Callable:
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
idx: int,
|
||||
) -> nn.Module:
|
||||
prefix = f"[stage={self.config.stage.value}, transform={self.get_transform_key()}]"
|
||||
original_log = ad_logger.log
|
||||
@ -268,7 +278,7 @@ def with_transform_logging(call_fn: Callable) -> Callable:
|
||||
|
||||
ad_logger.log = _patched_log # type: ignore[assignment]
|
||||
try:
|
||||
return call_fn(self, gm, cm, factory, shared_config)
|
||||
return call_fn(self, gm, cm, factory, shared_config, idx)
|
||||
finally:
|
||||
ad_logger.log = original_log # type: ignore[assignment]
|
||||
|
||||
@ -346,6 +356,7 @@ class BaseTransform(ABC):
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
idx: int,
|
||||
) -> nn.Module:
|
||||
"""Apply the transform to the graph.
|
||||
|
||||
@ -354,6 +365,7 @@ class BaseTransform(ABC):
|
||||
cm: The cached sequence interface defining the sequence interface.
|
||||
factory: The model factory used to build the model.
|
||||
shared_config: Global info shared between multiple transforms.
|
||||
idx: The index of the transform in the pipeline.
|
||||
|
||||
Returns:
|
||||
nn.Module: The transformed model.
|
||||
@ -473,10 +485,33 @@ class BaseTransform(ABC):
|
||||
autodeploy_meta[self._mem_history_key] = mem_history
|
||||
|
||||
self._set_autodeploy_meta(mod, autodeploy_meta)
|
||||
self._visualize_graph(mod, idx)
|
||||
|
||||
# return the graph module
|
||||
return mod
|
||||
|
||||
@final
|
||||
def _visualize_graph(self, mod: nn.Module, idx: int) -> None:
|
||||
"""Visualize the graph if debug visualization is enabled.
|
||||
Args:
|
||||
mod: The graph module to visualize.
|
||||
idx: The index of the transform in the pipeline.
|
||||
Note:
|
||||
we may want to consider doing this for each subgraph.
|
||||
See https://github.com/NVIDIA/TensorRT-LLM/issues/10203
|
||||
"""
|
||||
if not isinstance(mod, torch.fx.GraphModule):
|
||||
return
|
||||
visualize_dir = self.config.debug_visualize_dir
|
||||
if not visualize_dir:
|
||||
return
|
||||
if not os.path.exists(visualize_dir):
|
||||
os.makedirs(visualize_dir)
|
||||
name_stem = f"gm_{idx + 1:02d}_{self.get_transform_key()}"
|
||||
visualize_path = os.path.join(visualize_dir, f"{name_stem}")
|
||||
to_dot(mod, name=name_stem, save_path=visualize_path, format="svg")
|
||||
ad_logger.debug(f"[{idx + 1:02d}] Visualized {name_stem} to {visualize_path}")
|
||||
|
||||
@final
|
||||
def _apply_per_gm_or_whole_model(
|
||||
self,
|
||||
|
||||
@ -0,0 +1,382 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
For EdgeLLM API. Processes the chat template to create a JSON file with
|
||||
chat template data for the following:
|
||||
|
||||
Roles:
|
||||
- System
|
||||
- User
|
||||
- Assistant
|
||||
|
||||
Messages:
|
||||
- Role
|
||||
- Content
|
||||
- Type
|
||||
- text
|
||||
- image
|
||||
- video
|
||||
|
||||
The JSON file is saved to the exported ONNX model directory.
|
||||
|
||||
This implementation uses the HF tokenizer's apply_chat_template method with test cases
|
||||
to extract the actual prefix/suffix patterns used by the model, rather than trying
|
||||
to parse the Jinja template directly.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from transformers import AutoConfig, AutoProcessor, AutoTokenizer
|
||||
|
||||
from ...utils.logger import ad_logger
|
||||
|
||||
|
||||
def is_vlm(model_dir: str) -> bool:
|
||||
"""Check if the model is a VLM."""
|
||||
cfg = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
|
||||
cfg_dict = cfg.to_dict()
|
||||
has_vision = "vision_config" in cfg_dict
|
||||
has_phi4_vision = "image_embd_layer" in cfg_dict.get("embd_layer", {})
|
||||
if has_vision or has_phi4_vision:
|
||||
ad_logger.debug("Set use_prompt_tuning to True")
|
||||
return True
|
||||
else:
|
||||
ad_logger.debug("Set use_prompt_tuning to False")
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
role: str
|
||||
content: Union[str, List[Dict[str, str]]] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemMessage(Message):
|
||||
role: str = "system"
|
||||
content: str = "<placeholder_system_prompt>"
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserMessage(Message):
|
||||
role: str = "user"
|
||||
content: str = "<placeholder_user_text>"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultimodalUserMessage(Message):
|
||||
role: str = "user"
|
||||
content: List[Dict[str, str]] = field(
|
||||
default_factory=lambda: [{"type": "text", "text": "<placeholder_user_text>"}]
|
||||
)
|
||||
|
||||
def add_text_content(self, text: str):
|
||||
self.content.append({"type": "text", "text": text})
|
||||
|
||||
def add_image_content(self, image: str):
|
||||
self.content.append({"type": "image", "image": image})
|
||||
|
||||
def add_video_content(self, video: str):
|
||||
self.content.append({"type": "video", "video": video})
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssistantMessage(Message):
|
||||
role: str = "assistant"
|
||||
content: str = "<placeholder_assistant_text>"
|
||||
# TODO: Add tool calling
|
||||
|
||||
|
||||
# TODO: Add ToolMessage
|
||||
|
||||
|
||||
def _format_messages(
|
||||
tokenizer: Any, messages: List[Message], add_generation_prompt: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Format the messages using the tokenizer's chat template.
|
||||
|
||||
Args:
|
||||
tokenizer: HuggingFace loaded tokenizer
|
||||
messages: List of messages
|
||||
add_generation_prompt: Whether to add generation prompt
|
||||
|
||||
Returns:
|
||||
Formatted text
|
||||
|
||||
Raises:
|
||||
ValueError: If unable to format messages
|
||||
"""
|
||||
try:
|
||||
# Convert dataclass messages to dictionaries using asdict
|
||||
message_dicts = [asdict(msg) for msg in messages]
|
||||
|
||||
return tokenizer.apply_chat_template(
|
||||
message_dicts, tokenize=False, add_generation_prompt=add_generation_prompt
|
||||
)
|
||||
except Exception:
|
||||
# Try fallback: convert list content to string for tokenizers that don't support multimodal
|
||||
try:
|
||||
message_dicts = []
|
||||
for msg in messages:
|
||||
content = msg.content
|
||||
# If content is a list, extract the first text element
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
content = item.get("text", "")
|
||||
break
|
||||
message_dicts.append({"role": msg.role, "content": content})
|
||||
|
||||
return tokenizer.apply_chat_template(
|
||||
message_dicts, tokenize=False, add_generation_prompt=add_generation_prompt
|
||||
)
|
||||
except Exception as e2:
|
||||
raise ValueError(
|
||||
f"Unable to format messages using HuggingFace tokenizer's apply_chat_template method."
|
||||
f"Messages need to be in the format: role: <str>, content: <str|list of dicts>. "
|
||||
f"Check INPUT_FORMAT.md for more details."
|
||||
f"Error: {e2}"
|
||||
) from e2
|
||||
|
||||
|
||||
def _extract_prefix_suffix(text: str, placeholder: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Extract prefix and suffix from the differential text by finding the placeholder content.
|
||||
|
||||
Args:
|
||||
text: The text to extract the prefix and suffix from
|
||||
placeholder : The placeholder content to search for in the formatted output
|
||||
|
||||
Returns:
|
||||
Tuple of (prefix, suffix) strings
|
||||
"""
|
||||
content_start = text.find(placeholder)
|
||||
|
||||
if content_start == -1:
|
||||
return "", ""
|
||||
|
||||
prefix = text[:content_start]
|
||||
suffix = text[content_start + len(placeholder) :]
|
||||
|
||||
return prefix, suffix
|
||||
|
||||
|
||||
def _extract_content_pattern(
|
||||
tokenizer: Any,
|
||||
system_prompt: SystemMessage,
|
||||
content_type: str,
|
||||
placeholder: str,
|
||||
text_only_formatted: str,
|
||||
placeholder_text: str,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Extract the pattern for a specific content type (image/video) by comparing
|
||||
with text-only message.
|
||||
|
||||
Args:
|
||||
tokenizer: The loaded tokenizer
|
||||
system_prompt: System message to use
|
||||
content_type: Type of content ('image' or 'video')
|
||||
placeholder: Placeholder string for the content
|
||||
text_only_formatted: Formatted text-only message
|
||||
placeholder_text: The text placeholder used
|
||||
|
||||
Returns:
|
||||
Extracted pattern string or None if failed or tokenizer does not support multimodal content
|
||||
"""
|
||||
# Create user message with the content type
|
||||
user_with_content = MultimodalUserMessage()
|
||||
if content_type == "image":
|
||||
user_with_content.add_image_content(placeholder)
|
||||
elif content_type == "video":
|
||||
user_with_content.add_video_content(placeholder)
|
||||
else:
|
||||
return None
|
||||
|
||||
with_content_formatted = _format_messages(tokenizer, [system_prompt, user_with_content])
|
||||
|
||||
# Extract the differential - what was added for this content type
|
||||
if placeholder_text in text_only_formatted and placeholder_text in with_content_formatted:
|
||||
# Find position after the placeholder in both
|
||||
text_pos = text_only_formatted.find(placeholder_text) + len(placeholder_text)
|
||||
content_pos = with_content_formatted.find(placeholder_text) + len(placeholder_text)
|
||||
|
||||
# Get what comes after the placeholder in both
|
||||
text_only_suffix = text_only_formatted[text_pos:]
|
||||
with_content_suffix = with_content_formatted[content_pos:]
|
||||
|
||||
# The pattern is what was added (the difference)
|
||||
if text_only_suffix and with_content_suffix.endswith(text_only_suffix):
|
||||
pattern = with_content_suffix[: -len(text_only_suffix)]
|
||||
else:
|
||||
pattern = with_content_suffix
|
||||
|
||||
# Strip dynamic prefixes like "Image 1:" or "Video 1:"
|
||||
pattern = re.sub(rf"^{content_type.capitalize()} \d+:\s*", "", pattern)
|
||||
|
||||
return pattern if pattern else None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def process_chat_template(model_dir: str, output_dir: str) -> None:
|
||||
"""
|
||||
Process the chat template from model's tokenizer and create a JSON file
|
||||
with parsed template information.
|
||||
|
||||
This function uses the tokenizer's apply_chat_template method with various
|
||||
test cases to extract the actual prefix/suffix patterns.
|
||||
|
||||
Args:
|
||||
model_dir: Path to the model directory containing tokenizer files
|
||||
output_dir: Path to save the chat_template.json file
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
ad_logger.info(f"Processing chat template from {model_dir}")
|
||||
|
||||
tokenizer = None
|
||||
loaders = (
|
||||
[AutoProcessor, AutoTokenizer] if is_vlm(model_dir) else [AutoTokenizer, AutoProcessor]
|
||||
)
|
||||
for ldr in loaders:
|
||||
try:
|
||||
tokenizer = ldr.from_pretrained(model_dir, trust_remote_code=True)
|
||||
if getattr(tokenizer, "chat_template", None):
|
||||
ad_logger.debug(f"Successfully loaded chat template from {ldr.__name__}")
|
||||
break
|
||||
else:
|
||||
ad_logger.debug(f"{ldr.__name__} loaded but no chat template found")
|
||||
tokenizer = None
|
||||
except Exception as e:
|
||||
ad_logger.error(f"Failed to load {ldr.__name__}: {e}")
|
||||
tokenizer = None
|
||||
|
||||
if tokenizer is None:
|
||||
ad_logger.debug("Skipping chat template processing - no chat template available")
|
||||
return
|
||||
|
||||
ad_logger.debug("Extracting patterns from chat template...")
|
||||
|
||||
# Extract system role patterns (base case)
|
||||
system_prompt = SystemMessage()
|
||||
system_formatted = _format_messages(tokenizer, [system_prompt])
|
||||
system_prefix, system_suffix = _extract_prefix_suffix(system_formatted, system_prompt.content)
|
||||
|
||||
# Extract user role patterns (compare with system base)
|
||||
user_prompt = UserMessage()
|
||||
user_formatted = _format_messages(tokenizer, [system_prompt, user_prompt])
|
||||
user_prefix, user_suffix = _extract_prefix_suffix(
|
||||
user_formatted[len(system_formatted) :], user_prompt.content
|
||||
)
|
||||
|
||||
# Extract assistant role patterns (compare with user case)
|
||||
assistant_prompt = AssistantMessage()
|
||||
assistant_formatted = _format_messages(
|
||||
tokenizer, [system_prompt, user_prompt, assistant_prompt]
|
||||
)
|
||||
assistant_prefix, assistant_suffix = _extract_prefix_suffix(
|
||||
assistant_formatted[len(user_formatted) :], assistant_prompt.content
|
||||
)
|
||||
|
||||
# Extract generation prompt
|
||||
generation_formatted = _format_messages(
|
||||
tokenizer, [system_prompt, user_prompt], add_generation_prompt=True
|
||||
)
|
||||
generation_prompt = generation_formatted[len(user_formatted) :]
|
||||
|
||||
# Build content types
|
||||
content_types = {}
|
||||
|
||||
# Only extract multimodal patterns if this is a VLM model
|
||||
if is_vlm(model_dir):
|
||||
ad_logger.debug("Detected VLM model, extracting multimodal content patterns...")
|
||||
# Get base text-only formatted message for comparison
|
||||
user_text_only = MultimodalUserMessage()
|
||||
text_only_formatted = _format_messages(tokenizer, [system_prompt, user_text_only])
|
||||
placeholder_text = user_text_only.content[0]["text"]
|
||||
|
||||
# Extract image pattern
|
||||
image_pattern = _extract_content_pattern(
|
||||
tokenizer,
|
||||
system_prompt,
|
||||
"image",
|
||||
"<placeholder_image_path>",
|
||||
text_only_formatted,
|
||||
placeholder_text,
|
||||
)
|
||||
if image_pattern:
|
||||
content_types["image"] = {"format": image_pattern}
|
||||
|
||||
# Extract video pattern
|
||||
video_pattern = _extract_content_pattern(
|
||||
tokenizer,
|
||||
system_prompt,
|
||||
"video",
|
||||
"<placeholder_video_path>",
|
||||
text_only_formatted,
|
||||
placeholder_text,
|
||||
)
|
||||
if video_pattern:
|
||||
content_types["video"] = {"format": video_pattern}
|
||||
else:
|
||||
ad_logger.debug("Text-only LLM detected, skipping multimodal content pattern extraction")
|
||||
|
||||
# Extract default system prompt by testing without system message
|
||||
user_only_prompt = UserMessage()
|
||||
user_only_formatted = _format_messages(tokenizer, [user_only_prompt])
|
||||
|
||||
# Extract default system prompt
|
||||
default_system_prompt = ""
|
||||
# Check if a default system prompt was added
|
||||
# The system message should appear in user_only_formatted if there's a default
|
||||
system_start = user_only_formatted.find(system_prefix)
|
||||
if system_start != -1:
|
||||
# Extract the system content between prefix and suffix
|
||||
content_start = system_start + len(system_prefix)
|
||||
content_end = user_only_formatted.find(system_suffix, content_start)
|
||||
if content_end != -1:
|
||||
default_system_prompt = user_only_formatted[content_start:content_end]
|
||||
# Remove the placeholder if it appears
|
||||
if default_system_prompt == system_prompt.content:
|
||||
default_system_prompt = ""
|
||||
|
||||
# Build the final JSON structure
|
||||
chat_template_data = {
|
||||
"model_path": model_dir,
|
||||
"roles": {
|
||||
"system": {"prefix": system_prefix, "suffix": system_suffix},
|
||||
"user": {"prefix": user_prefix, "suffix": user_suffix},
|
||||
"assistant": {"prefix": assistant_prefix, "suffix": assistant_suffix},
|
||||
},
|
||||
"content_types": content_types,
|
||||
"generation_prompt": generation_prompt,
|
||||
"default_system_prompt": default_system_prompt,
|
||||
}
|
||||
|
||||
# Save to output directory
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_path = os.path.join(output_dir, "processed_chat_template.json")
|
||||
|
||||
with open(output_path, "w") as f:
|
||||
json.dump(chat_template_data, f, indent=2)
|
||||
|
||||
ad_logger.info(f"Chat template saved to {output_path}")
|
||||
@ -0,0 +1,205 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from ...utils.logger import ad_logger
|
||||
|
||||
EDGELLM_VERSION = "0.5.0.0"
|
||||
|
||||
|
||||
def _export_native_llm_config(config_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Export LLM configuration with required fields."""
|
||||
required_fields = [
|
||||
"vocab_size",
|
||||
"max_position_embeddings",
|
||||
"hidden_size",
|
||||
"intermediate_size",
|
||||
"num_hidden_layers",
|
||||
"num_attention_heads",
|
||||
"num_key_value_heads",
|
||||
"rope_theta",
|
||||
"rope_scaling",
|
||||
]
|
||||
|
||||
llm_config = {}
|
||||
for field in required_fields:
|
||||
if field not in config_dict:
|
||||
raise KeyError(f"Required field '{field}' not found in config")
|
||||
llm_config[field] = config_dict[field]
|
||||
|
||||
# Handle LongRoPE (rope_scaling already validated in required_fields)
|
||||
rope_scaling = config_dict["rope_scaling"]
|
||||
if rope_scaling and rope_scaling.get("type", None) == "longrope":
|
||||
if "original_max_position_embeddings" not in config_dict:
|
||||
raise KeyError("Required field 'original_max_position_embeddings' not found in config")
|
||||
llm_config["original_max_position_embeddings"] = config_dict[
|
||||
"original_max_position_embeddings"
|
||||
]
|
||||
|
||||
# Handle head_dim
|
||||
if "head_dim" in config_dict:
|
||||
llm_config["head_dim"] = config_dict["head_dim"]
|
||||
else:
|
||||
ad_logger.warning(
|
||||
"Warning: head_dim not found in config, calculating as hidden_size // num_attention_heads"
|
||||
)
|
||||
llm_config["head_dim"] = config_dict["hidden_size"] // config_dict["num_attention_heads"]
|
||||
|
||||
if "partial_rotary_factor" in config_dict:
|
||||
llm_config["partial_rotary_factor"] = config_dict["partial_rotary_factor"]
|
||||
else:
|
||||
llm_config["partial_rotary_factor"] = 1.0
|
||||
|
||||
llm_config["model_type"] = "llm"
|
||||
return llm_config
|
||||
|
||||
|
||||
def _export_eagle_base_config(config_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Export EAGLE base configuration with required fields."""
|
||||
required_fields = [
|
||||
"vocab_size",
|
||||
"max_position_embeddings",
|
||||
"hidden_size",
|
||||
"intermediate_size",
|
||||
"num_hidden_layers",
|
||||
"num_attention_heads",
|
||||
"num_key_value_heads",
|
||||
"rope_theta",
|
||||
"rope_scaling",
|
||||
]
|
||||
|
||||
eagle_config = {}
|
||||
for field in required_fields:
|
||||
if field not in config_dict:
|
||||
raise KeyError(f"Required field '{field}' not found in config")
|
||||
eagle_config[field] = config_dict[field]
|
||||
|
||||
# Handle head_dim
|
||||
if "head_dim" in config_dict:
|
||||
eagle_config["head_dim"] = config_dict["head_dim"]
|
||||
else:
|
||||
ad_logger.warning(
|
||||
"Warning: head_dim not found in config, calculating as hidden_size // num_attention_heads"
|
||||
)
|
||||
eagle_config["head_dim"] = config_dict["hidden_size"] // config_dict["num_attention_heads"]
|
||||
if "partial_rotary_factor" in config_dict:
|
||||
eagle_config["partial_rotary_factor"] = config_dict["partial_rotary_factor"]
|
||||
else:
|
||||
eagle_config["partial_rotary_factor"] = 1.0
|
||||
|
||||
eagle_config["model_type"] = "eagle3_base"
|
||||
return eagle_config
|
||||
|
||||
|
||||
def _export_eagle_draft_config(config_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Export EAGLE draft configuration with required fields."""
|
||||
required_fields = [
|
||||
"hidden_size",
|
||||
"max_position_embeddings",
|
||||
"intermediate_size",
|
||||
"num_hidden_layers",
|
||||
"num_attention_heads",
|
||||
"num_key_value_heads",
|
||||
"rope_theta",
|
||||
"rope_scaling",
|
||||
]
|
||||
|
||||
draft_config = {}
|
||||
for field in required_fields:
|
||||
if field not in config_dict:
|
||||
raise KeyError(f"Required field '{field}' not found in config")
|
||||
draft_config[field] = config_dict[field]
|
||||
|
||||
# Handle head_dim
|
||||
if "head_dim" in config_dict:
|
||||
draft_config["head_dim"] = config_dict["head_dim"]
|
||||
else:
|
||||
ad_logger.warning(
|
||||
"Warning: head_dim not found in config, calculating as hidden_size // num_attention_heads"
|
||||
)
|
||||
draft_config["head_dim"] = config_dict["hidden_size"] // config_dict["num_attention_heads"]
|
||||
|
||||
# Handle draft_vocab_size based on EAGLE version
|
||||
if "draft_vocab_size" not in config_dict:
|
||||
raise KeyError("Required field 'draft_vocab_size' not found in config")
|
||||
draft_config["draft_vocab_size"] = config_dict["draft_vocab_size"]
|
||||
|
||||
# Add base model configuration fields
|
||||
# The target_hidden_size from the model config represents the base model's hidden dimension
|
||||
if "target_hidden_size" in config_dict:
|
||||
# Use target_hidden_size * 3 as the base model hidden dimension (as per llm_export.py logic)
|
||||
draft_config["base_model_hidden_size"] = config_dict["target_hidden_size"] * 3
|
||||
else:
|
||||
# Fallback: assume base model hidden size is 3x draft model (Eagle3 default)
|
||||
draft_config["base_model_hidden_size"] = config_dict["hidden_size"] * 3
|
||||
ad_logger.warning(
|
||||
f"Warning: target_hidden_size not found, using default 3x draft hidden size: "
|
||||
f"{draft_config['base_model_hidden_size']}"
|
||||
)
|
||||
|
||||
# Set model_type for draft
|
||||
draft_config["model_type"] = "eagle3_draft"
|
||||
|
||||
return draft_config
|
||||
|
||||
|
||||
def export_vision_config(config: Any) -> Dict[str, Any]:
|
||||
"""Export vision configuration without modification."""
|
||||
config_dict = config.to_dict()
|
||||
|
||||
has_vision = "vision_config" in config_dict
|
||||
has_phi4_vision = "image_embd_layer" in config_dict.get("embd_layer", {})
|
||||
if not (has_vision or has_phi4_vision):
|
||||
raise KeyError(
|
||||
"Required field 'vision_config' or 'image_embd_layer' in 'embd_layer' not found in config"
|
||||
)
|
||||
# Add EdgeLLM API version
|
||||
config_dict["edgellm_version"] = EDGELLM_VERSION
|
||||
|
||||
# Return the original config_dict as-is without any modification
|
||||
# Since MRoPE needs LLM config, ViTRunner will use the LLM config.
|
||||
return config_dict
|
||||
|
||||
|
||||
def export_llm_config(config: Any, model_type: str) -> Dict[str, Any]:
|
||||
"""Export configuration based on model type and EAGLE version."""
|
||||
config_dict = config.to_dict()
|
||||
|
||||
# Extract model name from config class
|
||||
config_class_name = config.__class__.__name__
|
||||
model_name = config_class_name.lower().replace("config", "")
|
||||
|
||||
# For other model types, use text_config if available
|
||||
if "text_config" in config_dict:
|
||||
ad_logger.info("Detected multimodal model, using text_config")
|
||||
config_dict = config_dict["text_config"]
|
||||
|
||||
if model_type == "llm":
|
||||
output_config = _export_native_llm_config(config_dict)
|
||||
elif model_type == "eagle3_base":
|
||||
output_config = _export_eagle_base_config(config_dict)
|
||||
elif model_type == "eagle_draft":
|
||||
output_config = _export_eagle_draft_config(config_dict)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {model_type}")
|
||||
|
||||
# Add model name to output
|
||||
output_config["model"] = model_name
|
||||
|
||||
# Add EdgeLLM API version
|
||||
output_config["edgellm_version"] = EDGELLM_VERSION
|
||||
|
||||
return output_config
|
||||
@ -0,0 +1,258 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from onnx import defs
|
||||
from onnxscript.values import Opset
|
||||
|
||||
_TRT_DOMAIN_NAME = "trt"
|
||||
_AUTO_DEPLOY_DOMAIN_NAME = "auto_deploy"
|
||||
|
||||
# public opset objects, used in ONNX translation functions
|
||||
trt_opset = Opset(_TRT_DOMAIN_NAME, 1)
|
||||
auto_deploy_opset = Opset(_AUTO_DEPLOY_DOMAIN_NAME, 1)
|
||||
|
||||
|
||||
# ONNX Custom Op Registration for RoPE
|
||||
_torch_rope_with_explicit_cos_sin_schema = defs.OpSchema(
|
||||
name="rope_with_explicit_cos_sin",
|
||||
domain=_AUTO_DEPLOY_DOMAIN_NAME,
|
||||
since_version=1,
|
||||
doc="Rope with explicit cos and sin caches.",
|
||||
inputs=[
|
||||
defs.OpSchema.FormalParameter(
|
||||
name="q",
|
||||
description="Q tensor",
|
||||
type_str="T",
|
||||
),
|
||||
defs.OpSchema.FormalParameter(
|
||||
name="k",
|
||||
description="K tensor",
|
||||
type_str="T",
|
||||
),
|
||||
defs.OpSchema.FormalParameter(
|
||||
name="cos",
|
||||
description="Cos cache",
|
||||
type_str="T",
|
||||
),
|
||||
defs.OpSchema.FormalParameter(
|
||||
name="sin",
|
||||
description="Sin cache",
|
||||
type_str="T",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
defs.OpSchema.FormalParameter(
|
||||
name="output",
|
||||
description="Output tensor",
|
||||
type_str="T",
|
||||
)
|
||||
],
|
||||
type_constraints=[
|
||||
(
|
||||
"T",
|
||||
["tensor(float)", "tensor(float16)", "tensor(bfloat16)"],
|
||||
"Input and output data type.",
|
||||
),
|
||||
],
|
||||
attributes=[
|
||||
defs.OpSchema.Attribute(
|
||||
name="unsqueeze_dim",
|
||||
type=defs.OpSchema.AttrType.INT,
|
||||
description="Unsqueeze dimension. Must be 1 or 2.",
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# ONNX Custom Op Registration for AttentionPlugin
|
||||
_attention_plugin_schema = defs.OpSchema(
|
||||
name="AttentionPlugin",
|
||||
domain=_TRT_DOMAIN_NAME,
|
||||
since_version=1,
|
||||
doc="Fused RoPE + Attention operation for efficient inference.",
|
||||
inputs=[
|
||||
defs.OpSchema.FormalParameter(
|
||||
name="qkv",
|
||||
description="Concatenated Q, K, V tensors in shape [batch, seq_len, qkv_hidden_size]",
|
||||
type_str="T",
|
||||
),
|
||||
defs.OpSchema.FormalParameter(
|
||||
name="past_key_values",
|
||||
description="Concatenated past K and V cache in shape [batch, 2, num_kv_heads, past_len, head_size]",
|
||||
type_str="T",
|
||||
),
|
||||
defs.OpSchema.FormalParameter(
|
||||
name="context_lengths",
|
||||
description="Context lengths for each sequence in shape [batch]",
|
||||
type_str="T1",
|
||||
),
|
||||
defs.OpSchema.FormalParameter(
|
||||
name="rope_rotary_cos_sin",
|
||||
description="Concatenated cos and sin values for RoPE in shape [max_seq_len, head_dim]",
|
||||
type_str="T",
|
||||
),
|
||||
defs.OpSchema.FormalParameter(
|
||||
name="kvcache_start_index",
|
||||
description="KV cache start index for each sequence in shape [batch]",
|
||||
type_str="T1",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
defs.OpSchema.FormalParameter(
|
||||
name="output",
|
||||
description="Attention output in shape [batch, seq_len, hidden_size]",
|
||||
type_str="T",
|
||||
),
|
||||
defs.OpSchema.FormalParameter(
|
||||
name="present_key_values",
|
||||
description="Updated K and V cache",
|
||||
type_str="T",
|
||||
),
|
||||
],
|
||||
type_constraints=[
|
||||
(
|
||||
"T",
|
||||
["tensor(float16)", "tensor(float)", "tensor(bfloat16)"],
|
||||
"Input and output data type for floating point tensors.",
|
||||
),
|
||||
(
|
||||
"T1",
|
||||
["tensor(int32)", "tensor(int64)"],
|
||||
"Input data type for integer tensors.",
|
||||
),
|
||||
],
|
||||
attributes=[
|
||||
defs.OpSchema.Attribute(
|
||||
name="enable_tree_attention",
|
||||
type=defs.OpSchema.AttrType.INT,
|
||||
description="Whether to enable tree attention (0 or 1).",
|
||||
required=True,
|
||||
),
|
||||
defs.OpSchema.Attribute(
|
||||
name="head_size",
|
||||
type=defs.OpSchema.AttrType.INT,
|
||||
description="Size of each attention head.",
|
||||
required=True,
|
||||
),
|
||||
defs.OpSchema.Attribute(
|
||||
name="num_kv_heads",
|
||||
type=defs.OpSchema.AttrType.INT,
|
||||
description="Number of key-value heads.",
|
||||
required=True,
|
||||
),
|
||||
defs.OpSchema.Attribute(
|
||||
name="num_q_heads",
|
||||
type=defs.OpSchema.AttrType.INT,
|
||||
description="Number of query heads.",
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# ONNX Custom Op Registration for torch_attention
|
||||
_torch_attention_schema = defs.OpSchema(
|
||||
name="torch_attention",
|
||||
domain=_AUTO_DEPLOY_DOMAIN_NAME,
|
||||
since_version=1,
|
||||
doc="SDPA attention (with optional GQA) that supports bnsd and bsnd memory layouts.",
|
||||
inputs=[
|
||||
defs.OpSchema.FormalParameter(
|
||||
name="query",
|
||||
description="Query tensor [batch, seq_len_q/num_heads, num_heads/seq_len_q, head_dim]",
|
||||
type_str="T",
|
||||
),
|
||||
defs.OpSchema.FormalParameter(
|
||||
name="key",
|
||||
description="Key tensor [batch, seq_len_k/num_kv_heads, num_kv_heads/seq_len_k, head_dim]",
|
||||
type_str="T",
|
||||
),
|
||||
defs.OpSchema.FormalParameter(
|
||||
name="value",
|
||||
description="Value tensor [batch, seq_len_k/num_kv_heads, num_kv_heads/seq_len_k, head_dim]",
|
||||
type_str="T",
|
||||
),
|
||||
defs.OpSchema.FormalParameter(
|
||||
name="attn_mask",
|
||||
description="Optional attention mask in [batch, num_heads, seq_len_q, seq_len_k] layout",
|
||||
type_str="T",
|
||||
),
|
||||
defs.OpSchema.FormalParameter(
|
||||
name="sinks",
|
||||
description="Optional sinks tensor",
|
||||
type_str="T",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
defs.OpSchema.FormalParameter(
|
||||
name="output",
|
||||
description="Attention output in the same layout as inputs",
|
||||
type_str="T",
|
||||
)
|
||||
],
|
||||
type_constraints=[
|
||||
(
|
||||
"T",
|
||||
["tensor(float16)", "tensor(float)", "tensor(bfloat16)"],
|
||||
"Input and output data type for floating point tensors.",
|
||||
),
|
||||
],
|
||||
attributes=[
|
||||
defs.OpSchema.Attribute(
|
||||
name="dropout_p",
|
||||
type=defs.OpSchema.AttrType.FLOAT,
|
||||
description="Dropout probability.",
|
||||
required=False,
|
||||
),
|
||||
defs.OpSchema.Attribute(
|
||||
name="is_causal",
|
||||
type=defs.OpSchema.AttrType.INT,
|
||||
description="Whether to apply causal masking (0 or 1).",
|
||||
required=False,
|
||||
),
|
||||
defs.OpSchema.Attribute(
|
||||
name="scale",
|
||||
type=defs.OpSchema.AttrType.FLOAT,
|
||||
description="Attention scale factor.",
|
||||
required=False,
|
||||
),
|
||||
defs.OpSchema.Attribute(
|
||||
name="sliding_window",
|
||||
type=defs.OpSchema.AttrType.INT,
|
||||
description="Sliding window size for attention.",
|
||||
required=False,
|
||||
),
|
||||
defs.OpSchema.Attribute(
|
||||
name="logit_cap",
|
||||
type=defs.OpSchema.AttrType.FLOAT,
|
||||
description="Logit capping value.",
|
||||
required=False,
|
||||
),
|
||||
defs.OpSchema.Attribute(
|
||||
name="layout",
|
||||
type=defs.OpSchema.AttrType.STRING,
|
||||
description="Memory layout: 'bnsd' or 'bsnd'.",
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def register_onnx_schemas():
|
||||
"""Register ONNX custom ops."""
|
||||
defs.register_schema(_torch_rope_with_explicit_cos_sin_schema)
|
||||
defs.register_schema(_torch_attention_schema)
|
||||
defs.register_schema(_attention_plugin_schema)
|
||||
@ -0,0 +1,166 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import operator
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import is_op
|
||||
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
|
||||
|
||||
|
||||
@TransformRegistry.register("adapt_to_edgellm")
|
||||
class AdaptToEdgeLLM(BaseTransform):
|
||||
"""Transform that adapts the model graph for EdgeLLM deployment.
|
||||
|
||||
This transform performs several modifications to make the model compatible
|
||||
with EdgeLLM runtime requirements:
|
||||
|
||||
1. Converts all model weights to float16 precision
|
||||
2. Adds float32 cast after the final linear layer (for logits output)
|
||||
3. Inserts float16 casts after attention output reshapes
|
||||
4. Changes any bfloat16 casts to float16 (EdgeLLM may not support bfloat16)
|
||||
|
||||
These modifications ensure proper data type handling throughout the model
|
||||
while maintaining numerical precision where needed (e.g., logits output).
|
||||
"""
|
||||
|
||||
def _add_cast_after_last_linear(self, gm: GraphModule) -> torch.fx.Node:
|
||||
"""Add a float32 cast operation after the final linear layer.
|
||||
|
||||
The final linear layer produces logits which are typically kept in
|
||||
float32 for numerical stability during softmax and sampling operations.
|
||||
|
||||
Args:
|
||||
gm: The GraphModule to modify.
|
||||
|
||||
Returns:
|
||||
The newly created cast node.
|
||||
"""
|
||||
graph = gm.graph
|
||||
linear_nodes = graph.find_nodes(
|
||||
op="call_function", target=torch.ops.auto_deploy.torch_linear_simple.default, sort=True
|
||||
)
|
||||
assert len(linear_nodes) > 0, "No linear nodes found"
|
||||
last_linear_node = linear_nodes[-1]
|
||||
with graph.inserting_after(last_linear_node):
|
||||
cast_node = graph.call_function(
|
||||
torch.ops.aten.to.dtype, args=(last_linear_node, torch.float32)
|
||||
)
|
||||
last_linear_node.replace_all_uses_with(cast_node)
|
||||
# Restore cast_node's input to last_linear_node
|
||||
cast_node.update_arg(0, last_linear_node)
|
||||
return cast_node
|
||||
|
||||
def _insert_cast_after_attn_reshape(self, gm: GraphModule) -> int:
|
||||
"""Insert float16 cast after reshape nodes following AttentionPlugin output.
|
||||
|
||||
The AttentionPlugin may output tensors that need explicit casting to float16
|
||||
before being consumed by subsequent operations (e.g., linear projections).
|
||||
This ensures consistent data types throughout the attention block.
|
||||
|
||||
Graph transformation:
|
||||
Before: AttentionPlugin[0] -> Reshape -> MatMul
|
||||
After: AttentionPlugin[0] -> Reshape -> Cast(to float16) -> MatMul
|
||||
|
||||
Args:
|
||||
gm: The GraphModule to modify.
|
||||
|
||||
Returns:
|
||||
Number of cast nodes inserted.
|
||||
"""
|
||||
graph = gm.graph
|
||||
|
||||
# Find all AttentionPlugin nodes
|
||||
attention_plugin_nodes = graph.find_nodes(
|
||||
op="call_function", target=torch.ops.auto_deploy.torch_onnx_attention_plugin.default
|
||||
)
|
||||
|
||||
num_inserted = 0
|
||||
for attn_node in attention_plugin_nodes:
|
||||
# Find getitem[0] for this AttentionPlugin (first output)
|
||||
for user in attn_node.users:
|
||||
if is_op(user, operator.getitem) and user.args[1] == 0:
|
||||
getitem_0_node = user
|
||||
# Find reshape nodes that use this getitem[0]
|
||||
for reshape_user in list(getitem_0_node.users):
|
||||
if is_op(reshape_user, torch.ops.aten.reshape.default):
|
||||
reshape_node = reshape_user
|
||||
# Insert cast (to float16) after reshape
|
||||
with graph.inserting_after(reshape_node):
|
||||
cast_node = graph.call_function(
|
||||
torch.ops.aten.to.dtype,
|
||||
args=(reshape_node, torch.float16),
|
||||
)
|
||||
reshape_node.replace_all_uses_with(cast_node)
|
||||
# Fix: restore cast_node's input to reshape_node
|
||||
# (replace_all_uses_with also replaced it)
|
||||
cast_node.update_arg(0, reshape_node)
|
||||
num_inserted += 1
|
||||
ad_logger.debug(f"Inserted cast (to float16) after {reshape_node.name}")
|
||||
|
||||
return num_inserted
|
||||
|
||||
def _change_cast_bfloat16_to_float16(self, gm: GraphModule) -> int:
|
||||
"""Replace all bfloat16 cast operations with float16 casts.
|
||||
|
||||
EdgeLLM or certain hardware backends may not support bfloat16 natively.
|
||||
This method converts all bfloat16 casts to float16 for compatibility.
|
||||
|
||||
Args:
|
||||
gm: The GraphModule to modify.
|
||||
|
||||
Returns:
|
||||
Number of cast operations changed.
|
||||
"""
|
||||
graph = gm.graph
|
||||
cast_nodes = graph.find_nodes(op="call_function", target=torch.ops.aten.to.dtype)
|
||||
num_changed = 0
|
||||
for cast_node in cast_nodes:
|
||||
if cast_node.args[1] == torch.bfloat16:
|
||||
cast_node.update_arg(1, torch.float16)
|
||||
num_changed += 1
|
||||
return num_changed
|
||||
|
||||
def _to_float16(self, gm: GraphModule) -> None:
|
||||
"""Convert all model parameters and buffers to float16 precision.
|
||||
|
||||
Args:
|
||||
gm: The GraphModule to convert.
|
||||
"""
|
||||
gm.half()
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
self._to_float16(gm)
|
||||
logits_cast = self._add_cast_after_last_linear(gm)
|
||||
assert logits_cast is not None, "Failed to add cast after last linear"
|
||||
num_attn_casts = self._insert_cast_after_attn_reshape(gm)
|
||||
num_bfloat16_casts = self._change_cast_bfloat16_to_float16(gm)
|
||||
ad_logger.info(f"Changed {num_bfloat16_casts} bfloat16 casts to float16")
|
||||
ad_logger.info(f"Adapted EdgeLLM model (inserted {num_attn_casts} attention casts)")
|
||||
|
||||
return gm, TransformInfo(
|
||||
skipped=False, num_matches=num_attn_casts, is_clean=False, has_valid_shapes=True
|
||||
)
|
||||
@ -0,0 +1,418 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from onnxscript import ir, opset20
|
||||
from pydantic import Field
|
||||
from torch.export import Dim
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...models import hf
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.logger import ad_logger
|
||||
from ..interface import (
|
||||
BaseTransform,
|
||||
SharedConfig,
|
||||
TransformConfig,
|
||||
TransformInfo,
|
||||
TransformRegistry,
|
||||
)
|
||||
from . import _onnx_schemas
|
||||
from ._chat_template import process_chat_template
|
||||
from ._config_export import export_llm_config
|
||||
|
||||
|
||||
class ExportToONNXConfig(TransformConfig):
|
||||
"""Configuration for the export to ONNX transform."""
|
||||
|
||||
output_dir: Path = Field(
|
||||
description="The directory to save the exported ONNX model.",
|
||||
)
|
||||
is_eagle_base: bool = Field(
|
||||
description="Whether the model is an Eagle base model.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Custom translation functions for ONNX export
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _translate_rope_op(
|
||||
q: ir.Tensor, k: ir.Tensor, cos: ir.Tensor, sin: ir.Tensor, unsqueeze_dim: int
|
||||
):
|
||||
return _onnx_schemas.auto_deploy_opset.rope_with_explicit_cos_sin(
|
||||
q, k, cos, sin, unsqueeze_dim=unsqueeze_dim
|
||||
)
|
||||
|
||||
|
||||
def _translate_simple_linear_op(input: ir.Tensor, weight: ir.Tensor, bias: Optional[ir.Tensor]):
|
||||
weight = opset20.Transpose(weight, perm=[1, 0])
|
||||
if bias is None:
|
||||
return opset20.MatMul(input, weight)
|
||||
return opset20.Add(opset20.MatMul(input, weight), bias)
|
||||
|
||||
|
||||
def _translate_gather_nd_op(data: ir.Tensor, indices: ir.Tensor, batch_dims: int):
|
||||
return opset20.GatherND(data, indices, batch_dims=batch_dims)
|
||||
|
||||
|
||||
def _translate_rope_attention_op(
|
||||
qkv: ir.Tensor,
|
||||
past_key_values: ir.Tensor,
|
||||
context_lengths: ir.Tensor,
|
||||
rope_rotary_cos_sin: ir.Tensor,
|
||||
kvcache_start_index: ir.Tensor,
|
||||
enable_tree_attention: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
num_q_heads: int,
|
||||
):
|
||||
"""
|
||||
ONNX custom op translation function for AttentionPlugin.
|
||||
|
||||
This function creates a custom ONNX op node in the trt domain.
|
||||
The actual implementation will need to be provided in the inference engine
|
||||
(e.g., ONNX Runtime or TensorRT) that loads this ONNX model.
|
||||
|
||||
Note: This is a translation function for torch.onnx.export's custom_translation_table,
|
||||
so it should NOT have @script() decorator.
|
||||
"""
|
||||
# Call the custom op from the trt domain
|
||||
return _onnx_schemas.trt_opset.AttentionPlugin(
|
||||
qkv,
|
||||
past_key_values,
|
||||
context_lengths,
|
||||
rope_rotary_cos_sin,
|
||||
kvcache_start_index,
|
||||
enable_tree_attention=enable_tree_attention,
|
||||
head_size=head_size,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_q_heads=num_q_heads,
|
||||
)
|
||||
|
||||
|
||||
def _translate_torch_attention_op(
|
||||
query: ir.Tensor,
|
||||
key: ir.Tensor,
|
||||
value: ir.Tensor,
|
||||
attn_mask: Optional[ir.Tensor] = None,
|
||||
sinks: Optional[ir.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
logit_cap: Optional[float] = None,
|
||||
layout: str = "bnsd",
|
||||
):
|
||||
"""
|
||||
ONNX custom op translation function for torch_attention.
|
||||
|
||||
This function creates a custom ONNX op node in the auto_deploy domain.
|
||||
The actual implementation will need to be provided in the inference engine
|
||||
(e.g., ONNX Runtime or TensorRT) that loads this ONNX model.
|
||||
|
||||
Note: This is a translation function for torch.onnx.export's custom_translation_table,
|
||||
so it should NOT have @script() decorator.
|
||||
"""
|
||||
# Call the custom op from the auto_deploy domain
|
||||
return _onnx_schemas.auto_deploy_opset.torch_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
sinks,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=1 if is_causal else 0,
|
||||
scale=scale,
|
||||
sliding_window=sliding_window,
|
||||
logit_cap=logit_cap,
|
||||
layout=layout,
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("export_to_onnx")
|
||||
class ExportToONNX(BaseTransform):
|
||||
"""Transform that exports a PyTorch GraphModule to ONNX format for deployment.
|
||||
|
||||
This transform is responsible for:
|
||||
1. Exporting the model graph to ONNX format with dynamic shapes support
|
||||
2. Generating configuration files (config.json) for the exported model
|
||||
3. Saving tokenizer files (tokenizer.json, vocab.json, etc.)
|
||||
4. Processing and exporting chat templates
|
||||
|
||||
The exported ONNX model includes custom ops from the auto_deploy. These custom ops include:
|
||||
- torch_onnx_attention_plugin: Fused RoPE + Attention for efficient inference(exported as EdgeLLM's custom op)
|
||||
- torch_onnx_gather_nd: N-dimensional gather operation (exported as onnxscript.opset20.GatherND)
|
||||
|
||||
Note:
|
||||
This transform does NOT modify the input graph. It only exports the graph
|
||||
to external files and returns the original graph unchanged.
|
||||
|
||||
Attributes:
|
||||
config: ExportToONNXConfig containing output directory, and is_eagle_base flag.
|
||||
"""
|
||||
|
||||
config: ExportToONNXConfig
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[TransformConfig]:
|
||||
"""Return the configuration class for this transform."""
|
||||
return ExportToONNXConfig
|
||||
|
||||
def _export_onnx_model(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
_factory: ModelFactory,
|
||||
_shared_config: SharedConfig,
|
||||
) -> bool:
|
||||
"""Export the GraphModule to ONNX format using torch.onnx.export with dynamo.
|
||||
|
||||
This method handles the core ONNX export logic:
|
||||
1. Extracts input placeholders and their metadata from the graph
|
||||
2. Configures dynamic shapes for batch size, sequence length, and KV cache
|
||||
3. Sets up custom translation table for auto_deploy custom ops
|
||||
4. Exports the model using torch.onnx.export with dynamo=True
|
||||
|
||||
Args:
|
||||
gm: The PyTorch FX GraphModule to export.
|
||||
cm: CachedSequenceInterface containing max batch size and sequence length info.
|
||||
_factory: ModelFactory instance (unused in this method).
|
||||
_shared_config: SharedConfig instance (unused in this method).
|
||||
|
||||
Returns:
|
||||
bool: True if export was successful.
|
||||
"""
|
||||
# Extract input placeholders from graph to build kwargs for export
|
||||
args = []
|
||||
kwargs = {}
|
||||
placeholders = gm.graph.find_nodes(op="placeholder")
|
||||
for ph in placeholders:
|
||||
kwargs[ph.name] = ph.meta["val"]
|
||||
args = tuple(args)
|
||||
|
||||
ad_logger.info("Placeholders args:")
|
||||
for i, e in enumerate(args):
|
||||
ad_logger.info(f" {i}: {placeholders[i].name:20} {e}")
|
||||
|
||||
ad_logger.info("Placeholders kwargs:")
|
||||
for k, v in kwargs.items():
|
||||
ad_logger.info(f" {k}: {v}")
|
||||
|
||||
# Build output path
|
||||
output_path = self.config.output_dir / "model.onnx"
|
||||
|
||||
# Build dynamic_shapes for dynamo export
|
||||
# For dynamo, we need to specify dynamic dimensions for each input tensor
|
||||
dynamic_shapes = {}
|
||||
dynamic_shapes["input_ids"] = {
|
||||
0: Dim("batch_size", min=0, max=cm.info.max_batch_size),
|
||||
1: Dim("seq_len", min=0, max=cm.info.max_seq_len),
|
||||
}
|
||||
# Add dynamic shapes for context_lengths and rope_rotary_cos_sin
|
||||
dynamic_shapes["context_lengths"] = {
|
||||
0: Dim("batch_size", min=0, max=cm.info.max_batch_size)
|
||||
}
|
||||
dynamic_shapes["rope_rotary_cos_sin"] = {
|
||||
0: Dim("rope_batch_size", min=1, max=16),
|
||||
1: Dim("max_position_embeddings", min=1, max=4096),
|
||||
}
|
||||
dynamic_shapes["kvcache_start_index"] = {
|
||||
0: Dim("kv_cache_start_batch_size", min=0, max=cm.info.max_batch_size)
|
||||
}
|
||||
# Add dynamic shapes for past_key_values
|
||||
num_layers = len(
|
||||
gm.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.auto_deploy.torch_onnx_attention_plugin.default
|
||||
)
|
||||
)
|
||||
for i in range(num_layers):
|
||||
dynamic_shapes[f"past_key_values_{i}"] = {
|
||||
0: Dim("batch_size", min=0, max=cm.info.max_batch_size),
|
||||
3: Dim("past_len", min=1, max=4096),
|
||||
}
|
||||
dynamic_shapes["last_token_ids"] = {
|
||||
0: Dim("batch_size", min=0, max=cm.info.max_batch_size),
|
||||
1: Dim("num_selected_tokens", min=1, max=cm.info.max_seq_len),
|
||||
}
|
||||
|
||||
# Create custom translation table for ONNX export
|
||||
# Map torch custom ops to their corresponding onnxscript translation functions
|
||||
custom_translation_table = {
|
||||
# Before fuse rope attention
|
||||
# NOTE (yoco): This 2 ops will be fused into the AttentionPlugin operation
|
||||
# in the fuse_rope_attention transform.
|
||||
# However, when TensorRT-LLM changed, the fusion might not be applied.
|
||||
# And for debug purpose we might want to export the .onnx to check the graph.
|
||||
# So let's just keep them here.
|
||||
torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin.default: _translate_rope_op,
|
||||
torch.ops.auto_deploy.torch_attention.default: _translate_torch_attention_op,
|
||||
# Before and after fuse rope attention
|
||||
torch.ops.auto_deploy.torch_linear_simple.default: _translate_simple_linear_op,
|
||||
# After fuse rope attention
|
||||
torch.ops.auto_deploy.torch_onnx_attention_plugin.default: _translate_rope_attention_op,
|
||||
torch.ops.auto_deploy.torch_onnx_gather_nd.default: _translate_gather_nd_op,
|
||||
}
|
||||
|
||||
# Prepare output names
|
||||
output_names = []
|
||||
output_node = gm.graph.find_nodes(op="output")[0]
|
||||
outputs = output_node.args[0]
|
||||
for output in outputs:
|
||||
output_names.append(output.name)
|
||||
output_names[0] = "logits"
|
||||
|
||||
# Register ONNX custom ops
|
||||
_onnx_schemas.register_onnx_schemas()
|
||||
|
||||
# Export the graph module to ONNX using dynamo (more advanced tracer)
|
||||
ad_logger.info(f"Exporting GraphModule to ONNX with dynamo: {output_path}")
|
||||
torch.onnx.export(
|
||||
gm,
|
||||
tuple(args),
|
||||
output_path,
|
||||
opset_version=20,
|
||||
kwargs=kwargs,
|
||||
dynamo=True,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
report=False,
|
||||
output_names=output_names,
|
||||
custom_translation_table=custom_translation_table,
|
||||
)
|
||||
|
||||
ad_logger.info(f"Successfully exported ONNX model to {output_path}")
|
||||
return True
|
||||
|
||||
def _export_config_json(
|
||||
self, factory: ModelFactory, output_dir: str, is_eagle_base: bool
|
||||
) -> None:
|
||||
"""Export model configuration to config.json.
|
||||
|
||||
Generates a configuration file containing model architecture parameters
|
||||
such as hidden size, number of layers, attention heads, etc.
|
||||
|
||||
Args:
|
||||
factory: The ModelFactory containing model configuration.
|
||||
output_dir: Directory path where config.json will be saved.
|
||||
is_eagle_base: If True, exports as "eagle3_base" model type,
|
||||
otherwise exports as "llm" model type.
|
||||
"""
|
||||
model_type = "eagle3_base" if is_eagle_base else "llm"
|
||||
assert isinstance(factory, hf.AutoModelFactory)
|
||||
model_config, _ = factory._get_model_config()
|
||||
model_config = export_llm_config(model_config, model_type)
|
||||
# Add reduced_vocab_size to config if vocabulary reduction is used
|
||||
reduced_vocab_size = None # TODO: Implement this
|
||||
if reduced_vocab_size is not None:
|
||||
model_config["reduced_vocab_size"] = reduced_vocab_size
|
||||
ad_logger.info(f"Added reduced_vocab_size={reduced_vocab_size} to config")
|
||||
|
||||
config_path = os.path.join(output_dir, "config.json")
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(model_config, f, indent=2)
|
||||
ad_logger.info(f"Model configuration saved to {config_path}")
|
||||
|
||||
def _export_json_files(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
_cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
_shared_config: SharedConfig,
|
||||
) -> None:
|
||||
"""Export all JSON configuration and tokenizer files required for deployment.
|
||||
|
||||
This method orchestrates the export of:
|
||||
1. config.json - Model architecture configuration
|
||||
2. Tokenizer files - added_tokens.json, special_tokens_map.json,
|
||||
tokenizer.json, tokenizer_config.json, vocab.json
|
||||
3. processed_chat_template.json - Processed chat template for inference
|
||||
|
||||
Args:
|
||||
gm: The GraphModule containing model configuration.
|
||||
_cm: CachedSequenceInterface (unused, kept for interface consistency).
|
||||
factory: ModelFactory used to initialize tokenizer and get model directory.
|
||||
_shared_config: SharedConfig (unused, kept for interface consistency).
|
||||
"""
|
||||
# Export model configuration (architecture params, layer counts, etc.)
|
||||
is_eagle_base = self.config.is_eagle_base
|
||||
output_dir = self.config.output_dir
|
||||
self._export_config_json(factory, output_dir, is_eagle_base)
|
||||
|
||||
# Export tokenizer files for text processing during inference
|
||||
# Includes: added_tokens.json, special_tokens_map.json, tokenizer.json,
|
||||
# tokenizer_config.json, vocab.json
|
||||
tokenizer = factory.init_tokenizer()
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
# Export processed chat template for conversational inference
|
||||
model_dir = factory.model
|
||||
process_chat_template(model_dir, output_dir)
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
"""Apply the ONNX export transform to the graph module.
|
||||
|
||||
This is the main entry point that orchestrates the export process:
|
||||
1. Creates the output directory if it doesn't exist
|
||||
2. Exports all JSON configuration files (config, tokenizer, chat template)
|
||||
3. Exports the ONNX model with dynamic shapes and custom ops
|
||||
|
||||
Note:
|
||||
Unlike other transforms, this does NOT modify the graph.
|
||||
It only performs side effects (writing files to disk).
|
||||
|
||||
Args:
|
||||
gm: The PyTorch FX GraphModule to export.
|
||||
cm: CachedSequenceInterface containing runtime constraints.
|
||||
factory: ModelFactory for tokenizer initialization.
|
||||
shared_config: SharedConfig for transform coordination.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- gm: The original GraphModule (unchanged)
|
||||
- info: TransformInfo with export status (is_clean=True since graph is unmodified)
|
||||
"""
|
||||
# Ensure output directory exists before writing any files
|
||||
if not self.config.output_dir.exists():
|
||||
self.config.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Step 1: Export all auxiliary JSON files (config, tokenizer, chat template)
|
||||
self._export_json_files(gm, cm, factory, shared_config)
|
||||
|
||||
# Step 2: Export the ONNX model with dynamic shapes
|
||||
success = self._export_onnx_model(gm, cm, factory, shared_config)
|
||||
|
||||
# Return original graph unchanged with export status info
|
||||
# This transform is "clean" because it doesn't modify the graph structure
|
||||
info = TransformInfo(
|
||||
skipped=not success,
|
||||
num_matches=1 if success else 0,
|
||||
is_clean=True, # Graph is not modified
|
||||
has_valid_shapes=True, # Shape validity is preserved
|
||||
)
|
||||
|
||||
return gm, info
|
||||
@ -0,0 +1,551 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import operator
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils._graph import add_graph_input, add_graph_output, remove_graph_input
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import extract_op_args, is_op
|
||||
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
|
||||
|
||||
|
||||
class MatchResult:
|
||||
"""Container for matched RoPE + attention pattern nodes.
|
||||
|
||||
This class stores all the relevant nodes and metadata from a successfully
|
||||
matched RoPE (Rotary Position Embedding) + attention pattern in the graph.
|
||||
It is used to facilitate the pattern replacement during the fusion transform.
|
||||
|
||||
Attributes:
|
||||
q: The original query tensor node before view/reshape.
|
||||
k: The original key tensor node before view/reshape.
|
||||
v: The original value tensor node before view/reshape.
|
||||
cos: The cosine embedding node for RoPE.
|
||||
sin: The sine embedding node for RoPE.
|
||||
attn_node: The attention operation node to be replaced.
|
||||
rope_node: The RoPE operation node to be fused.
|
||||
head_dim: Dimension of each attention head.
|
||||
num_q_heads: Number of query attention heads.
|
||||
num_kv_heads: Number of key-value attention heads (for GQA/MQA).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
q: Node,
|
||||
k: Node,
|
||||
v: Node,
|
||||
cos: Node,
|
||||
sin: Node,
|
||||
attn_node: Node,
|
||||
rope_node: Node,
|
||||
head_dim: int,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
):
|
||||
"""Initialize MatchResult with matched pattern nodes and metadata.
|
||||
|
||||
Args:
|
||||
q: Query tensor node.
|
||||
k: Key tensor node.
|
||||
v: Value tensor node.
|
||||
cos: RoPE cosine node.
|
||||
sin: RoPE sine node.
|
||||
attn_node: Attention operation node.
|
||||
rope_node: RoPE operation node.
|
||||
head_dim: Attention head dimension.
|
||||
num_q_heads: Number of query heads.
|
||||
num_kv_heads: Number of key-value heads.
|
||||
"""
|
||||
self.q = q
|
||||
self.k = k
|
||||
self.v = v
|
||||
self.cos = cos
|
||||
self.sin = sin
|
||||
self.attn_node = attn_node
|
||||
self.rope_node = rope_node
|
||||
self.head_dim = head_dim
|
||||
self.num_q_heads = num_q_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
|
||||
def __repr__(self):
|
||||
"""Return string representation of the match result."""
|
||||
return (
|
||||
f"MatchResult(q={self.q.name}, k={self.k.name}, v={self.v.name}, "
|
||||
f"cos={self.cos.name}, sin={self.sin.name}, head_dim={self.head_dim}, "
|
||||
f"num_q_heads={self.num_q_heads}, num_kv_heads={self.num_kv_heads})"
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("fuse_rope_attention")
|
||||
class FuseRopeAttention(BaseTransform):
|
||||
"""Transform that fuses RoPE and attention operations into a single AttentionPlugin.
|
||||
|
||||
This transform identifies patterns in the graph where RoPE (Rotary Position
|
||||
Embedding) is applied to query and key tensors followed by scaled dot-product
|
||||
attention, and replaces them with a fused AttentionPlugin operation.
|
||||
|
||||
The fusion provides several benefits:
|
||||
- Reduced memory bandwidth by eliminating intermediate tensors
|
||||
- Enables KV-cache integration for efficient autoregressive generation
|
||||
- Allows backend-specific optimizations (e.g., TensorRT, EdgeLLM)
|
||||
|
||||
Pattern matched (backwards from attention):
|
||||
1. torch_attention(rope_q, rope_k, view_v, attn_mask, ...)
|
||||
2. rope_q = rope[1], rope_k = rope[0]
|
||||
3. rope = torch_rope_with_explicit_cos_sin(cont_q, cont_k, cos, sin, 2)
|
||||
4. cont_q = contiguous(view_q), cont_k = contiguous(view_k)
|
||||
5. view_q = view(q, ...), view_k = view(k, ...), view_v = view(v, ...)
|
||||
|
||||
The transform also:
|
||||
- Adds new graph inputs: context_lengths, rope_rotary_cos_sin, kvcache_start_index
|
||||
- Adds past_key_values inputs for each attention layer
|
||||
- Adds present_key_values outputs for each attention layer
|
||||
- Removes the position_ids placeholder (no longer needed after fusion)
|
||||
"""
|
||||
|
||||
def _get_batch_size_and_max_seq_len(self, gm: GraphModule) -> Tuple[int, int]:
|
||||
"""Get batch size and max sequence length from the graph.
|
||||
|
||||
Args:
|
||||
gm: The GraphModule to get batch size and max sequence length from.
|
||||
|
||||
Returns:
|
||||
Tuple of (batch_size_dim, max_seq_len_dim, batch_size_sym_node, max_seq_len_sym_node).
|
||||
|
||||
Note:
|
||||
For clarity, we return the symbolic nodes as well, which can be
|
||||
used in operations like view or reshape.
|
||||
"""
|
||||
graph = gm.graph
|
||||
input_ids_node = graph.find_nodes(op="placeholder", target="input_ids")[0]
|
||||
position_ids_node = graph.find_nodes(op="placeholder", target="position_ids")[0]
|
||||
input_ids_meta = input_ids_node.meta.get("val")
|
||||
batch_size_dim = input_ids_meta.size(0)
|
||||
max_seq_len_dim = input_ids_meta.size(1)
|
||||
|
||||
# We scan all the sym_size.int nodes to find the symbolic nodes for
|
||||
# batch size and max sequence length. The symbolic nodes has two sources.
|
||||
# 1. The input_ids placeholder.
|
||||
# 2. The position_ids placeholder.
|
||||
# Both of their shapes are (batch_size, max_seq_len).
|
||||
sym_ints = graph.find_nodes(op="call_function", target=torch.ops.aten.sym_size.int)
|
||||
for sym_int in sym_ints:
|
||||
if sym_int.args[0] != input_ids_node and sym_int.args[0] != position_ids_node:
|
||||
continue
|
||||
if sym_int.args[1] == 0:
|
||||
batch_size_sym_node = sym_int
|
||||
elif sym_int.args[1] == 1:
|
||||
max_seq_len_sym_node = sym_int
|
||||
assert batch_size_sym_node is not None and max_seq_len_sym_node is not None
|
||||
|
||||
return batch_size_dim, max_seq_len_dim, batch_size_sym_node, max_seq_len_sym_node
|
||||
|
||||
def _get_config_head_dim(self, model_config: transformers.PretrainedConfig) -> int:
|
||||
"""Get head dimension from model config."""
|
||||
if hasattr(model_config, "head_dim"):
|
||||
return model_config.head_dim
|
||||
else:
|
||||
return model_config.hidden_size // model_config.num_attention_heads
|
||||
|
||||
def _match_rope_attention_pattern(
|
||||
self, gm: GraphModule, model_config: transformers.PretrainedConfig
|
||||
) -> List[MatchResult]:
|
||||
"""Match RoPE + attention patterns in the computation graph.
|
||||
|
||||
Traverses the graph backwards from attention nodes to identify the
|
||||
complete pattern of RoPE application followed by attention computation.
|
||||
|
||||
Pattern structure (backwards from attention):
|
||||
1. torch_attention(rope_q, rope_k, bind_v, attn_mask, ...)
|
||||
2. rope_q, rope_k = rope[1], rope[0]
|
||||
3. rope = torch_rope_with_explicit_cos_sin(bind_q, bind_k, cos, sin, 2)
|
||||
|
||||
Args:
|
||||
gm: The GraphModule to search for patterns.
|
||||
|
||||
Returns:
|
||||
List of MatchResult objects, each containing the matched nodes
|
||||
and extracted metadata (head_dim, num_q_heads, num_kv_heads).
|
||||
"""
|
||||
matches = []
|
||||
graph = gm.graph
|
||||
head_dim = self._get_config_head_dim(model_config)
|
||||
|
||||
# Iterate through all nodes to find attention ops
|
||||
for attn_node in graph.nodes:
|
||||
if not is_op(attn_node, torch.ops.auto_deploy.torch_attention):
|
||||
continue
|
||||
|
||||
if attn_node.args[10] != "bsnd":
|
||||
ad_logger.error(
|
||||
f" Skipping: attention layout is not bsnd: {attn_node.kwargs.get('layout', None)}"
|
||||
)
|
||||
continue
|
||||
|
||||
ad_logger.debug(f"Found attention node: {attn_node.name}")
|
||||
|
||||
# Extract attention inputs: (rope_q, rope_k, v, attn_mask, ...)
|
||||
if len(attn_node.args) < 4:
|
||||
ad_logger.error(f" Skipping: insufficient args ({len(attn_node.args)})")
|
||||
continue
|
||||
|
||||
rope_q_node, rope_k_node, bind_v = extract_op_args(attn_node, "query", "key", "value")
|
||||
|
||||
# Step 1: Match rope_q and rope_k as getitem[1] and getitem[0] from rope output
|
||||
if not (is_op(rope_q_node, operator.getitem) and is_op(rope_k_node, operator.getitem)):
|
||||
ad_logger.error(" Skipping: rope_q or rope_k not getitem")
|
||||
continue
|
||||
|
||||
# Verify they come from the same rope node
|
||||
assert rope_q_node.target == operator.getitem
|
||||
assert rope_k_node.target == operator.getitem
|
||||
rope_node_from_q, rope_q_idx = rope_q_node.args
|
||||
rope_node_from_k, rope_k_idx = rope_k_node.args
|
||||
|
||||
if rope_node_from_q != rope_node_from_k:
|
||||
ad_logger.error(" Skipping: rope_q and rope_k come from different rope nodes")
|
||||
continue
|
||||
|
||||
rope_node = rope_node_from_q
|
||||
|
||||
# Verify getitem indices: rope[0] = rope_k, rope[1] = rope_q
|
||||
if rope_k_idx != 0 or rope_q_idx != 1:
|
||||
ad_logger.error(
|
||||
f" Skipping: incorrect getitem indices (k={rope_k_idx}, q={rope_q_idx})"
|
||||
)
|
||||
continue
|
||||
|
||||
# Step 2: Match the rope node
|
||||
if not is_op(rope_node, torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin):
|
||||
ad_logger.error(" Skipping: not a rope node")
|
||||
continue
|
||||
|
||||
# Extract rope inputs: (cont_q, cont_k, cos, sin, 2)
|
||||
if len(rope_node.args) < 5:
|
||||
ad_logger.error(f" Skipping: rope has insufficient args ({len(rope_node.args)})")
|
||||
continue
|
||||
|
||||
bind_k = rope_node.args[0]
|
||||
bind_q = rope_node.args[1]
|
||||
cos_node = rope_node.args[2]
|
||||
sin_node = rope_node.args[3]
|
||||
|
||||
num_q_heads = model_config.num_attention_heads
|
||||
num_kv_heads = model_config.num_key_value_heads
|
||||
|
||||
# Successfully matched the pattern!
|
||||
match = MatchResult(
|
||||
q=bind_q,
|
||||
k=bind_k,
|
||||
v=bind_v,
|
||||
cos=cos_node,
|
||||
sin=sin_node,
|
||||
attn_node=attn_node,
|
||||
rope_node=rope_node,
|
||||
head_dim=head_dim,
|
||||
num_q_heads=num_q_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
)
|
||||
matches.append(match)
|
||||
ad_logger.debug(f" ✓ Matched pattern: {match}")
|
||||
|
||||
return matches
|
||||
|
||||
def _add_global_placeholders(self, gm: GraphModule, factory: ModelFactory) -> tuple:
|
||||
"""Add global input placeholders required by the fused AttentionPlugin.
|
||||
|
||||
Creates three new graph inputs that are shared across all attention layers:
|
||||
- context_lengths: Actual sequence length for each batch element
|
||||
- rope_rotary_cos_sin: Precomputed RoPE embeddings
|
||||
- kvcache_start_index: Starting position in KV-cache for each batch
|
||||
|
||||
These inputs enable dynamic sequence handling and efficient KV-cache
|
||||
management during autoregressive generation.
|
||||
|
||||
Args:
|
||||
gm: GraphModule to add placeholders to.
|
||||
cm: CachedSequenceInterface for registering dynamic shapes.
|
||||
|
||||
Returns:
|
||||
Tuple of (context_lengths_node, rope_rotary_cos_sin_node,
|
||||
kvcache_start_index_node).
|
||||
"""
|
||||
|
||||
graph = gm.graph
|
||||
|
||||
# Find token_ids placeholder to get batch_size symbolic dimension
|
||||
token_ids_node = None
|
||||
for node in graph.nodes:
|
||||
if node.op == "placeholder" and "token" in node.name.lower():
|
||||
token_ids_node = node
|
||||
break
|
||||
|
||||
if token_ids_node is None:
|
||||
# Fallback: use first placeholder
|
||||
token_ids_node = graph.find_nodes(op="placeholder", sort=True)[0]
|
||||
|
||||
batch_size_dim, max_seq_len_dim, _, _ = self._get_batch_size_and_max_seq_len(gm)
|
||||
ad_logger.debug(f"Extracted batch_size={batch_size_dim}, max_seq_len={max_seq_len_dim}")
|
||||
|
||||
# 1. Add context_lengths placeholder: int32[batch_size]
|
||||
context_lengths_example = torch.zeros(batch_size_dim, dtype=torch.int32, device="meta")
|
||||
context_lengths_node = add_graph_input(
|
||||
gm, name="context_lengths", val=context_lengths_example
|
||||
)
|
||||
ad_logger.debug(f"Added context_lengths placeholder: {context_lengths_node.name}")
|
||||
|
||||
# 2. Add rope_rotary_cos_sin placeholder: float32[rope_batch_size, rope_max_position_length, 64]
|
||||
# Create with concrete example tensor
|
||||
model_config, _ = factory._get_model_config()
|
||||
head_dim = self._get_config_head_dim(model_config)
|
||||
rope_example = torch.zeros(
|
||||
batch_size_dim, max_seq_len_dim, head_dim, dtype=torch.float32, device="meta"
|
||||
)
|
||||
rope_rotary_cos_sin_node = add_graph_input(gm, name="rope_rotary_cos_sin", val=rope_example)
|
||||
ad_logger.debug(f"Added rope_rotary_cos_sin placeholder: {rope_rotary_cos_sin_node.name}")
|
||||
|
||||
# 3. Add kvcache_start_index placeholder: int32[batch_size]
|
||||
kvcache_start_index_example = torch.zeros(batch_size_dim, dtype=torch.int32, device="meta")
|
||||
kvcache_start_index_node = add_graph_input(
|
||||
gm, name="kvcache_start_index", val=kvcache_start_index_example
|
||||
)
|
||||
ad_logger.debug(f"Added kvcache_start_index placeholder: {kvcache_start_index_node.name}")
|
||||
|
||||
return context_lengths_node, rope_rotary_cos_sin_node, kvcache_start_index_node
|
||||
|
||||
def _perform_replacement(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: "CachedSequenceInterface",
|
||||
matches: List[MatchResult],
|
||||
context_lengths_node: Node,
|
||||
rope_rotary_cos_sin_node: Node,
|
||||
kvcache_start_index_node: Node,
|
||||
) -> int:
|
||||
"""Replace matched RoPE + attention patterns with fused AttentionPlugin.
|
||||
|
||||
For each matched pattern, this method:
|
||||
1. Creates a past_key_values input placeholder for KV-cache
|
||||
2. Reshape Q, K, V to (batch_size, seq_len, -1)
|
||||
3. Concatenates reshaped Q, K, V tensors into a single QKV tensor
|
||||
4. Inserts the fused AttentionPlugin operation
|
||||
5. Creates getitem nodes to extract attention output and present KV-cache
|
||||
6. Replaces the original attention node with the fused output
|
||||
7. Adds present_key_values to graph outputs
|
||||
|
||||
Args:
|
||||
gm: The GraphModule being transformed.
|
||||
cm: CachedSequenceInterface for shape management.
|
||||
matches: List of matched patterns to replace.
|
||||
context_lengths_node: Shared context lengths input node.
|
||||
rope_rotary_cos_sin_node: Shared RoPE embeddings input node.
|
||||
kvcache_start_index_node: Shared KV-cache index input node.
|
||||
|
||||
Returns:
|
||||
Number of patterns successfully replaced.
|
||||
"""
|
||||
graph = gm.graph
|
||||
past_len = 4096 # Does not matter, this value will be replaced by symbolic dimension
|
||||
|
||||
# Get batch size & max sequence length
|
||||
batch_size_dim, _, batch_size_sym_node, max_seq_len_sym_node = (
|
||||
self._get_batch_size_and_max_seq_len(gm)
|
||||
)
|
||||
|
||||
# Process each match
|
||||
for match_id, match in enumerate(matches):
|
||||
ad_logger.debug(f"Processing match {match_id}: {match}")
|
||||
|
||||
# 1. Create past_key_values_<id> placeholder
|
||||
# Shape: float16[batch_size, 2, num_kv_heads, past_len, head_dim]
|
||||
past_key_values_example = torch.zeros(
|
||||
batch_size_dim,
|
||||
2,
|
||||
match.num_kv_heads,
|
||||
past_len,
|
||||
match.head_dim,
|
||||
dtype=torch.float16,
|
||||
device="meta",
|
||||
)
|
||||
past_key_values_node = add_graph_input(
|
||||
gm, name=f"past_key_values_{match_id}", val=past_key_values_example
|
||||
)
|
||||
|
||||
ad_logger.debug(f"Added past_key_values_{match_id} placeholder")
|
||||
|
||||
# 2. Reshape Q, K, V to (batch_size, seq_len, -1)
|
||||
with graph.inserting_before(match.attn_node):
|
||||
q_node = graph.call_function(
|
||||
torch.ops.aten.view.default,
|
||||
args=(match.q, (batch_size_sym_node, max_seq_len_sym_node, -1)),
|
||||
)
|
||||
k_node = graph.call_function(
|
||||
torch.ops.aten.view.default,
|
||||
args=(match.k, (batch_size_sym_node, max_seq_len_sym_node, -1)),
|
||||
)
|
||||
v_node = graph.call_function(
|
||||
torch.ops.aten.view.default,
|
||||
args=(match.v, (batch_size_sym_node, max_seq_len_sym_node, -1)),
|
||||
)
|
||||
|
||||
ad_logger.debug(
|
||||
f"Reshaped Q, K, V to (batch_size, seq_len, -1): {q_node.name}, {k_node.name}, {v_node.name}"
|
||||
)
|
||||
|
||||
# 3. Concatenate reshaped Q, K, V to (batch_size, seq_len, -1)
|
||||
with graph.inserting_before(match.attn_node):
|
||||
qkv_node = graph.call_function(
|
||||
torch.ops.aten.cat.default, args=([q_node, k_node, v_node], -1)
|
||||
)
|
||||
|
||||
ad_logger.debug(f"Created qkv concat node: {qkv_node.name}")
|
||||
|
||||
# 4. Create AttentionPlugin node
|
||||
enable_tree_attention = 0
|
||||
head_dim = match.head_dim
|
||||
num_kv_heads = match.num_kv_heads
|
||||
num_q_heads = match.num_q_heads
|
||||
|
||||
with graph.inserting_before(match.attn_node):
|
||||
cast_node = graph.call_function(
|
||||
torch.ops.aten.to.dtype, args=(qkv_node, torch.float16)
|
||||
)
|
||||
rope_attn_node = graph.call_function(
|
||||
torch.ops.auto_deploy.torch_onnx_attention_plugin.default,
|
||||
args=(
|
||||
cast_node,
|
||||
past_key_values_node,
|
||||
context_lengths_node,
|
||||
rope_rotary_cos_sin_node,
|
||||
kvcache_start_index_node,
|
||||
enable_tree_attention,
|
||||
head_dim,
|
||||
num_kv_heads,
|
||||
num_q_heads,
|
||||
),
|
||||
)
|
||||
|
||||
ad_logger.debug(f"Created AttentionPlugin node: {rope_attn_node.name}")
|
||||
|
||||
# 5. Create getitem nodes for rope_attn outputs
|
||||
with graph.inserting_after(rope_attn_node):
|
||||
rope_attn_output = graph.call_function(operator.getitem, args=(rope_attn_node, 0))
|
||||
rope_attn_present_kv = graph.call_function(
|
||||
operator.getitem,
|
||||
args=(rope_attn_node, 1),
|
||||
name=f"present_key_values_{match_id}",
|
||||
)
|
||||
|
||||
ad_logger.debug(
|
||||
f"Created getitem nodes: {rope_attn_output.name}, {rope_attn_present_kv.name}"
|
||||
)
|
||||
|
||||
# 6. Replace all uses of attn_node with rope_attn_output
|
||||
match.attn_node.replace_all_uses_with(rope_attn_output)
|
||||
ad_logger.debug(f"Replaced {match.attn_node.name} with {rope_attn_output.name}")
|
||||
|
||||
# 7. Add present_key_values_<id> to graph outputs using add_graph_output
|
||||
add_graph_output(gm, rope_attn_present_kv, f"present_key_values_{match_id}")
|
||||
ad_logger.debug(f"Added present_key_values_{match_id} to graph outputs")
|
||||
|
||||
# 8. Eliminate dead code (remove old pattern nodes)
|
||||
graph.eliminate_dead_code()
|
||||
ad_logger.debug("Eliminated dead code")
|
||||
|
||||
return len(matches)
|
||||
|
||||
def _remove_position_ids_placeholder(self, gm: GraphModule) -> str:
|
||||
"""Remove the position_ids placeholder from the graph.
|
||||
|
||||
After fusing RoPE into AttentionPlugin, position_ids is no longer needed
|
||||
as an input since positional information is now provided through
|
||||
rope_rotary_cos_sin and kvcache_start_index.
|
||||
|
||||
This method also updates any sym_size operations that were reading from
|
||||
position_ids to use input_ids instead.
|
||||
|
||||
Args:
|
||||
gm: The GraphModule to modify.
|
||||
|
||||
Returns:
|
||||
Name of the removed position_ids node.
|
||||
"""
|
||||
graph = gm.graph
|
||||
position_ids_node = graph.find_nodes(op="placeholder", target="position_ids")[0]
|
||||
input_ids_node = graph.find_nodes(op="placeholder", target="input_ids")[0]
|
||||
|
||||
# Get size from token_ids placeholder instead of position_ids placeholder
|
||||
sym_size_int_nodes = graph.find_nodes(
|
||||
op="call_function", target=torch.ops.aten.sym_size.int
|
||||
)
|
||||
for node in sym_size_int_nodes:
|
||||
if node.args[0] == position_ids_node:
|
||||
new_args = (input_ids_node, *node.args[1:])
|
||||
node.args = new_args
|
||||
|
||||
return remove_graph_input(gm, position_ids_node)
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
ad_logger.debug("Fusing rope attention")
|
||||
|
||||
# Perform pattern matching
|
||||
model_config, _ = factory._get_model_config()
|
||||
matches = self._match_rope_attention_pattern(gm, model_config)
|
||||
cnt = len(matches)
|
||||
|
||||
ad_logger.info(f"Matched {cnt} rope attention patterns")
|
||||
|
||||
if cnt == 0:
|
||||
ad_logger.info("No patterns matched, skipping replacement")
|
||||
info = TransformInfo(skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True)
|
||||
return gm, info
|
||||
|
||||
# Add global placeholders
|
||||
context_lengths_node, rope_rotary_cos_sin_node, kvcache_start_index_node = (
|
||||
self._add_global_placeholders(gm, factory)
|
||||
)
|
||||
|
||||
# Perform replacement for all matches
|
||||
num_replaced = self._perform_replacement(
|
||||
gm,
|
||||
cm,
|
||||
matches,
|
||||
context_lengths_node,
|
||||
rope_rotary_cos_sin_node,
|
||||
kvcache_start_index_node,
|
||||
)
|
||||
|
||||
# Remove position_ids placeholder
|
||||
removed_node_name = self._remove_position_ids_placeholder(gm)
|
||||
ad_logger.debug(f"Removed position_ids placeholder: {removed_node_name}")
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False, num_matches=num_replaced, is_clean=True, has_valid_shapes=True
|
||||
)
|
||||
ad_logger.info(f"Fused {num_replaced} rope attention patterns")
|
||||
|
||||
return gm, info
|
||||
@ -0,0 +1,162 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils._graph import add_graph_input
|
||||
from ...utils.logger import ad_logger
|
||||
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
|
||||
|
||||
|
||||
@TransformRegistry.register("gather_last_token_ids")
|
||||
class GatherLastTokenIds(BaseTransform):
|
||||
def _find_last_linear_simple(self, gm: GraphModule) -> Node:
|
||||
"""Find the final torch_linear_simple node that produces logits.
|
||||
This is the last matmul operation in the graph
|
||||
that produces the vocabulary logits.
|
||||
"""
|
||||
graph = gm.graph
|
||||
|
||||
# Look for the output node and trace back to find MatMul
|
||||
output_node = graph.find_nodes(op="output")[0]
|
||||
|
||||
# Look for the last torch_linear_simple node
|
||||
linear_simple_nodes = graph.find_nodes(
|
||||
op="call_function", target=torch.ops.auto_deploy.torch_linear_simple.default, sort=True
|
||||
)
|
||||
assert len(linear_simple_nodes) > 0, "No linear simple nodes found"
|
||||
last_linear_simple_node = linear_simple_nodes[-1]
|
||||
|
||||
# Verify that the last linear simple node is the one producing the logits
|
||||
if last_linear_simple_node not in output_node.args[0]:
|
||||
ad_logger.warning(
|
||||
f"Last linear simple node {last_linear_simple_node.name} is not the one producing the logits"
|
||||
)
|
||||
return None
|
||||
|
||||
return last_linear_simple_node
|
||||
|
||||
def _add_last_token_ids_input(self, gm: GraphModule) -> Node:
|
||||
"""Add last_token_ids as a graph input.
|
||||
|
||||
Shape: int64[batch_size, num_selected_tokens]
|
||||
"""
|
||||
graph = gm.graph
|
||||
|
||||
# Find token_ids or input_ids placeholder to get batch_size dimension
|
||||
token_ids_node = None
|
||||
for node in graph.nodes:
|
||||
if node.op == "placeholder" and (
|
||||
"token" in node.name.lower() or "input_ids" in node.name
|
||||
):
|
||||
token_ids_node = node
|
||||
break
|
||||
|
||||
if token_ids_node is None:
|
||||
# Fallback: use first placeholder
|
||||
token_ids_node = graph.find_nodes(op="placeholder", sort=True)[0]
|
||||
|
||||
ad_logger.info(f"Using {token_ids_node.name} to extract batch_size dimension")
|
||||
|
||||
# Get symbolic batch_size dimension
|
||||
token_ids_meta = token_ids_node.meta.get("val")
|
||||
assert token_ids_meta is not None, "token_ids_meta is None"
|
||||
batch_size_dim = token_ids_meta.size(0)
|
||||
ad_logger.info(f"Extracted batch_size={batch_size_dim}")
|
||||
|
||||
# Add last_token_ids placeholder: int64[batch_size]
|
||||
num_selected_tokens = 2
|
||||
input_shape = (batch_size_dim, num_selected_tokens)
|
||||
last_token_ids_example = torch.zeros(input_shape, dtype=torch.int64, device="meta")
|
||||
last_token_ids_node = add_graph_input(gm, name="last_token_ids", val=last_token_ids_example)
|
||||
|
||||
ad_logger.info(f"Added last_token_ids placeholder: {last_token_ids_node.name}")
|
||||
|
||||
return last_token_ids_node
|
||||
|
||||
def _insert_gather_nd(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
linear_simple: Node,
|
||||
last_token_ids_node: Node,
|
||||
) -> bool:
|
||||
"""Insert GatherND operation before MatMul.
|
||||
|
||||
The pattern to create:
|
||||
1. Create GatherND with the linear input and last_token_ids
|
||||
2. Replace MatMul input with GatherND output
|
||||
"""
|
||||
graph = gm.graph
|
||||
|
||||
# Get the input to linear_simple (should be from the previous layer)
|
||||
linear_input = linear_simple.args[0]
|
||||
|
||||
ad_logger.info(
|
||||
f"Linear input: {linear_input.name if hasattr(linear_input, 'name') else linear_input}"
|
||||
)
|
||||
|
||||
# 1. Insert GatherND operation before linear_simple
|
||||
# GatherND takes the hidden states (linear_input) and gathers based on last_token_ids
|
||||
with graph.inserting_before(linear_simple):
|
||||
unsqueeze_node = graph.call_function(
|
||||
torch.ops.aten.unsqueeze.default, args=(last_token_ids_node, -1)
|
||||
)
|
||||
gather_nd_node = graph.call_function(
|
||||
torch.ops.auto_deploy.torch_onnx_gather_nd.default,
|
||||
args=(linear_input, unsqueeze_node, 1), # Use linear_input, not linear_simple!
|
||||
)
|
||||
ad_logger.info(f"Created GatherND node: {gather_nd_node.name}")
|
||||
|
||||
# 2. Replace linear_simple's first input with GatherND output
|
||||
linear_simple.replace_input_with(linear_input, gather_nd_node)
|
||||
ad_logger.info("Replaced Linear input with GatherND output")
|
||||
|
||||
return True
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
ad_logger.info("Adding last_token_ids gather operation")
|
||||
|
||||
# Step 1: Find last linear simple node
|
||||
linear_simple_node = self._find_last_linear_simple(gm)
|
||||
|
||||
if linear_simple_node is None:
|
||||
ad_logger.info("Could not find last linear simple node, skipping")
|
||||
return gm, TransformInfo(
|
||||
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
|
||||
)
|
||||
|
||||
# Step 2: Add last_token_ids input
|
||||
last_token_ids_node = self._add_last_token_ids_input(gm)
|
||||
|
||||
# Step 3: Insert GatherND operation
|
||||
success = self._insert_gather_nd(gm, linear_simple_node, last_token_ids_node)
|
||||
|
||||
if not success:
|
||||
ad_logger.info("Failed to insert GatherND operation")
|
||||
return gm, TransformInfo(
|
||||
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
|
||||
)
|
||||
|
||||
return gm, TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)
|
||||
@ -0,0 +1,165 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
|
||||
|
||||
|
||||
@TransformRegistry.register("short_reshape_attention_output")
|
||||
class ShortReshapeAttentionOutput(BaseTransform):
|
||||
"""Transform that simplifies reshape operations after attention outputs.
|
||||
|
||||
This transform optimizes reshape nodes that follow AttentionPlugin outputs
|
||||
by replacing fixed shape dimensions with dynamic symbolic dimensions derived
|
||||
from the input tensor. This ensures the reshape operation correctly handles
|
||||
dynamic batch sizes and sequence lengths.
|
||||
|
||||
The transform identifies patterns where:
|
||||
1. A reshape node follows an AttentionPlugin output
|
||||
2. The reshape's input can be traced back to a linear projection
|
||||
|
||||
It then replaces the reshape with a new one that uses symbolic dimensions
|
||||
from the linear projection's input, resulting in the shape [batch, seq, -1].
|
||||
"""
|
||||
|
||||
def _lookup_ascending_node(self, node: Node, target, max_depth: int = 3) -> Node:
|
||||
"""Recursively search for an ancestor node with a specific target operation.
|
||||
|
||||
Traverses the graph backwards through node arguments to find a node
|
||||
that matches the specified target operation type.
|
||||
|
||||
Args:
|
||||
node: Starting node for the backward search.
|
||||
target: Target operation to search for (e.g., torch.ops.aten.reshape.default).
|
||||
max_depth: Maximum number of levels to traverse (default: 3).
|
||||
|
||||
Returns:
|
||||
The matching ancestor node if found, None otherwise.
|
||||
"""
|
||||
if max_depth == 0:
|
||||
return None
|
||||
if node.target == target:
|
||||
return node
|
||||
|
||||
# Helper function to check a single node
|
||||
def check_node(n):
|
||||
if isinstance(n, Node):
|
||||
result = self._lookup_ascending_node(n, target, max_depth - 1)
|
||||
if result is not None:
|
||||
return result
|
||||
return None
|
||||
|
||||
# Check all arguments
|
||||
for arg in node.args:
|
||||
if isinstance(arg, (tuple, list)):
|
||||
for item in arg:
|
||||
result = check_node(item)
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
result = check_node(arg)
|
||||
if result is not None:
|
||||
return result
|
||||
return None
|
||||
|
||||
def _find_reshape_attention_output(self, gm: GraphModule) -> List[Node]:
|
||||
"""Find reshape nodes that follow AttentionPlugin outputs.
|
||||
|
||||
Searches for reshape operations that:
|
||||
1. Have an AttentionPlugin in their input chain
|
||||
2. Can be traced back to a torch_linear_simple operation
|
||||
|
||||
The linear operation is needed to extract the original input shape
|
||||
for constructing the new reshape with symbolic dimensions.
|
||||
|
||||
Args:
|
||||
gm: The GraphModule to search.
|
||||
|
||||
Returns:
|
||||
List of (reshape_node, linear_node) tuples representing the
|
||||
matched patterns.
|
||||
"""
|
||||
reshape_linear_pairs = []
|
||||
reshape_nodes = gm.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.aten.reshape.default
|
||||
)
|
||||
|
||||
for node in reshape_nodes:
|
||||
# Looking for AttentionPlugin from the input of the reshape node
|
||||
attention_plugin_node = self._lookup_ascending_node(
|
||||
node.args[0], torch.ops.auto_deploy.torch_onnx_attention_plugin.default
|
||||
)
|
||||
if attention_plugin_node is None:
|
||||
continue
|
||||
|
||||
# Looking for torch_linear_simple from the input of the AttentionPlugin node
|
||||
linear_simple_node = self._lookup_ascending_node(
|
||||
attention_plugin_node.args[0], torch.ops.auto_deploy.torch_linear_simple.default
|
||||
)
|
||||
if linear_simple_node is None:
|
||||
continue
|
||||
|
||||
reshape_linear_pairs.append((node, linear_simple_node))
|
||||
|
||||
return reshape_linear_pairs
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
num_matches = 0
|
||||
reshape_linear_pairs = self._find_reshape_attention_output(gm)
|
||||
self._log_info(f"Found {len(reshape_linear_pairs)} reshape_linear_pairs")
|
||||
|
||||
graph = gm.graph
|
||||
for reshape_node, linear_simple_node in reshape_linear_pairs:
|
||||
# Get the input tensor of reshape node (keep it unchanged)
|
||||
shape_src = linear_simple_node.args[0]
|
||||
data_input = reshape_node.args[0]
|
||||
|
||||
# Extract first two dimensions from the input tensor
|
||||
with graph.inserting_before(reshape_node):
|
||||
dim0 = graph.call_function(torch.ops.aten.sym_size.int, args=(shape_src, 0))
|
||||
dim1 = graph.call_function(torch.ops.aten.sym_size.int, args=(shape_src, 1))
|
||||
|
||||
# Create new shape using input tensor's first two dimensions
|
||||
new_shape = [dim0, dim1, -1]
|
||||
|
||||
# Create a NEW reshape node instead of modifying the existing one
|
||||
# This ensures the graph dependencies are correctly established
|
||||
new_reshape_node = graph.call_function(
|
||||
torch.ops.aten.reshape.default, args=(data_input, new_shape)
|
||||
)
|
||||
|
||||
# Replace all uses of the old reshape node with the new one
|
||||
reshape_node.replace_all_uses_with(new_reshape_node)
|
||||
|
||||
# Remove the old reshape node from the graph
|
||||
graph.erase_node(reshape_node)
|
||||
|
||||
num_matches += 1
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=False
|
||||
)
|
||||
return gm, info
|
||||
@ -64,11 +64,11 @@ class InferenceOptimizer:
|
||||
mod = nn.Module()
|
||||
|
||||
# iterate over all transforms sorted by stage in the config
|
||||
for t_name, t_config in self.config.items():
|
||||
for idx, (t_name, t_config) in enumerate(self.config.items()):
|
||||
# instantiate transform
|
||||
transform = TransformRegistry.get(t_name)(t_config)
|
||||
# run transform
|
||||
mod = transform(mod, cm, self.factory, self.shared_config)
|
||||
mod = transform(mod, cm, self.factory, self.shared_config, idx)
|
||||
|
||||
############################################################################################
|
||||
# RETURN OPTIMIZED MODEL
|
||||
|
||||
@ -15,12 +15,26 @@ from torch._subclasses import FakeTensor, FakeTensorMode
|
||||
from torch.fx import Graph, GraphModule, Node
|
||||
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
|
||||
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
||||
from torch.utils._pytree import _LEAF_SPEC
|
||||
from torch.utils._pytree import _LEAF_SPEC, TreeSpec
|
||||
|
||||
from .logger import ad_logger
|
||||
from .node_utils import get_weight_tensor, is_op
|
||||
|
||||
_NoValType = type("_NoValType", (), {})
|
||||
|
||||
|
||||
def _call_post_init_recursive(spec: TreeSpec) -> None:
|
||||
"""Recursively call __post_init__ on TreeSpec starting from leaves.
|
||||
|
||||
This is needed to update internal cached values (like num_leaves) after
|
||||
modifying children_specs.
|
||||
"""
|
||||
if hasattr(spec, "children_specs"):
|
||||
for child_spec in spec.children_specs:
|
||||
_call_post_init_recursive(child_spec)
|
||||
spec.__post_init__()
|
||||
|
||||
|
||||
_NO_VAL = _NoValType()
|
||||
|
||||
|
||||
@ -367,12 +381,7 @@ def add_graph_input(
|
||||
in_spec_for_args.context.append(name)
|
||||
|
||||
# update pytree info recursively with __post_init__ starting at leaves
|
||||
def call_post_init(spec):
|
||||
for child_spec in spec.children_specs:
|
||||
call_post_init(child_spec)
|
||||
spec.__post_init__()
|
||||
|
||||
call_post_init(in_spec)
|
||||
_call_post_init_recursive(in_spec)
|
||||
|
||||
# set fake tensor information if all required information is available
|
||||
fake_mode: Optional[FakeTensorMode] = _detect_fake_mode_from_gm(gm)
|
||||
@ -527,3 +536,194 @@ def del_attr_by_name(obj, name):
|
||||
for part in parts[:-1]:
|
||||
obj = getattr(obj, part)
|
||||
delattr(obj, parts[-1])
|
||||
|
||||
|
||||
def add_graph_output(gm: GraphModule, output_node: Node, name: str) -> None:
|
||||
"""Add a graph output to the given GraphModule.
|
||||
|
||||
This function appends a new output to the graph's output node and updates
|
||||
the pytree_info metadata accordingly.
|
||||
|
||||
NOTE: function does NOT do any graph canonicalization. This is left to the user!
|
||||
|
||||
Args:
|
||||
gm (GraphModule): The GraphModule to add the output to.
|
||||
output_node (Node): The node to add as an output.
|
||||
name (str): The name of the new output in the output dict.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the graph has no output node or multiple output nodes.
|
||||
|
||||
Implementation Notes (Hacky Details):
|
||||
1. **Handling single-output graphs (out_spec == _LEAF_SPEC)**:
|
||||
When the graph originally has a single output, its out_spec is a
|
||||
_LEAF_SPEC (a leaf node in the pytree). To add a new output, we must
|
||||
first convert it to a container spec (tuple with children_specs).
|
||||
|
||||
2. **Forcing output type to dict**:
|
||||
The original out_spec.type is typically a model-specific output class
|
||||
(e.g., CausalLMOutputWithPast from transformers). When we add new
|
||||
outputs with custom names (like "present_key_values_xx"), the original
|
||||
class constructor will fail because these names are not valid data
|
||||
members of that class.
|
||||
|
||||
PyTorch uses the out_spec.type's constructor to wrap the output:
|
||||
result_obj = result_type(**output)
|
||||
|
||||
To avoid constructor failures, we forcibly change out_spec.type to
|
||||
dict using object.__setattr__ (since TreeSpec is a frozen dataclass).
|
||||
This ensures the output is returned as a plain dictionary instead of
|
||||
the original model-specific class.
|
||||
"""
|
||||
# extract graph and output spec
|
||||
graph: Graph = gm.graph
|
||||
|
||||
# find the output node
|
||||
output_nodes = graph.find_nodes(op="output")
|
||||
if len(output_nodes) != 1:
|
||||
raise RuntimeError(f"Expected exactly one output node, found {len(output_nodes)}")
|
||||
|
||||
graph_output_node = output_nodes[0]
|
||||
|
||||
# get current outputs
|
||||
current_outputs = graph_output_node.args[0]
|
||||
if not isinstance(current_outputs, (tuple, list)):
|
||||
# single output, convert to tuple
|
||||
current_outputs = (current_outputs,)
|
||||
|
||||
# append new output
|
||||
new_outputs = tuple(current_outputs) + (output_node,)
|
||||
graph_output_node.args = (new_outputs,)
|
||||
|
||||
# update pytree info: append spec for the new output
|
||||
# The out_spec structure mirrors the output structure
|
||||
# For a tuple of outputs, out_spec should be a TreeSpec with children_specs
|
||||
# And graph._codegen.pytree_info is a NamedTuple, so we have to use _replace to replace the out_spec
|
||||
out_spec = graph._codegen.pytree_info.out_spec
|
||||
assert out_spec is not None, "Graph must have an out_spec"
|
||||
|
||||
if out_spec == _LEAF_SPEC:
|
||||
new_out_spec = TreeSpec(type=tuple, children_specs=[_LEAF_SPEC], context=["output"])
|
||||
graph._codegen.pytree_info = graph._codegen.pytree_info._replace(out_spec=new_out_spec)
|
||||
out_spec = graph._codegen.pytree_info.out_spec
|
||||
if hasattr(out_spec, "children_specs"):
|
||||
# out_spec is already a container (tuple/list), append to it
|
||||
out_spec.children_specs.append(_LEAF_SPEC)
|
||||
out_spec.context.append(name)
|
||||
else:
|
||||
# This shouldn't happen in normal cases, but handle it just in case
|
||||
# If out_spec is a single leaf, we need to convert it to a tuple spec
|
||||
raise NotImplementedError(
|
||||
"Cannot add output to a graph with non-container out_spec. "
|
||||
"This case is not currently supported."
|
||||
)
|
||||
|
||||
# update pytree info recursively with __post_init__ starting at leaves
|
||||
_call_post_init_recursive(out_spec)
|
||||
|
||||
# Change the type of the output spec to dict,
|
||||
#
|
||||
# NOTE(yocox) This will affect how torch inteprete the output of the graph.
|
||||
# The original type is some LLM output result class,
|
||||
# for example, CausalLMOutputWithPast, depends on the model.
|
||||
# To add new fields to the output, we need to change the type to dict.
|
||||
# The type's constructor will be used to create an object containing
|
||||
# the result, for example,
|
||||
#
|
||||
# result_obj = reuslt_type(**output)
|
||||
#
|
||||
# Because we added new output's with names like "present_key_values_xx" which
|
||||
# is not a data member of CausalLMOutputWithPast, so the constructor will fail.
|
||||
# So we need to change the type to dict.
|
||||
#
|
||||
# However the out_spec.type is frozen dataclass,
|
||||
# so we need to use object.__setattr__ to change it.
|
||||
object.__setattr__(out_spec, "type", dict)
|
||||
|
||||
|
||||
def remove_graph_input(gm: GraphModule, input_node: Node) -> str:
|
||||
"""Remove a graph input from the given GraphModule.
|
||||
|
||||
This is the inverse operation of add_graph_input(). It removes a placeholder node
|
||||
and updates the pytree_info metadata accordingly. The function automatically detects
|
||||
whether the node belongs to args or kwargs and updates the correct spec.
|
||||
|
||||
NOTE: function does NOT do any graph canonicalization. This is left to the user!
|
||||
|
||||
Args:
|
||||
gm (GraphModule): The GraphModule to remove the input from.
|
||||
input_node (Node): The placeholder node to remove.
|
||||
|
||||
Returns:
|
||||
str: The name of the removed node.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provided Node is not a placeholder or not in the graph.
|
||||
RuntimeError: If the input node is still being used by other nodes.
|
||||
"""
|
||||
graph: Graph = gm.graph
|
||||
|
||||
# validate that the node is a placeholder
|
||||
if input_node.op != "placeholder":
|
||||
raise ValueError(
|
||||
f"Node '{input_node.name}' is not a placeholder node (op='{input_node.op}'). "
|
||||
f"Only placeholder nodes can be removed as graph inputs."
|
||||
)
|
||||
|
||||
# find all placeholder nodes and validate the node is in this graph
|
||||
placeholder_nodes = graph.find_nodes(op="placeholder", sort=True)
|
||||
if input_node not in placeholder_nodes:
|
||||
raise ValueError(
|
||||
f"Node '{input_node.name}' is not found in the graph's placeholder nodes. "
|
||||
f"It may belong to a different graph or have already been removed."
|
||||
)
|
||||
|
||||
# check that the node is not being used
|
||||
if len(input_node.users) > 0:
|
||||
user_names = [user.name for user in input_node.users.keys()]
|
||||
raise RuntimeError(
|
||||
f"Cannot remove input '{input_node.name}' "
|
||||
f"because it is still being used by nodes: {user_names}. "
|
||||
f"Please remove or replace all usages first."
|
||||
)
|
||||
|
||||
# get pytree info
|
||||
in_spec = graph._codegen.pytree_info.in_spec
|
||||
args_spec = in_spec.children_specs[0] # tuple for args
|
||||
kwargs_spec = in_spec.children_specs[1] # dict for kwargs
|
||||
orig_args = graph._codegen.pytree_info.orig_args
|
||||
|
||||
# determine the global index of the node
|
||||
global_idx = placeholder_nodes.index(input_node)
|
||||
|
||||
# determine number of args (kwargs come after args in placeholder order)
|
||||
num_args = len(args_spec.children_specs)
|
||||
|
||||
# determine if this is an arg or kwarg, and compute relative index
|
||||
if global_idx < num_args:
|
||||
# it's an arg
|
||||
relative_idx = global_idx
|
||||
target_spec = args_spec
|
||||
is_kwarg = False
|
||||
else:
|
||||
# it's a kwarg
|
||||
relative_idx = global_idx - num_args
|
||||
target_spec = kwargs_spec
|
||||
is_kwarg = True
|
||||
|
||||
# save the node name before removing
|
||||
removed_node_name = input_node.name
|
||||
|
||||
# remove the node from the graph
|
||||
graph.erase_node(input_node)
|
||||
|
||||
# update pytree info: remove the corresponding spec and arg name
|
||||
target_spec.children_specs.pop(relative_idx)
|
||||
if is_kwarg:
|
||||
target_spec.context.pop(relative_idx)
|
||||
orig_args.pop(global_idx)
|
||||
|
||||
# update pytree info recursively with __post_init__ starting at leaves
|
||||
_call_post_init_recursive(in_spec)
|
||||
|
||||
return removed_node_name
|
||||
|
||||
@ -56,7 +56,7 @@ class TorchAttentionReference:
|
||||
v_flat = v.reshape(1, batch_size * seq_len, -1)
|
||||
|
||||
# Call torch backend via custom op registry
|
||||
output_flat = torch.ops.auto_deploy.torch_cached_attention_with_cache(
|
||||
output_flat = torch.ops.auto_deploy.torch_cached_attention_with_cache.default(
|
||||
q_flat,
|
||||
k_flat,
|
||||
v_flat,
|
||||
@ -97,7 +97,7 @@ class TorchAttentionReference:
|
||||
|
||||
This function directly calls the torch backend implementation via custom op registry.
|
||||
"""
|
||||
return torch.ops.auto_deploy.torch_cached_attention_with_cache(
|
||||
return torch.ops.auto_deploy.torch_cached_attention_with_cache.default(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -148,7 +148,7 @@ class TorchAttentionReference:
|
||||
batch_info_host = torch.tensor([0, 0, batch_size], device=q.device, dtype=torch.int32)
|
||||
|
||||
# Call torch backend via custom op registry
|
||||
output_flat = torch.ops.auto_deploy.torch_cached_attention_with_cache(
|
||||
output_flat = torch.ops.auto_deploy.torch_cached_attention_with_cache.default(
|
||||
q_flat,
|
||||
k_flat,
|
||||
v_flat,
|
||||
@ -185,7 +185,7 @@ class TorchAttentionReference:
|
||||
|
||||
This demonstrates how to use the torch backend with additional features.
|
||||
"""
|
||||
return torch.ops.auto_deploy.torch_cached_attention_with_cache(
|
||||
return torch.ops.auto_deploy.torch_cached_attention_with_cache.default(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
|
||||
@ -0,0 +1,65 @@
|
||||
import os
|
||||
|
||||
import onnx
|
||||
import pytest
|
||||
from _model_test_utils import get_small_model_config
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.export import export_onnx
|
||||
from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, output_dir, num_attn_ops",
|
||||
[
|
||||
(
|
||||
"Qwen/Qwen2.5-3B-Instruct",
|
||||
"/tmp/test_ad_export_onnx_qwen2.5-3b",
|
||||
36,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_ad_export_onnx(model: str, output_dir: str, num_attn_ops: int):
|
||||
ad_config = LlmArgs(
|
||||
model=get_small_model_config(model)["args"]["model"],
|
||||
mode="export_edgellm_onnx",
|
||||
max_batch_size=13,
|
||||
max_seq_len=4,
|
||||
)
|
||||
ad_config.transforms["export_to_onnx"]["output_dir"] = output_dir
|
||||
export_onnx(ad_config)
|
||||
|
||||
# check if the output directory exists
|
||||
assert os.path.exists(output_dir)
|
||||
|
||||
# check if the output json files exist
|
||||
assert os.path.exists(os.path.join(output_dir, "added_tokens.json"))
|
||||
assert os.path.exists(os.path.join(output_dir, "config.json"))
|
||||
assert os.path.exists(os.path.join(output_dir, "processed_chat_template.json"))
|
||||
assert os.path.exists(os.path.join(output_dir, "special_tokens_map.json"))
|
||||
assert os.path.exists(os.path.join(output_dir, "tokenizer.json"))
|
||||
assert os.path.exists(os.path.join(output_dir, "tokenizer_config.json"))
|
||||
assert os.path.exists(os.path.join(output_dir, "vocab.json"))
|
||||
|
||||
# check if the output onnx file exists
|
||||
assert os.path.exists(os.path.join(output_dir, "model.onnx"))
|
||||
|
||||
# check if the onnx file has 24 AttentionPlugin operators
|
||||
onnx_model = onnx.load(os.path.join(output_dir, "model.onnx"))
|
||||
attn_ops = [node for node in onnx_model.graph.node if node.op_type == "AttentionPlugin"]
|
||||
assert len(attn_ops) == num_attn_ops
|
||||
|
||||
# check input and output names of the model
|
||||
actual_inputs_names = {input.name for input in onnx_model.graph.input}
|
||||
actual_outputs_names = {output.name for output in onnx_model.graph.output}
|
||||
expect_inputs_names = {
|
||||
"input_ids",
|
||||
"context_lengths",
|
||||
"rope_rotary_cos_sin",
|
||||
"kvcache_start_index",
|
||||
"last_token_ids",
|
||||
}
|
||||
expect_outputs_names = {"logits"}
|
||||
expect_inputs_names.update({f"past_key_values_{i}" for i in range(num_attn_ops)})
|
||||
expect_outputs_names.update({f"present_key_values_{i}" for i in range(num_attn_ops)})
|
||||
assert actual_inputs_names == expect_inputs_names
|
||||
assert actual_outputs_names == expect_outputs_names
|
||||
@ -0,0 +1,242 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Tests for fuse_rope_attention transformation.
|
||||
"""
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.export import Dim
|
||||
|
||||
# Import modules to register custom ops (torch.ops.auto_deploy.*)
|
||||
import tensorrt_llm._torch.auto_deploy.custom_ops.torch_attention # noqa: F401
|
||||
import tensorrt_llm._torch.auto_deploy.custom_ops.torch_rope # noqa: F401
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
def _get_cos_sin(x: torch.Tensor, head_dim):
|
||||
"""Precompute cos and sin for RoPE."""
|
||||
batch_size = x.size(0)
|
||||
seq_len = x.size(1)
|
||||
cos = x.new_zeros(batch_size, seq_len, head_dim)
|
||||
sin = x.new_zeros(batch_size, seq_len, head_dim)
|
||||
return cos, sin
|
||||
|
||||
|
||||
class RopeAttentionModel(torch.nn.Module):
|
||||
"""
|
||||
Model that implements the rope + attention pattern that fuse_rope_attention expects.
|
||||
|
||||
Pattern:
|
||||
1. q_proj, k_proj, v_proj
|
||||
2. view to [batch, seq, num_heads, head_dim]
|
||||
3. contiguous
|
||||
4. torch_rope_with_explicit_cos_sin for q and k
|
||||
5. torch_attention with layout="bsnd"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_dim: int,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_seq_len: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = head_dim * num_q_heads
|
||||
self.head_dim = head_dim
|
||||
self.hidden_size = hidden_size
|
||||
self.num_q_heads = num_q_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_dim = hidden_size // num_q_heads
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
# Linear projections
|
||||
self.q_proj = torch.nn.Linear(hidden_size, num_q_heads * self.head_dim, bias=False)
|
||||
self.k_proj = torch.nn.Linear(hidden_size, num_kv_heads * self.head_dim, bias=False)
|
||||
self.v_proj = torch.nn.Linear(hidden_size, num_kv_heads * self.head_dim, bias=False)
|
||||
self.o_proj = torch.nn.Linear(num_q_heads * self.head_dim, hidden_size, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor, # Required for export API compatibility
|
||||
) -> torch.Tensor:
|
||||
x = input_ids.unsqueeze(-1).expand(-1, -1, self.hidden_size).to(torch.float16)
|
||||
batch_size, seq_len, _ = x.shape
|
||||
|
||||
# Generate q, k, v: [batch, seq, hidden_size] -> [batch, seq, num_heads * head_dim]
|
||||
q = self.q_proj(x) # [batch, seq, num_q_heads * head_dim]
|
||||
k = self.k_proj(x) # [batch, seq, num_kv_heads * head_dim]
|
||||
v = self.v_proj(x) # [batch, seq, num_kv_heads * head_dim]
|
||||
|
||||
# View to [batch, seq, num_heads, head_dim]
|
||||
q = q.view(batch_size, seq_len, self.num_q_heads, self.head_dim)
|
||||
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
|
||||
v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
|
||||
|
||||
# Apply contiguous (this is part of the pattern)
|
||||
q = torch.ops.aten.contiguous(q)
|
||||
k = torch.ops.aten.contiguous(k)
|
||||
|
||||
# Precompute cos and sin for RoPE
|
||||
cos, sin = _get_cos_sin(x, self.head_dim)
|
||||
|
||||
# Apply RoPE with unsqueeze_dim=2 for BSND layout
|
||||
# NOTE(yoco): This should be (q, k) according to the definition of torch_rope_with_explicit_cos_sin.
|
||||
# However, the actual graph is (k, q), Need further investigation.
|
||||
q_rot, k_rot = torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin(k, q, cos, sin, 2)
|
||||
|
||||
# Apply attention with layout="bsnd"
|
||||
attn_output = torch.ops.auto_deploy.torch_attention(
|
||||
k_rot,
|
||||
q_rot,
|
||||
v,
|
||||
None, # attn_mask
|
||||
0.0, # dropout_p
|
||||
True, # is_causal
|
||||
None, # scale
|
||||
None, # sinks
|
||||
None, # sliding_window
|
||||
None, # logit_cap
|
||||
"bsnd", # layout
|
||||
)
|
||||
|
||||
# Reshape back and apply output projection
|
||||
attn_output = attn_output.reshape(batch_size, seq_len, -1)
|
||||
output = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
def _apply_rope_manual(self, q, k, cos, sin):
|
||||
"""Manual rope application for fallback/comparison."""
|
||||
|
||||
def rotate_half(x):
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
cos = cos.unsqueeze(2) # unsqueeze_dim=2 for BSND
|
||||
sin = sin.unsqueeze(2)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
def get_dynamic_shapes(self):
|
||||
return [
|
||||
{0: Dim("batch_size", max=8), 1: Dim("seq_len", max=128)},
|
||||
{0: Dim("batch_size", max=8), 1: Dim("seq_len", max=128)},
|
||||
]
|
||||
|
||||
|
||||
def convert_placeholder_meta_to_cuda(gm: torch.fx.GraphModule):
|
||||
"""Convert placeholder meta to cuda."""
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
node.meta["val"] = node.meta["val"].to("cuda")
|
||||
|
||||
|
||||
def _run_test(
|
||||
head_dim: int,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
):
|
||||
"""Helper function to run the transformation test."""
|
||||
model = RopeAttentionModel(
|
||||
head_dim=head_dim,
|
||||
num_q_heads=num_q_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
).to("cuda", torch.float16)
|
||||
input_ids = torch.randint(0, 10, (batch_size, seq_len), device="cuda", dtype=torch.int32)
|
||||
position_ids = torch.randint(0, 10, (batch_size, seq_len), device="cuda", dtype=torch.int32)
|
||||
dynamic_shapes = model.get_dynamic_shapes()
|
||||
|
||||
# Export to graph module
|
||||
args = (input_ids, position_ids)
|
||||
gm = torch_export_to_gm(model, args=args, dynamic_shapes=dynamic_shapes, clone=True)
|
||||
|
||||
# Model config object, include num_attention_heads and hidden_size
|
||||
class Factory:
|
||||
def _get_model_config(self):
|
||||
ModelConfig = namedtuple(
|
||||
"ModelConfig", ["num_attention_heads", "hidden_size", "num_key_value_heads"]
|
||||
)
|
||||
return ModelConfig(
|
||||
num_attention_heads=num_q_heads,
|
||||
hidden_size=head_dim * num_q_heads,
|
||||
num_key_value_heads=num_kv_heads,
|
||||
), None
|
||||
|
||||
# Apply fuse_rope_attention transformation
|
||||
|
||||
optimizer = InferenceOptimizer(
|
||||
Factory(),
|
||||
{
|
||||
"fuse_rope_attention": {
|
||||
"stage": "pattern_matcher",
|
||||
},
|
||||
},
|
||||
)
|
||||
convert_placeholder_meta_to_cuda(gm)
|
||||
gm_transformed = optimizer(None, gm)
|
||||
gm_transformed.to("cuda")
|
||||
|
||||
fused_nodes = gm_transformed.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.auto_deploy.torch_onnx_attention_plugin.default
|
||||
)
|
||||
assert len(fused_nodes) == 1, "Expected 1 AttentionPlugin node, got {len(fused_nodes)}"
|
||||
input_nodes = gm_transformed.graph.find_nodes(op="placeholder")
|
||||
assert len(input_nodes) == 5, "Expected 5 input nodes, got {len(input_nodes)}"
|
||||
input_nodes_targets = {node.target for node in input_nodes}
|
||||
assert input_nodes_targets == {
|
||||
"input_ids",
|
||||
"context_lengths",
|
||||
"rope_rotary_cos_sin",
|
||||
"kvcache_start_index",
|
||||
"past_key_values_0",
|
||||
}
|
||||
out_node = gm_transformed.graph.find_nodes(op="output")[0]
|
||||
assert len(out_node.args[0]) == 2, "Expected 2 output nodes, got {len(out_node.args[0])}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"head_dim,num_q_heads,num_kv_heads,batch_size,seq_len",
|
||||
[
|
||||
# GQA (Grouped-Query Attention): num_q_heads > num_kv_heads
|
||||
pytest.param(64, 14, 2, 2, 16, id="gqa_ratio_7"),
|
||||
],
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_fuse_rope_attention(
|
||||
head_dim: int,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
):
|
||||
"""Test fuse_rope_attention transformation with various configurations."""
|
||||
_run_test(
|
||||
head_dim=head_dim,
|
||||
num_q_heads=num_q_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
batch_size=batch_size,
|
||||
seq_len=seq_len,
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user