From 4af47208d815540766511eb0073eecd1ac9845e2 Mon Sep 17 00:00:00 2001 From: nvyocox Date: Sat, 31 Jan 2026 04:43:11 +0800 Subject: [PATCH] [None][feat] Export ONNX for DriveOS LLM (#10117) Signed-off-by: yocox --- .gitignore | 1 + docker/common/install_base.sh | 1 + .../torch/auto_deploy/advanced/export_onnx.md | 93 +++ docs/source/torch/auto_deploy/auto-deploy.md | 1 + examples/auto_deploy/onnx_export_llm.py | 78 ++ jenkins/current_image_tags.properties | 8 +- requirements.txt | 2 + .../config/export_edgellm_onnx.yaml | 126 ++++ .../auto_deploy/custom_ops/onnx_attention.py | 194 +++++ .../_torch/auto_deploy/export/export.py | 59 +- tensorrt_llm/_torch/auto_deploy/llm_args.py | 3 +- .../transform/graph_module_visualizer.py | 691 ++++++++++++++++++ .../_torch/auto_deploy/transform/interface.py | 39 +- .../transform/library/_chat_template.py | 382 ++++++++++ .../transform/library/_config_export.py | 205 ++++++ .../transform/library/_onnx_schemas.py | 258 +++++++ .../transform/library/adapt_to_edgellm.py | 166 +++++ .../transform/library/export_to_onnx.py | 418 +++++++++++ .../transform/library/fuse_rope_attention.py | 551 ++++++++++++++ .../library/gather_last_token_ids.py | 162 ++++ .../library/short_reshape_attention_output.py | 165 +++++ .../_torch/auto_deploy/transform/optimizer.py | 4 +- .../_torch/auto_deploy/utils/_graph.py | 214 +++++- .../_utils_test/torch_attention_reference.py | 8 +- .../unit/singlegpu/test_ad_export_onnx.py | 65 ++ .../library/test_fuse_rope_attention.py | 242 ++++++ 26 files changed, 4115 insertions(+), 21 deletions(-) create mode 100644 docs/source/torch/auto_deploy/advanced/export_onnx.md create mode 100644 examples/auto_deploy/onnx_export_llm.py create mode 100644 tensorrt_llm/_torch/auto_deploy/config/export_edgellm_onnx.yaml create mode 100644 tensorrt_llm/_torch/auto_deploy/custom_ops/onnx_attention.py create mode 100644 tensorrt_llm/_torch/auto_deploy/transform/graph_module_visualizer.py create mode 100644 tensorrt_llm/_torch/auto_deploy/transform/library/_chat_template.py create mode 100644 tensorrt_llm/_torch/auto_deploy/transform/library/_config_export.py create mode 100644 tensorrt_llm/_torch/auto_deploy/transform/library/_onnx_schemas.py create mode 100644 tensorrt_llm/_torch/auto_deploy/transform/library/adapt_to_edgellm.py create mode 100644 tensorrt_llm/_torch/auto_deploy/transform/library/export_to_onnx.py create mode 100644 tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rope_attention.py create mode 100644 tensorrt_llm/_torch/auto_deploy/transform/library/gather_last_token_ids.py create mode 100644 tensorrt_llm/_torch/auto_deploy/transform/library/short_reshape_attention_output.py create mode 100644 tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_export_onnx.py create mode 100644 tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rope_attention.py diff --git a/.gitignore b/.gitignore index d409a49f48..d7c360cce6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ __pycache__/ .vscode +.cursor *.engine *.engine.config *.cache diff --git a/docker/common/install_base.sh b/docker/common/install_base.sh index 2d20657895..20d83a3d76 100644 --- a/docker/common/install_base.sh +++ b/docker/common/install_base.sh @@ -68,6 +68,7 @@ init_ubuntu() { gdb \ git-lfs \ clang \ + graphviz \ lld \ llvm \ libclang-rt-dev \ diff --git a/docs/source/torch/auto_deploy/advanced/export_onnx.md b/docs/source/torch/auto_deploy/advanced/export_onnx.md new file mode 100644 index 0000000000..19c8aa5682 --- /dev/null +++ b/docs/source/torch/auto_deploy/advanced/export_onnx.md @@ -0,0 +1,93 @@ +# Export ONNX for EdgeLLM + +AutoDeploy provides a mode to export PyTorch/HuggingFace models to ONNX format specifically designed for EdgeLLM deployment. This mode performs graph transformations to fuse RoPE (Rotary Position Embedding) and attention operations into a single `AttentionPlugin` operation, then exports the optimized graph to ONNX. + +## Overview + +The `export_edgellm_onnx` mode differs from the standard AutoDeploy workflow in two key ways: + +1. **Operation Fusion**: Fuses `torch_rope_with_explicit_cos_sin` and `torch_cached_attention_with_cache` into a single `AttentionPlugin` operation +1. **ONNX Export**: Outputs an ONNX model file instead of a TensorRT Engine + +## Quick Start + +Use the `onnx_export_llm.py` script to export a model: + +```bash +cd examples/auto_deploy +python onnx_export_llm.py --model "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" +``` + +This will export the model to ONNX format in the current directory. + +## Command Line Options + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `--model` | str | Required | HuggingFace model name or path to a local checkpoint | +| `--device` | str | `cpu` | Device to use for export (`cpu` or `cuda`) | +| `--output_dir` | str | `.` | Directory to save the exported ONNX model | + +## Examples + +### Basic Export + +Export a DeepSeek model with default settings: + +```bash +python onnx_export_llm.py --model "Qwen/Qwen2.5-0.5B-Instruct" +``` + +### Custom Output Location + +Export to a specific directory with a custom filename: + +```bash +python onnx_export_llm.py \ + --model "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" \ + --output_dir "./exported_models" +``` + +## Output Files + +The export process generates the following files in the output directory: + +| File | Description | +|------|-------------| +| `model.onnx` | The exported ONNX model with fused attention operations | +| `config.json` | Model configuration (architecture, hidden size, etc.) | +| `tokenizer.json` | Tokenizer vocabulary and configuration | +| `tokenizer_config.json` | Tokenizer settings | +| `special_tokens_map.json` | Special token mappings | +| `processed_chat_template.json` | Processed chat template for inference | + +## Programmatic Usage + +You can also use the ONNX export functionality programmatically: + +```python +from tensorrt_llm._torch.auto_deploy import LLM, AutoDeployConfig + +# Create AutoDeploy config with export_edgellm_onnx mode +ad_config = AutoDeployConfig( + model="Qwen/Qwen2.5-0.5B-Instruct", + mode="export_edgellm_onnx", + max_batch_size=8, + max_seq_len=512, + device="cpu", +) + +# Configure attention backend +ad_config.attn_backend = "torch" + +# Optionally customize output location +ad_config.transforms["export_to_onnx"]["output_dir"] = "./my_output" + +# Run the export +LLM(**ad_config.to_llm_kwargs()) +``` + +## Notes + +- **Device Selection**: Using `cpu` for the `--device` option is recommended to reduce GPU memory footprint during export. +- **Custom Operations**: The exported ONNX model contains custom operations (e.g., `AttentionPlugin`) in the `trt` domain that require corresponding implementations in the target inference runtime. diff --git a/docs/source/torch/auto_deploy/auto-deploy.md b/docs/source/torch/auto_deploy/auto-deploy.md index 185e1f321a..13ac1338d6 100644 --- a/docs/source/torch/auto_deploy/auto-deploy.md +++ b/docs/source/torch/auto_deploy/auto-deploy.md @@ -60,6 +60,7 @@ The exported graph then undergoes a series of automated transformations, includi - [Expert Configurations](./advanced/expert_configurations.md) - [Performance Benchmarking](./advanced/benchmarking_with_trtllm_bench.md) - [Serving with trtllm-serve](./advanced/serving_with_trtllm_serve.md) +- [Export ONNX for EdgeLLM](./advanced/export_onnx.md) ## Roadmap diff --git a/examples/auto_deploy/onnx_export_llm.py b/examples/auto_deploy/onnx_export_llm.py new file mode 100644 index 0000000000..af440dc4af --- /dev/null +++ b/examples/auto_deploy/onnx_export_llm.py @@ -0,0 +1,78 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ONNX export script for AutoDeploy models. + +This script exports a HuggingFace model to ONNX format using the AutoDeploy +transform pipeline directly, without initializing the full LLM executor. +""" + +import argparse + +from tensorrt_llm._torch.auto_deploy.export import export_onnx +from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs + + +def main(): + parser = argparse.ArgumentParser( + description="Export HuggingFace model to ONNX format using AutoDeploy." + ) + parser.add_argument( + "--model", + type=str, + required=True, + help="The HF model to use for onnx export.", + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + help="The device to use when exporting the model.", + ) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="The directory to save the exported ONNX model.", + ) + args = parser.parse_args() + + print(f"Constructing model from {args.model}") + + # to enable dynamic batch_size, the batch size must > 1 + # NOTE(yoco): Originally this is 2, however, don't know why, when set to 2, + # the batch_size will collapse static int 2 even we explicitly it is dynamic axis. + # And more weird, when set to 13, the batch_size will be dynamic. + # Probably some value between 2 and 13 will work, + # We use 13 here for debugging purpose. + max_batch_size = 13 + max_seq_len = 4 + + # Prepare the AutoDeploy config, mode is export_edgellm_onnx + ad_config = LlmArgs( + model=args.model, + mode="export_edgellm_onnx", + max_batch_size=max_batch_size, + max_seq_len=max_seq_len, + device=args.device, + ) + ad_config.attn_backend = "torch" + if args.output_dir is not None: + ad_config.transforms["export_to_onnx"]["output_dir"] = args.output_dir + + # Use direct InferenceOptimizer instead of LLM to avoid executor initialization + export_onnx(ad_config) + + +if __name__ == "__main__": + main() diff --git a/jenkins/current_image_tags.properties b/jenkins/current_image_tags.properties index 4f4ffbd331..140f772cf6 100644 --- a/jenkins/current_image_tags.properties +++ b/jenkins/current_image_tags.properties @@ -13,7 +13,7 @@ # images are adopted from PostMerge pipelines, the abbreviated commit hash is used instead. IMAGE_NAME=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm -LLM_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.12-py3-x86_64-ubuntu24.04-trt10.14.1.48-skip-tritondevel-202601230553-10896 -LLM_SBSA_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.12-py3-aarch64-ubuntu24.04-trt10.14.1.48-skip-tritondevel-202601230553-10896 -LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-13.1.0-devel-rocky8-x86_64-rocky8-py310-trt10.14.1.48-skip-tritondevel-202601230553-10896 -LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-13.1.0-devel-rocky8-x86_64-rocky8-py312-trt10.14.1.48-skip-tritondevel-202601230553-10896 +LLM_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.12-py3-x86_64-ubuntu24.04-trt10.14.1.48-skip-tritondevel-202601281024-10117 +LLM_SBSA_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.12-py3-aarch64-ubuntu24.04-trt10.14.1.48-skip-tritondevel-202601281024-10117 +LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-13.1.0-devel-rocky8-x86_64-rocky8-py310-trt10.14.1.48-skip-tritondevel-202601281024-10117 +LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-13.1.0-devel-rocky8-x86_64-rocky8-py312-trt10.14.1.48-skip-tritondevel-202601281024-10117 diff --git a/requirements.txt b/requirements.txt index 5f5df1c63b..2ecc98bfdc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,8 @@ mpi4py numpy<2 onnx>=1.18.0,<1.20.0 onnx_graphsurgeon>=0.5.2 +onnxscript==0.5.4 +graphviz openai polygraphy psutil diff --git a/tensorrt_llm/_torch/auto_deploy/config/export_edgellm_onnx.yaml b/tensorrt_llm/_torch/auto_deploy/config/export_edgellm_onnx.yaml new file mode 100644 index 0000000000..4ded5157bb --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/config/export_edgellm_onnx.yaml @@ -0,0 +1,126 @@ +# This is the set of transforms running in "graph" mode. In this mode, we capture the full graph +# of the model and optimize it for inference. +transforms: + ############################################################################################ + # BUILD MODEL, EXPORT TO GRAPH MODULE, AND CLEAN UP + ############################################################################################ + build_model: + stage: factory + run_per_gm: false + device: meta + requires_clean_graph: false + export_to_gm: + stage: export + clone_state_dict: false + strict: false + run_per_gm: false + requires_clean_graph: false + cleanup_noop_slice: + stage: post_export + cleanup_noop_add: + stage: post_export + cleanup_input_constraints: + stage: post_export + ############################################################################################ + # RUN PATTERN MATCHER TRANSFORMATIONS TO STANDARDIZE GRAPH REPRESENTATION + ############################################################################################ + match_moe_pattern: + stage: pattern_matcher + match_dense_moe_pattern: + stage: pattern_matcher + match_bmm_moe_pattern: + stage: pattern_matcher + match_repeat_kv: + stage: pattern_matcher + run_shape_prop: true + match_eager_attention: + stage: pattern_matcher + requires_shape_prop: true + match_sdpa_to_torch_attention: + stage: pattern_matcher + match_grouped_attention: + stage: pattern_matcher + match_attention_layout: + stage: pattern_matcher + attn_layout: bsnd + match_rope_pattern: + stage: pattern_matcher + match_rope_layout: + stage: pattern_matcher + expected_layout: bsnd + ############################################################################################ + # RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION + ############################################################################################ + eliminate_redundant_transposes: + stage: pattern_matcher + quantize_int4_linear_from_config: + stage: pattern_matcher + quantize_fp8_linear_from_config: + stage: pattern_matcher + quantize_nvfp4_linear_from_config: + stage: pattern_matcher + quantize_fp8_bmm_from_config: + stage: pattern_matcher + quantize_fp8_from_graph: + stage: pattern_matcher + quantize_nvfp4_from_graph: + stage: pattern_matcher + quantize_fp8_moe: + stage: pattern_matcher + quantize_nvfp4_moe: + stage: pattern_matcher + quantize_mxfp4_moe: + stage: pattern_matcher + ############################################################################################ + # MOVE MODEL AND LOAD WEIGHTS + ############################################################################################ + load_weights: + stage: weight_load + run_per_gm: false + checkpoint_device: cpu + move_inputs_to_device: + stage: weight_load + checkpoint_device: cpu + run_per_gm: false + ############################################################################################ + # RUN POST-LOAD FUSION AND OPTIMIZATIONS + ############################################################################################ + fuse_gemms: + stage: post_load_fusion + enabled: + false # https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this causes OOMs on GPU + # But while export ONNX for EdgeLLM, we compile the model on CPU.. + # So we can enable this transform. + # TODO(yoco):However, it currently make non-strict ONNX export fail. + fuse_moe: + stage: post_load_fusion + enabled: true + backend: trtllm + fuse_fp8_moe: + stage: post_load_fusion + enabled: true + backend: trtllm + fuse_nvfp4_moe: + stage: post_load_fusion + enabled: false + ############################################################################################ + # VISUALIZE GRAPH + ############################################################################################ + visualize_namespace: + stage: visualize + enabled: false # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/8460 + ############################################################################################ + # FUSE Rope Attention & export to ONNX + ############################################################################################ + fuse_rope_attention: + stage: export_onnx + short_reshape_attention_output: + stage: export_onnx + gather_last_token_ids: + stage: export_onnx + adapt_to_edgellm: + stage: export_onnx + export_to_onnx: + stage: export_onnx + output_dir: "." + output_name: "model.onnx" diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/onnx_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/onnx_attention.py new file mode 100644 index 0000000000..6e3de6149e --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/onnx_attention.py @@ -0,0 +1,194 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom operations for ONNX export of attention mechanisms. + +This module provides placeholder custom operations for exporting attention-related +operations to ONNX format. These operations serve as intermediate representations +during the graph transformation pipeline and are intended to be replaced by actual +backend implementations during deployment. +""" + +from typing import Tuple + +import torch + + +@torch.library.custom_op("auto_deploy::torch_onnx_attention_plugin", mutates_args=()) +def attention_plugin( + # Inputs + qkv: torch.Tensor, + past_key_values: torch.Tensor, + context_lengths: torch.Tensor, + rope_rotary_cos_sin: torch.Tensor, + kvcache_start_index: torch.Tensor, + # Attributes + enable_tree_attention: int, + head_size: int, + num_kv_heads: int, + num_q_heads: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Fused attention operation with integrated RoPE (Rotary Position Embedding). + + This custom operation combines rotary position embedding, and scaled + dot-product attention into a single fused operation. It also handles + KV-cache management for efficient autoregressive generation. + + Note: + This is a placeholder implementation for ONNX export. The actual computation + is performed by the backend runtime (e.g., TensorRT, EdgeLLM + Args: + qkv: Concatenated query, key, value tensor of shape + [batch_size, seq_len, (num_q_heads + 2 * num_kv_heads) * head_size]. + past_key_values: KV-cache tensor of shape + [batch_size, 2, num_kv_heads, past_seq_len, head_size]. + context_lengths: Sequence lengths for each batch element, shape [batch_size]. + rope_rotary_cos_sin: Precomputed RoPE cosine and sine values of shape + [batch_size, max_seq_len, head_size]. + kvcache_start_index: Starting index in KV-cache for each batch element, + shape [batch_size]. + enable_tree_attention: Flag to enable tree attention mode (0 or 1). + head_size: Dimension of each attention head. + num_kv_heads: Number of key-value heads (for grouped-query attention). + num_q_heads: Number of query heads. + + Returns: + A tuple containing: + - attention_output: Attention output tensor of shape + [batch_size, seq_len, num_q_heads, head_size]. + - present_key_values: Updated KV-cache tensor of shape + [batch_size, 2, num_kv_heads, present_seq_len, head_size]. + """ + return qkv.new_empty(0), past_key_values.new_empty(0) + + +@attention_plugin.register_fake +def attention_plugin_fake( + qkv: torch.Tensor, + past_key_values: torch.Tensor, + context_lengths: torch.Tensor, + rope_rotary_cos_sin: torch.Tensor, + kvcache_start_index: torch.Tensor, + enable_tree_attention: int, + head_size: int, + num_kv_heads: int, + num_q_heads: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Fake implementation of attention_plugin for torch.compile shape inference. + + This function computes the output shapes without performing actual computation, + enabling torch.compile to trace through the custom operation. + + Args: + qkv: Concatenated QKV tensor. + past_key_values: Previous KV-cache tensor. + context_lengths: Sequence lengths per batch. + rope_rotary_cos_sin: RoPE embedding values. + kvcache_start_index: KV-cache start indices. + enable_tree_attention: Tree attention flag. + head_size: Attention head dimension. + num_kv_heads: Number of KV heads. + num_q_heads: Number of query heads. + + Returns: + Tuple of empty tensors with correct shapes for attention output and + present KV-cache. + """ + batch_size = qkv.size(0) + seq_len = qkv.size(1) + past_len = past_key_values.size(3) + present_kv_len = seq_len + past_len + attn_shape = (batch_size, seq_len, num_q_heads, head_size) + present_kv_shape = (batch_size, 2, num_kv_heads, present_kv_len, head_size) + return torch.empty(attn_shape, device=qkv.device, dtype=qkv.dtype), torch.empty( + present_kv_shape, device=past_key_values.device, dtype=past_key_values.dtype + ) + + +def _fake_gather_nd(data: torch.Tensor, indices: torch.Tensor, batch_dims: int) -> torch.Tensor: + """Compute output shape for GatherND operation without actual gathering. + + This helper function creates an empty tensor with the correct output shape + for the GatherND operation, used for both the actual op and its fake + implementation. + + Args: + data: Source tensor of shape [batch_size, seq_len, embedding_dim]. + indices: Index tensor of shape [batch_size, num_selected] or + [batch_size, num_selected, index_depth]. + batch_dims: Number of leading batch dimensions (must be 1). + + Returns: + Empty tensor with shape [batch_size, num_selected, embedding_dim]. + + Raises: + AssertionError: If batch_dims != 1, data is not 3D, or indices is not 2D/3D. + """ + assert batch_dims == 1, "Current only support batch_dims = 1" + assert data.ndim == 3, "Current only support 3D data tensor" + assert indices.ndim == 2 or indices.ndim == 3, "Current only support 2D or 3D indices tensor" + + dim_batch_size = indices.size(0) + dim_selected_token = indices.size(1) + dim_emb = data.size(-1) + result_shape = [dim_batch_size, dim_selected_token, dim_emb] + return torch.empty(result_shape, device=data.device, dtype=data.dtype) + + +@torch.library.custom_op("auto_deploy::torch_onnx_gather_nd", mutates_args=()) +def gather_nd( + data: torch.Tensor, + indices: torch.Tensor, + batch_dims: int, +) -> torch.Tensor: + """N-dimensional gather operation following ONNX gather_nd semantics. + + Gathers slices from the data tensor based on indices, supporting batched + operations. This operation is commonly used for selecting specific tokens + from a sequence based on their positions. + + Note: + This is a placeholder implementation for ONNX export. The actual + computation is performed by the backend runtime. + + Args: + data: Source tensor to gather from, shape [batch_size, seq_len, embedding_dim]. + indices: Index tensor specifying which elements to gather, + shape [batch_size, num_selected] or [batch_size, num_selected, index_depth]. + batch_dims: Number of leading dimensions to treat as batch dimensions. + Currently only batch_dims=1 is supported. + + Returns: + Gathered tensor of shape [batch_size, num_selected, embedding_dim]. + """ + return _fake_gather_nd(data, indices, batch_dims) + + +@gather_nd.register_fake +def gather_nd_fake( + data: torch.Tensor, + indices: torch.Tensor, + batch_dims: int, +) -> torch.Tensor: + """Fake implementation of gather_nd for torch.compile shape inference. + + Args: + data: Source tensor to gather from. + indices: Index tensor for gathering. + batch_dims: Number of batch dimensions. + + Returns: + Empty tensor with the correct output shape. + """ + return _fake_gather_nd(data, indices, batch_dims) diff --git a/tensorrt_llm/_torch/auto_deploy/export/export.py b/tensorrt_llm/_torch/auto_deploy/export/export.py index fb1401e6eb..f21275b3ee 100644 --- a/tensorrt_llm/_torch/auto_deploy/export/export.py +++ b/tensorrt_llm/_torch/auto_deploy/export/export.py @@ -3,7 +3,7 @@ from collections import defaultdict from contextlib import nullcontext from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.export as te @@ -15,6 +15,9 @@ from ..utils.logger import ad_logger from ..utils.node_utils import is_op from .interface import apply_export_patches +if TYPE_CHECKING: + from ..llm_args import LlmArgs + try: from modelopt.torch.quantization.utils import export_torch_mode as torch_export_context except ImportError: @@ -342,3 +345,57 @@ def torch_export_to_gm( ad_logger.debug("exported graph: " + str(egm)) return egm + + +def export_onnx(ad_config: "LlmArgs") -> nn.Module: + """Export model to ONNX using InferenceOptimizer directly. + + This is a lightweight export path that avoids initializing the full LLM executor, + which requires KVCacheManager and other runtime components not needed for ONNX export. + + Args: + ad_config: The AutoDeploy configuration for the model. Should use a mode like + "export_edgellm_onnx" that includes the export_to_onnx transform. + + Returns: + The transformed model after running through the inference optimizer pipeline. + + Example: + >>> from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs + >>> from tensorrt_llm._torch.auto_deploy.export import export_onnx + >>> + >>> ad_config = LlmArgs( + ... model="meta-llama/Llama-2-7b-hf", + ... mode="export_edgellm_onnx", + ... max_batch_size=13, + ... max_seq_len=4, + ... device="cpu", + ... ) + >>> ad_config.transforms["export_to_onnx"]["output_dir"] = "/tmp/onnx_output" + >>> model = export_onnx(ad_config) + """ + # Import here to avoid circular imports + from ..shim.interface import CachedSequenceInterface + from ..transform.optimizer import InferenceOptimizer + + # 1. Create factory from config + factory = ad_config.create_factory() + + # 2. Create CachedSequenceInterface (lightweight, no KVCacheManager initialization) + cache_seq_interface = CachedSequenceInterface( + max_seq_len=ad_config.max_seq_len, + max_batch_size=ad_config.max_batch_size, + device=ad_config.device, + kv_cache_config=ad_config.kv_cache_config, + max_num_tokens=ad_config.max_num_tokens, + vocab_size_padded=factory.vocab_size_padded, + ) + + # 3. Create InferenceOptimizer with transform config + inference_optimizer = InferenceOptimizer( + factory=factory, + config=ad_config.transforms, + ) + + # 4. Run the transform pipeline (includes export_to_onnx transform) + return inference_optimizer(cache_seq_interface) diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index eb22e116b7..c44e9a9392 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -206,7 +206,7 @@ class LlmArgs(DynamicYamlMixInForSettings, TorchLlmArgs, BaseSettings): ) ### INFERENCE OPTIMIZER CONFIG ################################################################# - mode: Literal["graph", "transformers"] = Field( + mode: Literal["graph", "transformers", "export_edgellm_onnx"] = Field( default="graph", description="The mode to use for the inference optimizer. Currently, we " "support only the 'graph' and 'transformers' modes, i.e., full-graph capture + optimization" @@ -335,5 +335,6 @@ class LlmArgs(DynamicYamlMixInForSettings, TorchLlmArgs, BaseSettings): mapping = { "graph": str(config_path / "default.yaml"), "transformers": str(config_path / "transformers.yaml"), + "export_edgellm_onnx": str(config_path / "export_edgellm_onnx.yaml"), } return mapping.get(mode) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/graph_module_visualizer.py b/tensorrt_llm/_torch/auto_deploy/transform/graph_module_visualizer.py new file mode 100644 index 0000000000..530f70dcd7 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/graph_module_visualizer.py @@ -0,0 +1,691 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +PyTorch GraphModule Visualization Tool + +This module provides functionality to convert PyTorch GraphModule to Graphviz diagrams. +Supports different node styles and detailed graph annotations. + +Key Features: +- Convert FX GraphModule to Graphviz diagrams +- Display tensor shape information on edges +- Adjust edge width based on tensor element count +- Intelligent port assignment for multi-input/output handling +- Color coding based on tensor identity + +Usage Example: + import torch + import torch.fx as fx + from graph_module_visualizer import to_dot + + # Trace model + model = YourModel() + traced = fx.symbolic_trace(model) + + # Generate visualization + dot = to_dot(traced, format="svg", include_shapes=True) + +Requirements: pip install graphviz +""" + +import math +import re +from typing import Any, Dict, Optional + +import graphviz +import torch +import torch.fx as fx +from torch.fx import GraphModule + +from ..utils.logger import ad_logger + + +def _calculate_edge_width(val) -> float: + """ + Calculate edge width based on tensor element count + Formula: log10(num_elements) + 2 + + Args: + val: FakeTensor, Tensor, or any object with .shape attribute + """ + min_width = 2.0 + max_width = 10.0 + if not hasattr(val, "shape"): + return min_width # Default width + + try: + # Calculate total number of elements + num_elements = 1 + for dim in val.shape: + if isinstance(dim, int): + num_elements *= dim + else: + num_elements *= 1 + + if num_elements <= 0: + return min_width + + # Calculate width: log10(element count) + 2 + width = math.log10(num_elements) + 2 + + # Constrain width range (2.0 to 10.0) + width = max(min_width, min(max_width, width)) + + return width + + except (ValueError, TypeError): + return min_width # Use default width on error + + +def _get_edge_color(source_node_name: str, output_index: int) -> str: + """ + Assign color based on source node and actual output index + Ensures all edges of the same tensor use the same color + """ + colors = [ + "#FF6B9D80", # Pink + "#4ECDC480", # Mint green + "#45B7D180", # Sky blue + "#96CEB480", # Light green + "#DDA0DD80", # Light purple + "#98D8C880", # Teal + "#BB8FCE80", # Violet + "#85C1E980", # Light blue + "#F8C47180", # Peach + "#82E0AA80", # Mint + "#F7DC6F80", # Lemon yellow + "#AED6F180", # Light sky blue + ] + + # Use combination of source_node + output_index + # Ensures multiple edges of the same tensor have the same color + tensor_id = f"{source_node_name}_{output_index}" + color_index = hash(tensor_id) % len(colors) + return colors[color_index] + + +def _get_port_of_five(input_idx, total_inputs): + """ + 0 xxxxxxxx a = 7 + 1 xxxxxxxx b = 2 + 2 xxxxxxx x = 2 * 8 = 16 + 3 xxxxxxx + 4 xxxxxxx + """ + K = 5 + a, b = total_inputs // K, total_inputs % K + x = b * (a + 1) + if input_idx < x: + return input_idx // (a + 1) + else: + return (input_idx - x) // a + b + + +def _get_input_port(input_index: int, total_inputs: int) -> str: + """ + Get input port based on input index and total input count + """ + if total_inputs <= 1: + return "n" # Single input uses default port + elif total_inputs == 2: + return "nw" if input_index == 0 else "ne" # Northwest, northeast + elif total_inputs == 3: + ports = ["nw", "n", "ne"] # Northwest, north, northeast + return ports[input_index] + elif total_inputs == 4: + ports = ["nw", "n", "ne", "e"] + return ports[input_index] + else: + # 5+ inputs: west, northwest, north, northeast, east in order + ports = ["w", "nw", "n", "ne", "e"] + # Cycle through ports for more than 5 inputs + return ports[_get_port_of_five(input_index, total_inputs)] + + +def _get_output_port(output_index: int, total_outputs: int) -> str: + """ + Get output port based on output index and total output count (symmetric to input, but on bottom) + """ + if total_outputs <= 1: + return "s" # Single output uses default port + elif total_outputs == 2: + return "sw" if output_index == 0 else "se" # Southwest, southeast + elif total_outputs == 3: + ports = ["sw", "s", "se"] # Southwest, south, southeast + return ports[output_index] + else: + # 4+ outputs: west, southwest, south, southeast, east in order + ports = ["w", "sw", "s", "se", "e"] + # Cycle through ports for more than 5 outputs + return ports[_get_port_of_five(output_index, total_outputs)] + + +def to_dot( + graph_module: GraphModule, + name: str, + save_path: str, + format: str = "svg", + include_shapes: bool = True, +) -> Optional["graphviz.Digraph"]: + """ + Convert PyTorch GraphModule to Graphviz diagram + + Args: + graph_module: GraphModule to visualize + name: Name of the diagram + save_path: Save path, if None uses name + format: Output format ('png', 'pdf', 'svg', 'dot', etc.) + include_shapes: Whether to include tensor shape information + + Returns: + graphviz.Digraph object + """ + # Create Graphviz diagram + dot = graphviz.Digraph( + name=name, comment=f"PyTorch GraphModule: {graph_module.__class__.__name__}", format=format + ) + + # Set graph attributes + dot.attr(rankdir="TB") # Top to bottom + dot.attr("node", shape="box", style="rounded,filled", height="0.2") + # Remove default edge color, let each edge use its own color + + # Node style configuration + node_styles = { + "placeholder": {"fillcolor": "lightgreen", "shape": "box"}, + "get_attr": {"fillcolor": "lightcyan", "shape": "box"}, + "call_function": {"fillcolor": "lightblue", "shape": "box"}, + "call_method": {"fillcolor": "lightyellow", "shape": "box"}, + "call_module": {"fillcolor": "lightpink", "shape": "box"}, + "output": {"fillcolor": "lightcoral", "shape": "box"}, + } + + # Analyze graph structure + graph = graph_module.graph + nodes = list(graph.nodes) + node_labels = {} + + # Process each node + for node in nodes: + # Get basic node information + node_name = node.name + op_type = node.op + + # Create node label + label = _get_node_label(graph_module, node) + node_labels[node_name] = label + + # Set node style + style = node_styles.get(op_type, {"fillcolor": "white", "shape": "box"}) + + # Add node to diagram + node_attrs = { + "label": label, + "fillcolor": style["fillcolor"], + "shape": style["shape"], + "tooltip": node_name, # Use node name as tooltip + } + # If the node has no value, set the fillcolor to red + if "val" not in node.meta: + node_attrs["fillcolor"] = "red" + elif isinstance(node.meta["val"], torch.Tensor): + node_attrs["label"] += "\n" + str(node.meta["val"].device) + elif isinstance(node.meta["val"], (list, tuple)): + for val in node.meta["val"]: + if isinstance(val, torch.Tensor): + node_attrs["label"] += "\n" + str(val.device) + else: + node_attrs["label"] += "\n" + str(val) + + dot.node(node_name, **node_attrs) + + # First collect all edge information + edges = [] # Format: (source_node, target_node, val, source_output_index, target_input_index) + node_inputs = {} # Input list for each node: [(source_node_name, output_index)] + node_outputs = {} # Output list for each node: [(target_node_name, input_index)] + + # Initialize + for node in nodes: + node_inputs[node.name] = [] + node_outputs[node.name] = [] + + # Collect edge information + def _add_edge_from_node_with_index(input_idx, source_node: fx.Node, target_node: fx.Node): + """Add edge from source_node to target_node, automatically determining correct output index""" + # Extract val (FakeTensor or Tensor) from node.meta["val"] + # This is more reliable than tensor_meta since FakeTensorProp always populates val + val = None + if include_shapes and hasattr(source_node, "meta") and "val" in source_node.meta: + val = source_node.meta["val"] + + # Calculate indices + source_output_index = _determine_output_index(source_node, target_node) + + # Add edge and update indices (store tuple containing node name and corresponding index) + edges.append((source_node.name, target_node.name, val, source_output_index, input_idx)) + node_inputs[target_node.name].append((source_node.name, source_output_index)) + node_outputs[source_node.name].append((target_node.name, input_idx)) + + def _determine_output_index(source_node: fx.Node, target_node: fx.Node) -> int: + """Determine the output index for the edge from source_node to target_node""" + # Check if target_node is a getitem operation + if ( + target_node.op == "call_function" + and hasattr(target_node.target, "__name__") + and target_node.target.__name__ == "getitem" + and len(target_node.args) >= 2 + ): + # target_node.args[0] is source node, target_node.args[1] is index + if target_node.args[0] == source_node and isinstance(target_node.args[1], int): + return target_node.args[1] + + # By default, FX nodes have only one output + return 0 + + def _fix_problematic_negative_one(text): + """Fix the problematic 9223372036854775807 number back to -1""" + return text.replace("9223372036854775807", "-1") + + def _format_constant_label(value): + """Format constant value for display as node label""" + # Special case: direct problematic integer + if isinstance(value, int) and value == 9223372036854775807: # 2^63 - 1 + return "-1" + + # Handle different value types + if isinstance(value, (int, float, bool, str)): + result = str(value) + elif hasattr(value, "shape"): + # Tensor-like object with shape + if hasattr(value, "numel") and value.numel() > 6: + return f"shape={tuple(value.shape)}" + else: + result = str(value) + elif isinstance(value, (list, tuple)): + # Collection of elements + if len(value) > 6: + return f"length={len(value)}" + else: + result = str(value) + else: + # Other types + result = str(value)[:50] + ("..." if len(str(value)) > 50 else "") + + # Apply the fix to any string representation + return _fix_problematic_negative_one(result) + + # Store constants to be created later + constants_to_create = [] + + def _process_arg(input_idx, arg, target_node: fx.Node): + """Process a single argument and create appropriate edges in the graph. + + Handles three types of arguments: + - fx.Node: Creates an edge from the source node to the target node. + - Container (list/tuple): Iterates through and creates edges for any fx.Node elements. + - Constant: Creates a constant node and adds it to the graph with appropriate edges. + + Args: + input_idx: The index of this argument in the target node's input list. + arg: The argument to process. Can be an fx.Node, a container (list/tuple), + or a constant value. + target_node: The fx.Node that receives this argument as input. + + Note: + This function modifies external state including `constants_to_create`, + `node_inputs`, `node_outputs`, and `edges`. + """ + if isinstance(arg, fx.Node): + _add_edge_from_node_with_index(input_idx, arg, target_node) + elif isinstance(arg, (list, tuple)): + for sub_arg in arg: + if isinstance(sub_arg, fx.Node): + _add_edge_from_node_with_index(input_idx, sub_arg, target_node) + else: + # This is a constant value - store for later creation + const_node_name = f"const_{target_node.name}_{input_idx}" + constants_to_create.append((const_node_name, arg, target_node.name, input_idx)) + + # Add to node tracking immediately + node_inputs[target_node.name].append(const_node_name) + if const_node_name not in node_outputs: + node_outputs[const_node_name] = [] + node_outputs[const_node_name].append((target_node.name, input_idx)) + + # Create edge from constant to target + edges.append((const_node_name, target_node.name, None, 0, input_idx)) + + def _determine_node_output_count(node_name: str) -> int: + """Determine the actual output count of a node (applies to all node types)""" + # Check if any getitem operations use this node + max_index = -1 + for n in nodes: + # Only check getitem operation pattern, don't restrict source node op type + if ( + n.op == "call_function" + and hasattr(n.target, "__name__") + and n.target.__name__ == "getitem" + and len(n.args) >= 2 + and hasattr(n.args[0], "name") + and n.args[0].name == node_name + and isinstance(n.args[1], int) + ): + max_index = max(max_index, n.args[1]) + + # If getitem operations found, output count is max_index + 1 + if max_index >= 0: + return max_index + 1 + + # By default, nodes have only one output + return 1 + + # Traverse all nodes to collect edges + for node in nodes: + if not node.args: + continue + + for input_idx, arg in enumerate(node.args): + _process_arg(input_idx, arg, node) + + # Print 10 nodes with most inputs + # ad_logger.debug("Nodes with most inputs:") + # node_inputs_sorted = sorted(node_inputs.items(), key=lambda x: len(x[1]), reverse=True) + # for node_name, input_list in node_inputs_sorted[:10]: + # ad_logger.debug(f" {node_name}: {len(input_list)}") + + # Print 10 nodes with most outputs + node_outputs_sorted = sorted(node_outputs.items(), key=lambda x: len(x[1]), reverse=True) + # ad_logger.debug("Nodes with most outputs:") + large_fanout_nodes: Dict[str, int] = {} + for node_name, output_list in node_outputs_sorted[:10]: + if len(output_list) > 10: + large_fanout_nodes[node_name] = 0 + # ad_logger.debug(f" {node_name}: {len(output_list)}") + + # Overwrite large fanout nodes style + for node_name in large_fanout_nodes: + # Add node to diagram + dot.node( + node_name, + label=node_name, + fillcolor="#ffdddd80", + color="#88888880", + shape="box", + style="filled,dashed,rounded", + tooltip=node_name, # Use node name as tooltip + ) + + node_inputs_sorted = sorted(node_inputs.items(), key=lambda x: len(x[1]), reverse=True) + large_fanin_nodes: Dict[str, int] = {} + for node_name, input_list in node_inputs_sorted[:10]: + if len(input_list) > 12: + large_fanin_nodes[node_name] = 0 + + for node_name in large_fanin_nodes: + # Add node to diagram + dot.node( + node_name, + label=node_name, + fillcolor="#ddffdd80", + color="#88888880", + shape="box", + style="filled,dashed,rounded", + tooltip=node_name, # Use node name as tooltip + ) + + # Create constant nodes + for const_node_name, const_value, target_name, input_idx in constants_to_create: + const_label = _format_constant_label(const_value) + + # Add constant node to dot graph + const_attrs = { + "label": const_label, + "shape": "box", + "style": "rounded,filled", + "fillcolor": "#ffffcc", # Light yellow background + "color": "#cccccc", # Light gray border + "width": "0.2", + "fontsize": "10", + } + dot.node(const_node_name, **const_attrs) + + # Add edges with ports and colors + for source_name, target_name, val, source_output_index, target_input_index in edges: + edge_attrs = {} + + # Calculate ports (for graphical display positioning) + input_list = node_inputs[target_name] + + # Use actual output count, not usage count + source_output_count = _determine_node_output_count(source_name) + + input_port = _get_input_port(target_input_index, len(input_list)) + output_port = _get_output_port(source_output_index, source_output_count) + + # Build node names with ports + source_name_port = f"{source_name}:{output_port}" if output_port else source_name + target_name_port = f"{target_name}:{input_port}" if input_port else target_name + + # Set edge color (based on actual output_index) + edge_color = _get_edge_color(source_name, source_output_index) + edge_attrs["color"] = edge_color + + # Add tensor shape and width information + # val can be FakeTensor, Tensor, or other types with .shape attribute + if val is not None and include_shapes and hasattr(val, "shape"): + shape_str = str(tuple(val.shape)) + # Add dtype if available + dtype_str = "" + if hasattr(val, "dtype"): + dtype_str = str(val.dtype).replace("torch.", "") + edge_attrs["xlabel"] = f"{shape_str}\n{dtype_str}" if dtype_str else shape_str + edge_attrs["fontsize"] = "10" + edge_attrs["fontcolor"] = "blue" + + # Calculate edge width based on element count + width = _calculate_edge_width(val) + edge_attrs["penwidth"] = str(width) + + # For those large fanout nodes, large fantout nodes will stuck the graphviz layout algorithm. + # So we need to duplicate the node, so each edge has its own source node. + # Make the layout algorithm work. + # So is large fanin nodes. + if source_name in large_fanout_nodes: + node_attrs = { + "fillcolor": "#ddffdd80", + "color": "#88888880", + "shape": "box", + "style": "filled,dashed,rounded", + } + large_fanout_nodes[source_name] += 1 + source_name_port = source_name + f"___{large_fanout_nodes[source_name]}" + dot.node( + name=source_name_port, + label=source_name, + **node_attrs, + ) + + if target_name in large_fanin_nodes: + node_attrs = { + "fillcolor": "#ddffdd80", + "color": "#88888880", + "shape": "box", + "style": "filled,dashed,rounded", + } + large_fanin_nodes[target_name] += 1 + target_name_port = target_name + f"___{large_fanin_nodes[target_name]}" + dot.node(name=target_name_port, label=target_name, **node_attrs) + + dot.edge(source_name_port, target_name_port, **edge_attrs) + + # Save diagram + try: + dot.render(save_path, cleanup=True) + ad_logger.info(f"Diagram saved: {save_path}.{format}") + with open(save_path + ".txt", "w") as f: + f.write(str(graph_module.graph)) + except Exception as e: + ad_logger.error(f"Failed to save diagram: {e}") + + return dot + + +def analyze_graph_structure(graph_module: GraphModule) -> Dict[str, Any]: + """ + Analyze structural statistics of GraphModule + + Args: + graph_module: GraphModule to analyze + + Returns: + Dictionary containing structural statistics + """ + graph = graph_module.graph + nodes = list(graph.nodes) + + # Count node types + op_counts = {} + for node in nodes: + op_type = node.op + op_counts[op_type] = op_counts.get(op_type, 0) + 1 + + # Analyze connections + total_connections = 0 + for node in nodes: + if node.args: + for arg in node.args: + if isinstance(arg, fx.Node): + total_connections += 1 + elif isinstance(arg, (list, tuple)): + for sub_arg in arg: + if isinstance(sub_arg, fx.Node): + total_connections += 1 + + # Calculate graph complexity + complexity_score = len(nodes) + total_connections + + return { + "total_nodes": len(nodes), + "node_types": op_counts, + "total_connections": total_connections, + "complexity_score": complexity_score, + "graph_depth": _calculate_graph_depth(nodes), + } + + +def _get_node_label(graph_module: GraphModule, node: fx.Node) -> str: + """Get node label""" + if node.op == "call_function": + func_name = _get_function_name(node.target) + tokens = func_name.split(".") + assert len(tokens) <= 2, f"Function name {func_name} has more than 2 tokens" + label = tokens[0] if tokens[0] != "to" else func_name + elif node.op == "call_method": + label = str(node.target) + elif node.op == "call_module": + label = _get_module_name(graph_module, node.target) + elif node.op == "get_attr": + attr_name = str(node.target).split(".")[-1] if "." in str(node.target) else str(node.target) + label = attr_name + elif node.op == "placeholder": + label = "ph: " + str(node.name) + elif node.op == "output": + label = "out: " + str(node.name) + else: + label = node.op + return label + + +def _get_function_name(func) -> str: + """Get simplified function name""" + if hasattr(func, "__name__"): + return func.__name__ + + func_str = str(func) + + # Handle torch functions + if "torch." in func_str: + match = re.search(r"torch\.(\w+)\.(\w+)", func_str) + if match: + return f"{match.group(1)}.{match.group(2)}" + + match = re.search(r"torch\.(\w+)", func_str) + if match: + return match.group(1) + + # Handle built-in functions + if "built-in" in func_str: + match = re.search(r"'(\w+)'", func_str) + if match: + return match.group(1) + + return str(func).split(".")[-1] if "." in str(func) else str(func) + + +def _get_module_name(graph_module: GraphModule, target) -> str: + """Extract module name, handle numeric indices in Sequential""" + try: + # Try to get actual module type name + actual_module = graph_module.get_submodule(str(target)) + module_type = actual_module.__class__.__name__ + + # Extract the last part of module name + module_name = str(target).split(".")[-1] if "." in str(target) else str(target) + + # If it's numeric index (like modules in Sequential), show type name + if module_name.isdigit(): + return module_type + else: + return module_name + except Exception: + # If unable to get module, fall back to original logic + module_name = str(target).split(".")[-1] if "." in str(target) else str(target) + return module_name + + +def _calculate_graph_depth(nodes) -> int: + """Calculate maximum depth of the graph""" + # Build dependency relationships + dependencies = {} + for node in nodes: + dependencies[node.name] = [] + if node.args: + for arg in node.args: + if isinstance(arg, fx.Node): + dependencies[node.name].append(arg.name) + elif isinstance(arg, (list, tuple)): + for sub_arg in arg: + if isinstance(sub_arg, fx.Node): + dependencies[node.name].append(sub_arg.name) + + # Calculate depth of each node + depths = {} + + def calculate_depth(node_name): + if node_name in depths: + return depths[node_name] + + if not dependencies[node_name]: + depths[node_name] = 0 + return 0 + + max_dep_depth = max(calculate_depth(dep) for dep in dependencies[node_name]) + depths[node_name] = max_dep_depth + 1 + return depths[node_name] + + for node in nodes: + calculate_depth(node.name) + + return max(depths.values()) if depths else 0 diff --git a/tensorrt_llm/_torch/auto_deploy/transform/interface.py b/tensorrt_llm/_torch/auto_deploy/transform/interface.py index 571a632f7d..7bfc2f8023 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/interface.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/interface.py @@ -3,14 +3,16 @@ This module defines the base classes and interfaces for all transforms. """ +import os import time from abc import ABC from contextlib import contextmanager, nullcontext from dataclasses import dataclass from enum import Enum from functools import total_ordering, wraps -from typing import Any, Callable, Dict, Mapping, Tuple, Type, Union, final +from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type, Union, final +import torch import torch.nn as nn from pydantic import BaseModel, Field from torch.fx import GraphModule, Node @@ -27,6 +29,7 @@ from ..utils._graph import ( ) from ..utils.cuda_mem_tracker import get_mem_info from ..utils.logger import ad_logger +from .graph_module_visualizer import to_dot # ANSI color codes for log formatting (set to False to disable colors) # NOTE: colors disabled by default to make logging in CI/CD pipelines easier to read @@ -102,6 +105,7 @@ class Stages(Enum): POST_LOAD_FUSION = "post_load_fusion" # post-loading fusion and perf optimizations of the graph CACHE_INIT = "cache_init" # initialization of cached attention + (KV) cache initialization VISUALIZE = "visualize" # visualization of the graph + EXPORT_ONNX = "export_onnx" # export the graph to onnx COMPILE = "compile" # graph compilation stage using low-level compilers like torch.compile def __lt__(self, other): @@ -166,6 +170,11 @@ class TransformConfig(BaseModel): default=False, description="Whether this transform requires shape propagation before it is applied.", ) + debug_visualize_dir: Optional[str] = Field( + default=None, + description="Debug visualization directory. None to disable visualization, " + "or a path string to specify the output directory.", + ) expect_mem_change: bool = Field( default=False, @@ -257,6 +266,7 @@ def with_transform_logging(call_fn: Callable) -> Callable: cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, + idx: int, ) -> nn.Module: prefix = f"[stage={self.config.stage.value}, transform={self.get_transform_key()}]" original_log = ad_logger.log @@ -268,7 +278,7 @@ def with_transform_logging(call_fn: Callable) -> Callable: ad_logger.log = _patched_log # type: ignore[assignment] try: - return call_fn(self, gm, cm, factory, shared_config) + return call_fn(self, gm, cm, factory, shared_config, idx) finally: ad_logger.log = original_log # type: ignore[assignment] @@ -346,6 +356,7 @@ class BaseTransform(ABC): cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, + idx: int, ) -> nn.Module: """Apply the transform to the graph. @@ -354,6 +365,7 @@ class BaseTransform(ABC): cm: The cached sequence interface defining the sequence interface. factory: The model factory used to build the model. shared_config: Global info shared between multiple transforms. + idx: The index of the transform in the pipeline. Returns: nn.Module: The transformed model. @@ -473,10 +485,33 @@ class BaseTransform(ABC): autodeploy_meta[self._mem_history_key] = mem_history self._set_autodeploy_meta(mod, autodeploy_meta) + self._visualize_graph(mod, idx) # return the graph module return mod + @final + def _visualize_graph(self, mod: nn.Module, idx: int) -> None: + """Visualize the graph if debug visualization is enabled. + Args: + mod: The graph module to visualize. + idx: The index of the transform in the pipeline. + Note: + we may want to consider doing this for each subgraph. + See https://github.com/NVIDIA/TensorRT-LLM/issues/10203 + """ + if not isinstance(mod, torch.fx.GraphModule): + return + visualize_dir = self.config.debug_visualize_dir + if not visualize_dir: + return + if not os.path.exists(visualize_dir): + os.makedirs(visualize_dir) + name_stem = f"gm_{idx + 1:02d}_{self.get_transform_key()}" + visualize_path = os.path.join(visualize_dir, f"{name_stem}") + to_dot(mod, name=name_stem, save_path=visualize_path, format="svg") + ad_logger.debug(f"[{idx + 1:02d}] Visualized {name_stem} to {visualize_path}") + @final def _apply_per_gm_or_whole_model( self, diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/_chat_template.py b/tensorrt_llm/_torch/auto_deploy/transform/library/_chat_template.py new file mode 100644 index 0000000000..cf3d2e0613 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/_chat_template.py @@ -0,0 +1,382 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +For EdgeLLM API. Processes the chat template to create a JSON file with +chat template data for the following: + +Roles: +- System +- User +- Assistant + +Messages: +- Role +- Content + - Type + - text + - image + - video + +The JSON file is saved to the exported ONNX model directory. + +This implementation uses the HF tokenizer's apply_chat_template method with test cases +to extract the actual prefix/suffix patterns used by the model, rather than trying +to parse the Jinja template directly. +""" + +import json +import os +import re +from dataclasses import asdict, dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Union + +from transformers import AutoConfig, AutoProcessor, AutoTokenizer + +from ...utils.logger import ad_logger + + +def is_vlm(model_dir: str) -> bool: + """Check if the model is a VLM.""" + cfg = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) + cfg_dict = cfg.to_dict() + has_vision = "vision_config" in cfg_dict + has_phi4_vision = "image_embd_layer" in cfg_dict.get("embd_layer", {}) + if has_vision or has_phi4_vision: + ad_logger.debug("Set use_prompt_tuning to True") + return True + else: + ad_logger.debug("Set use_prompt_tuning to False") + return False + + +@dataclass +class Message: + role: str + content: Union[str, List[Dict[str, str]]] = field(default_factory=list) + + +@dataclass +class SystemMessage(Message): + role: str = "system" + content: str = "" + + +@dataclass +class UserMessage(Message): + role: str = "user" + content: str = "" + + +@dataclass +class MultimodalUserMessage(Message): + role: str = "user" + content: List[Dict[str, str]] = field( + default_factory=lambda: [{"type": "text", "text": ""}] + ) + + def add_text_content(self, text: str): + self.content.append({"type": "text", "text": text}) + + def add_image_content(self, image: str): + self.content.append({"type": "image", "image": image}) + + def add_video_content(self, video: str): + self.content.append({"type": "video", "video": video}) + + +@dataclass +class AssistantMessage(Message): + role: str = "assistant" + content: str = "" + # TODO: Add tool calling + + +# TODO: Add ToolMessage + + +def _format_messages( + tokenizer: Any, messages: List[Message], add_generation_prompt: bool = False +) -> str: + """ + Format the messages using the tokenizer's chat template. + + Args: + tokenizer: HuggingFace loaded tokenizer + messages: List of messages + add_generation_prompt: Whether to add generation prompt + + Returns: + Formatted text + + Raises: + ValueError: If unable to format messages + """ + try: + # Convert dataclass messages to dictionaries using asdict + message_dicts = [asdict(msg) for msg in messages] + + return tokenizer.apply_chat_template( + message_dicts, tokenize=False, add_generation_prompt=add_generation_prompt + ) + except Exception: + # Try fallback: convert list content to string for tokenizers that don't support multimodal + try: + message_dicts = [] + for msg in messages: + content = msg.content + # If content is a list, extract the first text element + if isinstance(content, list): + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + content = item.get("text", "") + break + message_dicts.append({"role": msg.role, "content": content}) + + return tokenizer.apply_chat_template( + message_dicts, tokenize=False, add_generation_prompt=add_generation_prompt + ) + except Exception as e2: + raise ValueError( + f"Unable to format messages using HuggingFace tokenizer's apply_chat_template method." + f"Messages need to be in the format: role: , content: . " + f"Check INPUT_FORMAT.md for more details." + f"Error: {e2}" + ) from e2 + + +def _extract_prefix_suffix(text: str, placeholder: str) -> Tuple[str, str]: + """ + Extract prefix and suffix from the differential text by finding the placeholder content. + + Args: + text: The text to extract the prefix and suffix from + placeholder : The placeholder content to search for in the formatted output + + Returns: + Tuple of (prefix, suffix) strings + """ + content_start = text.find(placeholder) + + if content_start == -1: + return "", "" + + prefix = text[:content_start] + suffix = text[content_start + len(placeholder) :] + + return prefix, suffix + + +def _extract_content_pattern( + tokenizer: Any, + system_prompt: SystemMessage, + content_type: str, + placeholder: str, + text_only_formatted: str, + placeholder_text: str, +) -> Optional[str]: + """ + Extract the pattern for a specific content type (image/video) by comparing + with text-only message. + + Args: + tokenizer: The loaded tokenizer + system_prompt: System message to use + content_type: Type of content ('image' or 'video') + placeholder: Placeholder string for the content + text_only_formatted: Formatted text-only message + placeholder_text: The text placeholder used + + Returns: + Extracted pattern string or None if failed or tokenizer does not support multimodal content + """ + # Create user message with the content type + user_with_content = MultimodalUserMessage() + if content_type == "image": + user_with_content.add_image_content(placeholder) + elif content_type == "video": + user_with_content.add_video_content(placeholder) + else: + return None + + with_content_formatted = _format_messages(tokenizer, [system_prompt, user_with_content]) + + # Extract the differential - what was added for this content type + if placeholder_text in text_only_formatted and placeholder_text in with_content_formatted: + # Find position after the placeholder in both + text_pos = text_only_formatted.find(placeholder_text) + len(placeholder_text) + content_pos = with_content_formatted.find(placeholder_text) + len(placeholder_text) + + # Get what comes after the placeholder in both + text_only_suffix = text_only_formatted[text_pos:] + with_content_suffix = with_content_formatted[content_pos:] + + # The pattern is what was added (the difference) + if text_only_suffix and with_content_suffix.endswith(text_only_suffix): + pattern = with_content_suffix[: -len(text_only_suffix)] + else: + pattern = with_content_suffix + + # Strip dynamic prefixes like "Image 1:" or "Video 1:" + pattern = re.sub(rf"^{content_type.capitalize()} \d+:\s*", "", pattern) + + return pattern if pattern else None + + return None + + +def process_chat_template(model_dir: str, output_dir: str) -> None: + """ + Process the chat template from model's tokenizer and create a JSON file + with parsed template information. + + This function uses the tokenizer's apply_chat_template method with various + test cases to extract the actual prefix/suffix patterns. + + Args: + model_dir: Path to the model directory containing tokenizer files + output_dir: Path to save the chat_template.json file + + Returns: + None + """ + ad_logger.info(f"Processing chat template from {model_dir}") + + tokenizer = None + loaders = ( + [AutoProcessor, AutoTokenizer] if is_vlm(model_dir) else [AutoTokenizer, AutoProcessor] + ) + for ldr in loaders: + try: + tokenizer = ldr.from_pretrained(model_dir, trust_remote_code=True) + if getattr(tokenizer, "chat_template", None): + ad_logger.debug(f"Successfully loaded chat template from {ldr.__name__}") + break + else: + ad_logger.debug(f"{ldr.__name__} loaded but no chat template found") + tokenizer = None + except Exception as e: + ad_logger.error(f"Failed to load {ldr.__name__}: {e}") + tokenizer = None + + if tokenizer is None: + ad_logger.debug("Skipping chat template processing - no chat template available") + return + + ad_logger.debug("Extracting patterns from chat template...") + + # Extract system role patterns (base case) + system_prompt = SystemMessage() + system_formatted = _format_messages(tokenizer, [system_prompt]) + system_prefix, system_suffix = _extract_prefix_suffix(system_formatted, system_prompt.content) + + # Extract user role patterns (compare with system base) + user_prompt = UserMessage() + user_formatted = _format_messages(tokenizer, [system_prompt, user_prompt]) + user_prefix, user_suffix = _extract_prefix_suffix( + user_formatted[len(system_formatted) :], user_prompt.content + ) + + # Extract assistant role patterns (compare with user case) + assistant_prompt = AssistantMessage() + assistant_formatted = _format_messages( + tokenizer, [system_prompt, user_prompt, assistant_prompt] + ) + assistant_prefix, assistant_suffix = _extract_prefix_suffix( + assistant_formatted[len(user_formatted) :], assistant_prompt.content + ) + + # Extract generation prompt + generation_formatted = _format_messages( + tokenizer, [system_prompt, user_prompt], add_generation_prompt=True + ) + generation_prompt = generation_formatted[len(user_formatted) :] + + # Build content types + content_types = {} + + # Only extract multimodal patterns if this is a VLM model + if is_vlm(model_dir): + ad_logger.debug("Detected VLM model, extracting multimodal content patterns...") + # Get base text-only formatted message for comparison + user_text_only = MultimodalUserMessage() + text_only_formatted = _format_messages(tokenizer, [system_prompt, user_text_only]) + placeholder_text = user_text_only.content[0]["text"] + + # Extract image pattern + image_pattern = _extract_content_pattern( + tokenizer, + system_prompt, + "image", + "", + text_only_formatted, + placeholder_text, + ) + if image_pattern: + content_types["image"] = {"format": image_pattern} + + # Extract video pattern + video_pattern = _extract_content_pattern( + tokenizer, + system_prompt, + "video", + "", + text_only_formatted, + placeholder_text, + ) + if video_pattern: + content_types["video"] = {"format": video_pattern} + else: + ad_logger.debug("Text-only LLM detected, skipping multimodal content pattern extraction") + + # Extract default system prompt by testing without system message + user_only_prompt = UserMessage() + user_only_formatted = _format_messages(tokenizer, [user_only_prompt]) + + # Extract default system prompt + default_system_prompt = "" + # Check if a default system prompt was added + # The system message should appear in user_only_formatted if there's a default + system_start = user_only_formatted.find(system_prefix) + if system_start != -1: + # Extract the system content between prefix and suffix + content_start = system_start + len(system_prefix) + content_end = user_only_formatted.find(system_suffix, content_start) + if content_end != -1: + default_system_prompt = user_only_formatted[content_start:content_end] + # Remove the placeholder if it appears + if default_system_prompt == system_prompt.content: + default_system_prompt = "" + + # Build the final JSON structure + chat_template_data = { + "model_path": model_dir, + "roles": { + "system": {"prefix": system_prefix, "suffix": system_suffix}, + "user": {"prefix": user_prefix, "suffix": user_suffix}, + "assistant": {"prefix": assistant_prefix, "suffix": assistant_suffix}, + }, + "content_types": content_types, + "generation_prompt": generation_prompt, + "default_system_prompt": default_system_prompt, + } + + # Save to output directory + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, "processed_chat_template.json") + + with open(output_path, "w") as f: + json.dump(chat_template_data, f, indent=2) + + ad_logger.info(f"Chat template saved to {output_path}") diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/_config_export.py b/tensorrt_llm/_torch/auto_deploy/transform/library/_config_export.py new file mode 100644 index 0000000000..103307dd33 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/_config_export.py @@ -0,0 +1,205 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict + +from ...utils.logger import ad_logger + +EDGELLM_VERSION = "0.5.0.0" + + +def _export_native_llm_config(config_dict: Dict[str, Any]) -> Dict[str, Any]: + """Export LLM configuration with required fields.""" + required_fields = [ + "vocab_size", + "max_position_embeddings", + "hidden_size", + "intermediate_size", + "num_hidden_layers", + "num_attention_heads", + "num_key_value_heads", + "rope_theta", + "rope_scaling", + ] + + llm_config = {} + for field in required_fields: + if field not in config_dict: + raise KeyError(f"Required field '{field}' not found in config") + llm_config[field] = config_dict[field] + + # Handle LongRoPE (rope_scaling already validated in required_fields) + rope_scaling = config_dict["rope_scaling"] + if rope_scaling and rope_scaling.get("type", None) == "longrope": + if "original_max_position_embeddings" not in config_dict: + raise KeyError("Required field 'original_max_position_embeddings' not found in config") + llm_config["original_max_position_embeddings"] = config_dict[ + "original_max_position_embeddings" + ] + + # Handle head_dim + if "head_dim" in config_dict: + llm_config["head_dim"] = config_dict["head_dim"] + else: + ad_logger.warning( + "Warning: head_dim not found in config, calculating as hidden_size // num_attention_heads" + ) + llm_config["head_dim"] = config_dict["hidden_size"] // config_dict["num_attention_heads"] + + if "partial_rotary_factor" in config_dict: + llm_config["partial_rotary_factor"] = config_dict["partial_rotary_factor"] + else: + llm_config["partial_rotary_factor"] = 1.0 + + llm_config["model_type"] = "llm" + return llm_config + + +def _export_eagle_base_config(config_dict: Dict[str, Any]) -> Dict[str, Any]: + """Export EAGLE base configuration with required fields.""" + required_fields = [ + "vocab_size", + "max_position_embeddings", + "hidden_size", + "intermediate_size", + "num_hidden_layers", + "num_attention_heads", + "num_key_value_heads", + "rope_theta", + "rope_scaling", + ] + + eagle_config = {} + for field in required_fields: + if field not in config_dict: + raise KeyError(f"Required field '{field}' not found in config") + eagle_config[field] = config_dict[field] + + # Handle head_dim + if "head_dim" in config_dict: + eagle_config["head_dim"] = config_dict["head_dim"] + else: + ad_logger.warning( + "Warning: head_dim not found in config, calculating as hidden_size // num_attention_heads" + ) + eagle_config["head_dim"] = config_dict["hidden_size"] // config_dict["num_attention_heads"] + if "partial_rotary_factor" in config_dict: + eagle_config["partial_rotary_factor"] = config_dict["partial_rotary_factor"] + else: + eagle_config["partial_rotary_factor"] = 1.0 + + eagle_config["model_type"] = "eagle3_base" + return eagle_config + + +def _export_eagle_draft_config(config_dict: Dict[str, Any]) -> Dict[str, Any]: + """Export EAGLE draft configuration with required fields.""" + required_fields = [ + "hidden_size", + "max_position_embeddings", + "intermediate_size", + "num_hidden_layers", + "num_attention_heads", + "num_key_value_heads", + "rope_theta", + "rope_scaling", + ] + + draft_config = {} + for field in required_fields: + if field not in config_dict: + raise KeyError(f"Required field '{field}' not found in config") + draft_config[field] = config_dict[field] + + # Handle head_dim + if "head_dim" in config_dict: + draft_config["head_dim"] = config_dict["head_dim"] + else: + ad_logger.warning( + "Warning: head_dim not found in config, calculating as hidden_size // num_attention_heads" + ) + draft_config["head_dim"] = config_dict["hidden_size"] // config_dict["num_attention_heads"] + + # Handle draft_vocab_size based on EAGLE version + if "draft_vocab_size" not in config_dict: + raise KeyError("Required field 'draft_vocab_size' not found in config") + draft_config["draft_vocab_size"] = config_dict["draft_vocab_size"] + + # Add base model configuration fields + # The target_hidden_size from the model config represents the base model's hidden dimension + if "target_hidden_size" in config_dict: + # Use target_hidden_size * 3 as the base model hidden dimension (as per llm_export.py logic) + draft_config["base_model_hidden_size"] = config_dict["target_hidden_size"] * 3 + else: + # Fallback: assume base model hidden size is 3x draft model (Eagle3 default) + draft_config["base_model_hidden_size"] = config_dict["hidden_size"] * 3 + ad_logger.warning( + f"Warning: target_hidden_size not found, using default 3x draft hidden size: " + f"{draft_config['base_model_hidden_size']}" + ) + + # Set model_type for draft + draft_config["model_type"] = "eagle3_draft" + + return draft_config + + +def export_vision_config(config: Any) -> Dict[str, Any]: + """Export vision configuration without modification.""" + config_dict = config.to_dict() + + has_vision = "vision_config" in config_dict + has_phi4_vision = "image_embd_layer" in config_dict.get("embd_layer", {}) + if not (has_vision or has_phi4_vision): + raise KeyError( + "Required field 'vision_config' or 'image_embd_layer' in 'embd_layer' not found in config" + ) + # Add EdgeLLM API version + config_dict["edgellm_version"] = EDGELLM_VERSION + + # Return the original config_dict as-is without any modification + # Since MRoPE needs LLM config, ViTRunner will use the LLM config. + return config_dict + + +def export_llm_config(config: Any, model_type: str) -> Dict[str, Any]: + """Export configuration based on model type and EAGLE version.""" + config_dict = config.to_dict() + + # Extract model name from config class + config_class_name = config.__class__.__name__ + model_name = config_class_name.lower().replace("config", "") + + # For other model types, use text_config if available + if "text_config" in config_dict: + ad_logger.info("Detected multimodal model, using text_config") + config_dict = config_dict["text_config"] + + if model_type == "llm": + output_config = _export_native_llm_config(config_dict) + elif model_type == "eagle3_base": + output_config = _export_eagle_base_config(config_dict) + elif model_type == "eagle_draft": + output_config = _export_eagle_draft_config(config_dict) + else: + raise ValueError(f"Unsupported model type: {model_type}") + + # Add model name to output + output_config["model"] = model_name + + # Add EdgeLLM API version + output_config["edgellm_version"] = EDGELLM_VERSION + + return output_config diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/_onnx_schemas.py b/tensorrt_llm/_torch/auto_deploy/transform/library/_onnx_schemas.py new file mode 100644 index 0000000000..0a13f780f8 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/_onnx_schemas.py @@ -0,0 +1,258 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from onnx import defs +from onnxscript.values import Opset + +_TRT_DOMAIN_NAME = "trt" +_AUTO_DEPLOY_DOMAIN_NAME = "auto_deploy" + +# public opset objects, used in ONNX translation functions +trt_opset = Opset(_TRT_DOMAIN_NAME, 1) +auto_deploy_opset = Opset(_AUTO_DEPLOY_DOMAIN_NAME, 1) + + +# ONNX Custom Op Registration for RoPE +_torch_rope_with_explicit_cos_sin_schema = defs.OpSchema( + name="rope_with_explicit_cos_sin", + domain=_AUTO_DEPLOY_DOMAIN_NAME, + since_version=1, + doc="Rope with explicit cos and sin caches.", + inputs=[ + defs.OpSchema.FormalParameter( + name="q", + description="Q tensor", + type_str="T", + ), + defs.OpSchema.FormalParameter( + name="k", + description="K tensor", + type_str="T", + ), + defs.OpSchema.FormalParameter( + name="cos", + description="Cos cache", + type_str="T", + ), + defs.OpSchema.FormalParameter( + name="sin", + description="Sin cache", + type_str="T", + ), + ], + outputs=[ + defs.OpSchema.FormalParameter( + name="output", + description="Output tensor", + type_str="T", + ) + ], + type_constraints=[ + ( + "T", + ["tensor(float)", "tensor(float16)", "tensor(bfloat16)"], + "Input and output data type.", + ), + ], + attributes=[ + defs.OpSchema.Attribute( + name="unsqueeze_dim", + type=defs.OpSchema.AttrType.INT, + description="Unsqueeze dimension. Must be 1 or 2.", + required=True, + ), + ], +) + + +# ONNX Custom Op Registration for AttentionPlugin +_attention_plugin_schema = defs.OpSchema( + name="AttentionPlugin", + domain=_TRT_DOMAIN_NAME, + since_version=1, + doc="Fused RoPE + Attention operation for efficient inference.", + inputs=[ + defs.OpSchema.FormalParameter( + name="qkv", + description="Concatenated Q, K, V tensors in shape [batch, seq_len, qkv_hidden_size]", + type_str="T", + ), + defs.OpSchema.FormalParameter( + name="past_key_values", + description="Concatenated past K and V cache in shape [batch, 2, num_kv_heads, past_len, head_size]", + type_str="T", + ), + defs.OpSchema.FormalParameter( + name="context_lengths", + description="Context lengths for each sequence in shape [batch]", + type_str="T1", + ), + defs.OpSchema.FormalParameter( + name="rope_rotary_cos_sin", + description="Concatenated cos and sin values for RoPE in shape [max_seq_len, head_dim]", + type_str="T", + ), + defs.OpSchema.FormalParameter( + name="kvcache_start_index", + description="KV cache start index for each sequence in shape [batch]", + type_str="T1", + ), + ], + outputs=[ + defs.OpSchema.FormalParameter( + name="output", + description="Attention output in shape [batch, seq_len, hidden_size]", + type_str="T", + ), + defs.OpSchema.FormalParameter( + name="present_key_values", + description="Updated K and V cache", + type_str="T", + ), + ], + type_constraints=[ + ( + "T", + ["tensor(float16)", "tensor(float)", "tensor(bfloat16)"], + "Input and output data type for floating point tensors.", + ), + ( + "T1", + ["tensor(int32)", "tensor(int64)"], + "Input data type for integer tensors.", + ), + ], + attributes=[ + defs.OpSchema.Attribute( + name="enable_tree_attention", + type=defs.OpSchema.AttrType.INT, + description="Whether to enable tree attention (0 or 1).", + required=True, + ), + defs.OpSchema.Attribute( + name="head_size", + type=defs.OpSchema.AttrType.INT, + description="Size of each attention head.", + required=True, + ), + defs.OpSchema.Attribute( + name="num_kv_heads", + type=defs.OpSchema.AttrType.INT, + description="Number of key-value heads.", + required=True, + ), + defs.OpSchema.Attribute( + name="num_q_heads", + type=defs.OpSchema.AttrType.INT, + description="Number of query heads.", + required=True, + ), + ], +) + + +# ONNX Custom Op Registration for torch_attention +_torch_attention_schema = defs.OpSchema( + name="torch_attention", + domain=_AUTO_DEPLOY_DOMAIN_NAME, + since_version=1, + doc="SDPA attention (with optional GQA) that supports bnsd and bsnd memory layouts.", + inputs=[ + defs.OpSchema.FormalParameter( + name="query", + description="Query tensor [batch, seq_len_q/num_heads, num_heads/seq_len_q, head_dim]", + type_str="T", + ), + defs.OpSchema.FormalParameter( + name="key", + description="Key tensor [batch, seq_len_k/num_kv_heads, num_kv_heads/seq_len_k, head_dim]", + type_str="T", + ), + defs.OpSchema.FormalParameter( + name="value", + description="Value tensor [batch, seq_len_k/num_kv_heads, num_kv_heads/seq_len_k, head_dim]", + type_str="T", + ), + defs.OpSchema.FormalParameter( + name="attn_mask", + description="Optional attention mask in [batch, num_heads, seq_len_q, seq_len_k] layout", + type_str="T", + ), + defs.OpSchema.FormalParameter( + name="sinks", + description="Optional sinks tensor", + type_str="T", + ), + ], + outputs=[ + defs.OpSchema.FormalParameter( + name="output", + description="Attention output in the same layout as inputs", + type_str="T", + ) + ], + type_constraints=[ + ( + "T", + ["tensor(float16)", "tensor(float)", "tensor(bfloat16)"], + "Input and output data type for floating point tensors.", + ), + ], + attributes=[ + defs.OpSchema.Attribute( + name="dropout_p", + type=defs.OpSchema.AttrType.FLOAT, + description="Dropout probability.", + required=False, + ), + defs.OpSchema.Attribute( + name="is_causal", + type=defs.OpSchema.AttrType.INT, + description="Whether to apply causal masking (0 or 1).", + required=False, + ), + defs.OpSchema.Attribute( + name="scale", + type=defs.OpSchema.AttrType.FLOAT, + description="Attention scale factor.", + required=False, + ), + defs.OpSchema.Attribute( + name="sliding_window", + type=defs.OpSchema.AttrType.INT, + description="Sliding window size for attention.", + required=False, + ), + defs.OpSchema.Attribute( + name="logit_cap", + type=defs.OpSchema.AttrType.FLOAT, + description="Logit capping value.", + required=False, + ), + defs.OpSchema.Attribute( + name="layout", + type=defs.OpSchema.AttrType.STRING, + description="Memory layout: 'bnsd' or 'bsnd'.", + required=False, + ), + ], +) + + +def register_onnx_schemas(): + """Register ONNX custom ops.""" + defs.register_schema(_torch_rope_with_explicit_cos_sin_schema) + defs.register_schema(_torch_attention_schema) + defs.register_schema(_attention_plugin_schema) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/adapt_to_edgellm.py b/tensorrt_llm/_torch/auto_deploy/transform/library/adapt_to_edgellm.py new file mode 100644 index 0000000000..90ef10bad1 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/adapt_to_edgellm.py @@ -0,0 +1,166 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import operator +from typing import Tuple + +import torch +from torch.fx import GraphModule + +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ...utils.logger import ad_logger +from ...utils.node_utils import is_op +from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry + + +@TransformRegistry.register("adapt_to_edgellm") +class AdaptToEdgeLLM(BaseTransform): + """Transform that adapts the model graph for EdgeLLM deployment. + + This transform performs several modifications to make the model compatible + with EdgeLLM runtime requirements: + + 1. Converts all model weights to float16 precision + 2. Adds float32 cast after the final linear layer (for logits output) + 3. Inserts float16 casts after attention output reshapes + 4. Changes any bfloat16 casts to float16 (EdgeLLM may not support bfloat16) + + These modifications ensure proper data type handling throughout the model + while maintaining numerical precision where needed (e.g., logits output). + """ + + def _add_cast_after_last_linear(self, gm: GraphModule) -> torch.fx.Node: + """Add a float32 cast operation after the final linear layer. + + The final linear layer produces logits which are typically kept in + float32 for numerical stability during softmax and sampling operations. + + Args: + gm: The GraphModule to modify. + + Returns: + The newly created cast node. + """ + graph = gm.graph + linear_nodes = graph.find_nodes( + op="call_function", target=torch.ops.auto_deploy.torch_linear_simple.default, sort=True + ) + assert len(linear_nodes) > 0, "No linear nodes found" + last_linear_node = linear_nodes[-1] + with graph.inserting_after(last_linear_node): + cast_node = graph.call_function( + torch.ops.aten.to.dtype, args=(last_linear_node, torch.float32) + ) + last_linear_node.replace_all_uses_with(cast_node) + # Restore cast_node's input to last_linear_node + cast_node.update_arg(0, last_linear_node) + return cast_node + + def _insert_cast_after_attn_reshape(self, gm: GraphModule) -> int: + """Insert float16 cast after reshape nodes following AttentionPlugin output. + + The AttentionPlugin may output tensors that need explicit casting to float16 + before being consumed by subsequent operations (e.g., linear projections). + This ensures consistent data types throughout the attention block. + + Graph transformation: + Before: AttentionPlugin[0] -> Reshape -> MatMul + After: AttentionPlugin[0] -> Reshape -> Cast(to float16) -> MatMul + + Args: + gm: The GraphModule to modify. + + Returns: + Number of cast nodes inserted. + """ + graph = gm.graph + + # Find all AttentionPlugin nodes + attention_plugin_nodes = graph.find_nodes( + op="call_function", target=torch.ops.auto_deploy.torch_onnx_attention_plugin.default + ) + + num_inserted = 0 + for attn_node in attention_plugin_nodes: + # Find getitem[0] for this AttentionPlugin (first output) + for user in attn_node.users: + if is_op(user, operator.getitem) and user.args[1] == 0: + getitem_0_node = user + # Find reshape nodes that use this getitem[0] + for reshape_user in list(getitem_0_node.users): + if is_op(reshape_user, torch.ops.aten.reshape.default): + reshape_node = reshape_user + # Insert cast (to float16) after reshape + with graph.inserting_after(reshape_node): + cast_node = graph.call_function( + torch.ops.aten.to.dtype, + args=(reshape_node, torch.float16), + ) + reshape_node.replace_all_uses_with(cast_node) + # Fix: restore cast_node's input to reshape_node + # (replace_all_uses_with also replaced it) + cast_node.update_arg(0, reshape_node) + num_inserted += 1 + ad_logger.debug(f"Inserted cast (to float16) after {reshape_node.name}") + + return num_inserted + + def _change_cast_bfloat16_to_float16(self, gm: GraphModule) -> int: + """Replace all bfloat16 cast operations with float16 casts. + + EdgeLLM or certain hardware backends may not support bfloat16 natively. + This method converts all bfloat16 casts to float16 for compatibility. + + Args: + gm: The GraphModule to modify. + + Returns: + Number of cast operations changed. + """ + graph = gm.graph + cast_nodes = graph.find_nodes(op="call_function", target=torch.ops.aten.to.dtype) + num_changed = 0 + for cast_node in cast_nodes: + if cast_node.args[1] == torch.bfloat16: + cast_node.update_arg(1, torch.float16) + num_changed += 1 + return num_changed + + def _to_float16(self, gm: GraphModule) -> None: + """Convert all model parameters and buffers to float16 precision. + + Args: + gm: The GraphModule to convert. + """ + gm.half() + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + self._to_float16(gm) + logits_cast = self._add_cast_after_last_linear(gm) + assert logits_cast is not None, "Failed to add cast after last linear" + num_attn_casts = self._insert_cast_after_attn_reshape(gm) + num_bfloat16_casts = self._change_cast_bfloat16_to_float16(gm) + ad_logger.info(f"Changed {num_bfloat16_casts} bfloat16 casts to float16") + ad_logger.info(f"Adapted EdgeLLM model (inserted {num_attn_casts} attention casts)") + + return gm, TransformInfo( + skipped=False, num_matches=num_attn_casts, is_clean=False, has_valid_shapes=True + ) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_onnx.py b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_onnx.py new file mode 100644 index 0000000000..daa7f155f1 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_onnx.py @@ -0,0 +1,418 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from pathlib import Path +from typing import Optional, Tuple, Type + +import torch +from onnxscript import ir, opset20 +from pydantic import Field +from torch.export import Dim +from torch.fx import GraphModule + +from ...models import hf +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ...utils.logger import ad_logger +from ..interface import ( + BaseTransform, + SharedConfig, + TransformConfig, + TransformInfo, + TransformRegistry, +) +from . import _onnx_schemas +from ._chat_template import process_chat_template +from ._config_export import export_llm_config + + +class ExportToONNXConfig(TransformConfig): + """Configuration for the export to ONNX transform.""" + + output_dir: Path = Field( + description="The directory to save the exported ONNX model.", + ) + is_eagle_base: bool = Field( + description="Whether the model is an Eagle base model.", + default=False, + ) + + +# ============================================================================ +# Custom translation functions for ONNX export +# ============================================================================ + + +def _translate_rope_op( + q: ir.Tensor, k: ir.Tensor, cos: ir.Tensor, sin: ir.Tensor, unsqueeze_dim: int +): + return _onnx_schemas.auto_deploy_opset.rope_with_explicit_cos_sin( + q, k, cos, sin, unsqueeze_dim=unsqueeze_dim + ) + + +def _translate_simple_linear_op(input: ir.Tensor, weight: ir.Tensor, bias: Optional[ir.Tensor]): + weight = opset20.Transpose(weight, perm=[1, 0]) + if bias is None: + return opset20.MatMul(input, weight) + return opset20.Add(opset20.MatMul(input, weight), bias) + + +def _translate_gather_nd_op(data: ir.Tensor, indices: ir.Tensor, batch_dims: int): + return opset20.GatherND(data, indices, batch_dims=batch_dims) + + +def _translate_rope_attention_op( + qkv: ir.Tensor, + past_key_values: ir.Tensor, + context_lengths: ir.Tensor, + rope_rotary_cos_sin: ir.Tensor, + kvcache_start_index: ir.Tensor, + enable_tree_attention: int, + head_size: int, + num_kv_heads: int, + num_q_heads: int, +): + """ + ONNX custom op translation function for AttentionPlugin. + + This function creates a custom ONNX op node in the trt domain. + The actual implementation will need to be provided in the inference engine + (e.g., ONNX Runtime or TensorRT) that loads this ONNX model. + + Note: This is a translation function for torch.onnx.export's custom_translation_table, + so it should NOT have @script() decorator. + """ + # Call the custom op from the trt domain + return _onnx_schemas.trt_opset.AttentionPlugin( + qkv, + past_key_values, + context_lengths, + rope_rotary_cos_sin, + kvcache_start_index, + enable_tree_attention=enable_tree_attention, + head_size=head_size, + num_kv_heads=num_kv_heads, + num_q_heads=num_q_heads, + ) + + +def _translate_torch_attention_op( + query: ir.Tensor, + key: ir.Tensor, + value: ir.Tensor, + attn_mask: Optional[ir.Tensor] = None, + sinks: Optional[ir.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + sliding_window: Optional[int] = None, + logit_cap: Optional[float] = None, + layout: str = "bnsd", +): + """ + ONNX custom op translation function for torch_attention. + + This function creates a custom ONNX op node in the auto_deploy domain. + The actual implementation will need to be provided in the inference engine + (e.g., ONNX Runtime or TensorRT) that loads this ONNX model. + + Note: This is a translation function for torch.onnx.export's custom_translation_table, + so it should NOT have @script() decorator. + """ + # Call the custom op from the auto_deploy domain + return _onnx_schemas.auto_deploy_opset.torch_attention( + query, + key, + value, + attn_mask, + sinks, + dropout_p=dropout_p, + is_causal=1 if is_causal else 0, + scale=scale, + sliding_window=sliding_window, + logit_cap=logit_cap, + layout=layout, + ) + + +@TransformRegistry.register("export_to_onnx") +class ExportToONNX(BaseTransform): + """Transform that exports a PyTorch GraphModule to ONNX format for deployment. + + This transform is responsible for: + 1. Exporting the model graph to ONNX format with dynamic shapes support + 2. Generating configuration files (config.json) for the exported model + 3. Saving tokenizer files (tokenizer.json, vocab.json, etc.) + 4. Processing and exporting chat templates + + The exported ONNX model includes custom ops from the auto_deploy. These custom ops include: + - torch_onnx_attention_plugin: Fused RoPE + Attention for efficient inference(exported as EdgeLLM's custom op) + - torch_onnx_gather_nd: N-dimensional gather operation (exported as onnxscript.opset20.GatherND) + + Note: + This transform does NOT modify the input graph. It only exports the graph + to external files and returns the original graph unchanged. + + Attributes: + config: ExportToONNXConfig containing output directory, and is_eagle_base flag. + """ + + config: ExportToONNXConfig + + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + """Return the configuration class for this transform.""" + return ExportToONNXConfig + + def _export_onnx_model( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + _factory: ModelFactory, + _shared_config: SharedConfig, + ) -> bool: + """Export the GraphModule to ONNX format using torch.onnx.export with dynamo. + + This method handles the core ONNX export logic: + 1. Extracts input placeholders and their metadata from the graph + 2. Configures dynamic shapes for batch size, sequence length, and KV cache + 3. Sets up custom translation table for auto_deploy custom ops + 4. Exports the model using torch.onnx.export with dynamo=True + + Args: + gm: The PyTorch FX GraphModule to export. + cm: CachedSequenceInterface containing max batch size and sequence length info. + _factory: ModelFactory instance (unused in this method). + _shared_config: SharedConfig instance (unused in this method). + + Returns: + bool: True if export was successful. + """ + # Extract input placeholders from graph to build kwargs for export + args = [] + kwargs = {} + placeholders = gm.graph.find_nodes(op="placeholder") + for ph in placeholders: + kwargs[ph.name] = ph.meta["val"] + args = tuple(args) + + ad_logger.info("Placeholders args:") + for i, e in enumerate(args): + ad_logger.info(f" {i}: {placeholders[i].name:20} {e}") + + ad_logger.info("Placeholders kwargs:") + for k, v in kwargs.items(): + ad_logger.info(f" {k}: {v}") + + # Build output path + output_path = self.config.output_dir / "model.onnx" + + # Build dynamic_shapes for dynamo export + # For dynamo, we need to specify dynamic dimensions for each input tensor + dynamic_shapes = {} + dynamic_shapes["input_ids"] = { + 0: Dim("batch_size", min=0, max=cm.info.max_batch_size), + 1: Dim("seq_len", min=0, max=cm.info.max_seq_len), + } + # Add dynamic shapes for context_lengths and rope_rotary_cos_sin + dynamic_shapes["context_lengths"] = { + 0: Dim("batch_size", min=0, max=cm.info.max_batch_size) + } + dynamic_shapes["rope_rotary_cos_sin"] = { + 0: Dim("rope_batch_size", min=1, max=16), + 1: Dim("max_position_embeddings", min=1, max=4096), + } + dynamic_shapes["kvcache_start_index"] = { + 0: Dim("kv_cache_start_batch_size", min=0, max=cm.info.max_batch_size) + } + # Add dynamic shapes for past_key_values + num_layers = len( + gm.graph.find_nodes( + op="call_function", target=torch.ops.auto_deploy.torch_onnx_attention_plugin.default + ) + ) + for i in range(num_layers): + dynamic_shapes[f"past_key_values_{i}"] = { + 0: Dim("batch_size", min=0, max=cm.info.max_batch_size), + 3: Dim("past_len", min=1, max=4096), + } + dynamic_shapes["last_token_ids"] = { + 0: Dim("batch_size", min=0, max=cm.info.max_batch_size), + 1: Dim("num_selected_tokens", min=1, max=cm.info.max_seq_len), + } + + # Create custom translation table for ONNX export + # Map torch custom ops to their corresponding onnxscript translation functions + custom_translation_table = { + # Before fuse rope attention + # NOTE (yoco): This 2 ops will be fused into the AttentionPlugin operation + # in the fuse_rope_attention transform. + # However, when TensorRT-LLM changed, the fusion might not be applied. + # And for debug purpose we might want to export the .onnx to check the graph. + # So let's just keep them here. + torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin.default: _translate_rope_op, + torch.ops.auto_deploy.torch_attention.default: _translate_torch_attention_op, + # Before and after fuse rope attention + torch.ops.auto_deploy.torch_linear_simple.default: _translate_simple_linear_op, + # After fuse rope attention + torch.ops.auto_deploy.torch_onnx_attention_plugin.default: _translate_rope_attention_op, + torch.ops.auto_deploy.torch_onnx_gather_nd.default: _translate_gather_nd_op, + } + + # Prepare output names + output_names = [] + output_node = gm.graph.find_nodes(op="output")[0] + outputs = output_node.args[0] + for output in outputs: + output_names.append(output.name) + output_names[0] = "logits" + + # Register ONNX custom ops + _onnx_schemas.register_onnx_schemas() + + # Export the graph module to ONNX using dynamo (more advanced tracer) + ad_logger.info(f"Exporting GraphModule to ONNX with dynamo: {output_path}") + torch.onnx.export( + gm, + tuple(args), + output_path, + opset_version=20, + kwargs=kwargs, + dynamo=True, + dynamic_shapes=dynamic_shapes, + report=False, + output_names=output_names, + custom_translation_table=custom_translation_table, + ) + + ad_logger.info(f"Successfully exported ONNX model to {output_path}") + return True + + def _export_config_json( + self, factory: ModelFactory, output_dir: str, is_eagle_base: bool + ) -> None: + """Export model configuration to config.json. + + Generates a configuration file containing model architecture parameters + such as hidden size, number of layers, attention heads, etc. + + Args: + factory: The ModelFactory containing model configuration. + output_dir: Directory path where config.json will be saved. + is_eagle_base: If True, exports as "eagle3_base" model type, + otherwise exports as "llm" model type. + """ + model_type = "eagle3_base" if is_eagle_base else "llm" + assert isinstance(factory, hf.AutoModelFactory) + model_config, _ = factory._get_model_config() + model_config = export_llm_config(model_config, model_type) + # Add reduced_vocab_size to config if vocabulary reduction is used + reduced_vocab_size = None # TODO: Implement this + if reduced_vocab_size is not None: + model_config["reduced_vocab_size"] = reduced_vocab_size + ad_logger.info(f"Added reduced_vocab_size={reduced_vocab_size} to config") + + config_path = os.path.join(output_dir, "config.json") + with open(config_path, "w") as f: + json.dump(model_config, f, indent=2) + ad_logger.info(f"Model configuration saved to {config_path}") + + def _export_json_files( + self, + gm: GraphModule, + _cm: CachedSequenceInterface, + factory: ModelFactory, + _shared_config: SharedConfig, + ) -> None: + """Export all JSON configuration and tokenizer files required for deployment. + + This method orchestrates the export of: + 1. config.json - Model architecture configuration + 2. Tokenizer files - added_tokens.json, special_tokens_map.json, + tokenizer.json, tokenizer_config.json, vocab.json + 3. processed_chat_template.json - Processed chat template for inference + + Args: + gm: The GraphModule containing model configuration. + _cm: CachedSequenceInterface (unused, kept for interface consistency). + factory: ModelFactory used to initialize tokenizer and get model directory. + _shared_config: SharedConfig (unused, kept for interface consistency). + """ + # Export model configuration (architecture params, layer counts, etc.) + is_eagle_base = self.config.is_eagle_base + output_dir = self.config.output_dir + self._export_config_json(factory, output_dir, is_eagle_base) + + # Export tokenizer files for text processing during inference + # Includes: added_tokens.json, special_tokens_map.json, tokenizer.json, + # tokenizer_config.json, vocab.json + tokenizer = factory.init_tokenizer() + tokenizer.save_pretrained(output_dir) + + # Export processed chat template for conversational inference + model_dir = factory.model + process_chat_template(model_dir, output_dir) + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + """Apply the ONNX export transform to the graph module. + + This is the main entry point that orchestrates the export process: + 1. Creates the output directory if it doesn't exist + 2. Exports all JSON configuration files (config, tokenizer, chat template) + 3. Exports the ONNX model with dynamic shapes and custom ops + + Note: + Unlike other transforms, this does NOT modify the graph. + It only performs side effects (writing files to disk). + + Args: + gm: The PyTorch FX GraphModule to export. + cm: CachedSequenceInterface containing runtime constraints. + factory: ModelFactory for tokenizer initialization. + shared_config: SharedConfig for transform coordination. + + Returns: + Tuple containing: + - gm: The original GraphModule (unchanged) + - info: TransformInfo with export status (is_clean=True since graph is unmodified) + """ + # Ensure output directory exists before writing any files + if not self.config.output_dir.exists(): + self.config.output_dir.mkdir(parents=True, exist_ok=True) + + # Step 1: Export all auxiliary JSON files (config, tokenizer, chat template) + self._export_json_files(gm, cm, factory, shared_config) + + # Step 2: Export the ONNX model with dynamic shapes + success = self._export_onnx_model(gm, cm, factory, shared_config) + + # Return original graph unchanged with export status info + # This transform is "clean" because it doesn't modify the graph structure + info = TransformInfo( + skipped=not success, + num_matches=1 if success else 0, + is_clean=True, # Graph is not modified + has_valid_shapes=True, # Shape validity is preserved + ) + + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rope_attention.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rope_attention.py new file mode 100644 index 0000000000..bfbacea513 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rope_attention.py @@ -0,0 +1,551 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import operator +from typing import List, Tuple + +import torch +import transformers +from torch.fx import GraphModule, Node + +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ...utils._graph import add_graph_input, add_graph_output, remove_graph_input +from ...utils.logger import ad_logger +from ...utils.node_utils import extract_op_args, is_op +from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry + + +class MatchResult: + """Container for matched RoPE + attention pattern nodes. + + This class stores all the relevant nodes and metadata from a successfully + matched RoPE (Rotary Position Embedding) + attention pattern in the graph. + It is used to facilitate the pattern replacement during the fusion transform. + + Attributes: + q: The original query tensor node before view/reshape. + k: The original key tensor node before view/reshape. + v: The original value tensor node before view/reshape. + cos: The cosine embedding node for RoPE. + sin: The sine embedding node for RoPE. + attn_node: The attention operation node to be replaced. + rope_node: The RoPE operation node to be fused. + head_dim: Dimension of each attention head. + num_q_heads: Number of query attention heads. + num_kv_heads: Number of key-value attention heads (for GQA/MQA). + """ + + def __init__( + self, + q: Node, + k: Node, + v: Node, + cos: Node, + sin: Node, + attn_node: Node, + rope_node: Node, + head_dim: int, + num_q_heads: int, + num_kv_heads: int, + ): + """Initialize MatchResult with matched pattern nodes and metadata. + + Args: + q: Query tensor node. + k: Key tensor node. + v: Value tensor node. + cos: RoPE cosine node. + sin: RoPE sine node. + attn_node: Attention operation node. + rope_node: RoPE operation node. + head_dim: Attention head dimension. + num_q_heads: Number of query heads. + num_kv_heads: Number of key-value heads. + """ + self.q = q + self.k = k + self.v = v + self.cos = cos + self.sin = sin + self.attn_node = attn_node + self.rope_node = rope_node + self.head_dim = head_dim + self.num_q_heads = num_q_heads + self.num_kv_heads = num_kv_heads + + def __repr__(self): + """Return string representation of the match result.""" + return ( + f"MatchResult(q={self.q.name}, k={self.k.name}, v={self.v.name}, " + f"cos={self.cos.name}, sin={self.sin.name}, head_dim={self.head_dim}, " + f"num_q_heads={self.num_q_heads}, num_kv_heads={self.num_kv_heads})" + ) + + +@TransformRegistry.register("fuse_rope_attention") +class FuseRopeAttention(BaseTransform): + """Transform that fuses RoPE and attention operations into a single AttentionPlugin. + + This transform identifies patterns in the graph where RoPE (Rotary Position + Embedding) is applied to query and key tensors followed by scaled dot-product + attention, and replaces them with a fused AttentionPlugin operation. + + The fusion provides several benefits: + - Reduced memory bandwidth by eliminating intermediate tensors + - Enables KV-cache integration for efficient autoregressive generation + - Allows backend-specific optimizations (e.g., TensorRT, EdgeLLM) + + Pattern matched (backwards from attention): + 1. torch_attention(rope_q, rope_k, view_v, attn_mask, ...) + 2. rope_q = rope[1], rope_k = rope[0] + 3. rope = torch_rope_with_explicit_cos_sin(cont_q, cont_k, cos, sin, 2) + 4. cont_q = contiguous(view_q), cont_k = contiguous(view_k) + 5. view_q = view(q, ...), view_k = view(k, ...), view_v = view(v, ...) + + The transform also: + - Adds new graph inputs: context_lengths, rope_rotary_cos_sin, kvcache_start_index + - Adds past_key_values inputs for each attention layer + - Adds present_key_values outputs for each attention layer + - Removes the position_ids placeholder (no longer needed after fusion) + """ + + def _get_batch_size_and_max_seq_len(self, gm: GraphModule) -> Tuple[int, int]: + """Get batch size and max sequence length from the graph. + + Args: + gm: The GraphModule to get batch size and max sequence length from. + + Returns: + Tuple of (batch_size_dim, max_seq_len_dim, batch_size_sym_node, max_seq_len_sym_node). + + Note: + For clarity, we return the symbolic nodes as well, which can be + used in operations like view or reshape. + """ + graph = gm.graph + input_ids_node = graph.find_nodes(op="placeholder", target="input_ids")[0] + position_ids_node = graph.find_nodes(op="placeholder", target="position_ids")[0] + input_ids_meta = input_ids_node.meta.get("val") + batch_size_dim = input_ids_meta.size(0) + max_seq_len_dim = input_ids_meta.size(1) + + # We scan all the sym_size.int nodes to find the symbolic nodes for + # batch size and max sequence length. The symbolic nodes has two sources. + # 1. The input_ids placeholder. + # 2. The position_ids placeholder. + # Both of their shapes are (batch_size, max_seq_len). + sym_ints = graph.find_nodes(op="call_function", target=torch.ops.aten.sym_size.int) + for sym_int in sym_ints: + if sym_int.args[0] != input_ids_node and sym_int.args[0] != position_ids_node: + continue + if sym_int.args[1] == 0: + batch_size_sym_node = sym_int + elif sym_int.args[1] == 1: + max_seq_len_sym_node = sym_int + assert batch_size_sym_node is not None and max_seq_len_sym_node is not None + + return batch_size_dim, max_seq_len_dim, batch_size_sym_node, max_seq_len_sym_node + + def _get_config_head_dim(self, model_config: transformers.PretrainedConfig) -> int: + """Get head dimension from model config.""" + if hasattr(model_config, "head_dim"): + return model_config.head_dim + else: + return model_config.hidden_size // model_config.num_attention_heads + + def _match_rope_attention_pattern( + self, gm: GraphModule, model_config: transformers.PretrainedConfig + ) -> List[MatchResult]: + """Match RoPE + attention patterns in the computation graph. + + Traverses the graph backwards from attention nodes to identify the + complete pattern of RoPE application followed by attention computation. + + Pattern structure (backwards from attention): + 1. torch_attention(rope_q, rope_k, bind_v, attn_mask, ...) + 2. rope_q, rope_k = rope[1], rope[0] + 3. rope = torch_rope_with_explicit_cos_sin(bind_q, bind_k, cos, sin, 2) + + Args: + gm: The GraphModule to search for patterns. + + Returns: + List of MatchResult objects, each containing the matched nodes + and extracted metadata (head_dim, num_q_heads, num_kv_heads). + """ + matches = [] + graph = gm.graph + head_dim = self._get_config_head_dim(model_config) + + # Iterate through all nodes to find attention ops + for attn_node in graph.nodes: + if not is_op(attn_node, torch.ops.auto_deploy.torch_attention): + continue + + if attn_node.args[10] != "bsnd": + ad_logger.error( + f" Skipping: attention layout is not bsnd: {attn_node.kwargs.get('layout', None)}" + ) + continue + + ad_logger.debug(f"Found attention node: {attn_node.name}") + + # Extract attention inputs: (rope_q, rope_k, v, attn_mask, ...) + if len(attn_node.args) < 4: + ad_logger.error(f" Skipping: insufficient args ({len(attn_node.args)})") + continue + + rope_q_node, rope_k_node, bind_v = extract_op_args(attn_node, "query", "key", "value") + + # Step 1: Match rope_q and rope_k as getitem[1] and getitem[0] from rope output + if not (is_op(rope_q_node, operator.getitem) and is_op(rope_k_node, operator.getitem)): + ad_logger.error(" Skipping: rope_q or rope_k not getitem") + continue + + # Verify they come from the same rope node + assert rope_q_node.target == operator.getitem + assert rope_k_node.target == operator.getitem + rope_node_from_q, rope_q_idx = rope_q_node.args + rope_node_from_k, rope_k_idx = rope_k_node.args + + if rope_node_from_q != rope_node_from_k: + ad_logger.error(" Skipping: rope_q and rope_k come from different rope nodes") + continue + + rope_node = rope_node_from_q + + # Verify getitem indices: rope[0] = rope_k, rope[1] = rope_q + if rope_k_idx != 0 or rope_q_idx != 1: + ad_logger.error( + f" Skipping: incorrect getitem indices (k={rope_k_idx}, q={rope_q_idx})" + ) + continue + + # Step 2: Match the rope node + if not is_op(rope_node, torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin): + ad_logger.error(" Skipping: not a rope node") + continue + + # Extract rope inputs: (cont_q, cont_k, cos, sin, 2) + if len(rope_node.args) < 5: + ad_logger.error(f" Skipping: rope has insufficient args ({len(rope_node.args)})") + continue + + bind_k = rope_node.args[0] + bind_q = rope_node.args[1] + cos_node = rope_node.args[2] + sin_node = rope_node.args[3] + + num_q_heads = model_config.num_attention_heads + num_kv_heads = model_config.num_key_value_heads + + # Successfully matched the pattern! + match = MatchResult( + q=bind_q, + k=bind_k, + v=bind_v, + cos=cos_node, + sin=sin_node, + attn_node=attn_node, + rope_node=rope_node, + head_dim=head_dim, + num_q_heads=num_q_heads, + num_kv_heads=num_kv_heads, + ) + matches.append(match) + ad_logger.debug(f" ✓ Matched pattern: {match}") + + return matches + + def _add_global_placeholders(self, gm: GraphModule, factory: ModelFactory) -> tuple: + """Add global input placeholders required by the fused AttentionPlugin. + + Creates three new graph inputs that are shared across all attention layers: + - context_lengths: Actual sequence length for each batch element + - rope_rotary_cos_sin: Precomputed RoPE embeddings + - kvcache_start_index: Starting position in KV-cache for each batch + + These inputs enable dynamic sequence handling and efficient KV-cache + management during autoregressive generation. + + Args: + gm: GraphModule to add placeholders to. + cm: CachedSequenceInterface for registering dynamic shapes. + + Returns: + Tuple of (context_lengths_node, rope_rotary_cos_sin_node, + kvcache_start_index_node). + """ + + graph = gm.graph + + # Find token_ids placeholder to get batch_size symbolic dimension + token_ids_node = None + for node in graph.nodes: + if node.op == "placeholder" and "token" in node.name.lower(): + token_ids_node = node + break + + if token_ids_node is None: + # Fallback: use first placeholder + token_ids_node = graph.find_nodes(op="placeholder", sort=True)[0] + + batch_size_dim, max_seq_len_dim, _, _ = self._get_batch_size_and_max_seq_len(gm) + ad_logger.debug(f"Extracted batch_size={batch_size_dim}, max_seq_len={max_seq_len_dim}") + + # 1. Add context_lengths placeholder: int32[batch_size] + context_lengths_example = torch.zeros(batch_size_dim, dtype=torch.int32, device="meta") + context_lengths_node = add_graph_input( + gm, name="context_lengths", val=context_lengths_example + ) + ad_logger.debug(f"Added context_lengths placeholder: {context_lengths_node.name}") + + # 2. Add rope_rotary_cos_sin placeholder: float32[rope_batch_size, rope_max_position_length, 64] + # Create with concrete example tensor + model_config, _ = factory._get_model_config() + head_dim = self._get_config_head_dim(model_config) + rope_example = torch.zeros( + batch_size_dim, max_seq_len_dim, head_dim, dtype=torch.float32, device="meta" + ) + rope_rotary_cos_sin_node = add_graph_input(gm, name="rope_rotary_cos_sin", val=rope_example) + ad_logger.debug(f"Added rope_rotary_cos_sin placeholder: {rope_rotary_cos_sin_node.name}") + + # 3. Add kvcache_start_index placeholder: int32[batch_size] + kvcache_start_index_example = torch.zeros(batch_size_dim, dtype=torch.int32, device="meta") + kvcache_start_index_node = add_graph_input( + gm, name="kvcache_start_index", val=kvcache_start_index_example + ) + ad_logger.debug(f"Added kvcache_start_index placeholder: {kvcache_start_index_node.name}") + + return context_lengths_node, rope_rotary_cos_sin_node, kvcache_start_index_node + + def _perform_replacement( + self, + gm: GraphModule, + cm: "CachedSequenceInterface", + matches: List[MatchResult], + context_lengths_node: Node, + rope_rotary_cos_sin_node: Node, + kvcache_start_index_node: Node, + ) -> int: + """Replace matched RoPE + attention patterns with fused AttentionPlugin. + + For each matched pattern, this method: + 1. Creates a past_key_values input placeholder for KV-cache + 2. Reshape Q, K, V to (batch_size, seq_len, -1) + 3. Concatenates reshaped Q, K, V tensors into a single QKV tensor + 4. Inserts the fused AttentionPlugin operation + 5. Creates getitem nodes to extract attention output and present KV-cache + 6. Replaces the original attention node with the fused output + 7. Adds present_key_values to graph outputs + + Args: + gm: The GraphModule being transformed. + cm: CachedSequenceInterface for shape management. + matches: List of matched patterns to replace. + context_lengths_node: Shared context lengths input node. + rope_rotary_cos_sin_node: Shared RoPE embeddings input node. + kvcache_start_index_node: Shared KV-cache index input node. + + Returns: + Number of patterns successfully replaced. + """ + graph = gm.graph + past_len = 4096 # Does not matter, this value will be replaced by symbolic dimension + + # Get batch size & max sequence length + batch_size_dim, _, batch_size_sym_node, max_seq_len_sym_node = ( + self._get_batch_size_and_max_seq_len(gm) + ) + + # Process each match + for match_id, match in enumerate(matches): + ad_logger.debug(f"Processing match {match_id}: {match}") + + # 1. Create past_key_values_ placeholder + # Shape: float16[batch_size, 2, num_kv_heads, past_len, head_dim] + past_key_values_example = torch.zeros( + batch_size_dim, + 2, + match.num_kv_heads, + past_len, + match.head_dim, + dtype=torch.float16, + device="meta", + ) + past_key_values_node = add_graph_input( + gm, name=f"past_key_values_{match_id}", val=past_key_values_example + ) + + ad_logger.debug(f"Added past_key_values_{match_id} placeholder") + + # 2. Reshape Q, K, V to (batch_size, seq_len, -1) + with graph.inserting_before(match.attn_node): + q_node = graph.call_function( + torch.ops.aten.view.default, + args=(match.q, (batch_size_sym_node, max_seq_len_sym_node, -1)), + ) + k_node = graph.call_function( + torch.ops.aten.view.default, + args=(match.k, (batch_size_sym_node, max_seq_len_sym_node, -1)), + ) + v_node = graph.call_function( + torch.ops.aten.view.default, + args=(match.v, (batch_size_sym_node, max_seq_len_sym_node, -1)), + ) + + ad_logger.debug( + f"Reshaped Q, K, V to (batch_size, seq_len, -1): {q_node.name}, {k_node.name}, {v_node.name}" + ) + + # 3. Concatenate reshaped Q, K, V to (batch_size, seq_len, -1) + with graph.inserting_before(match.attn_node): + qkv_node = graph.call_function( + torch.ops.aten.cat.default, args=([q_node, k_node, v_node], -1) + ) + + ad_logger.debug(f"Created qkv concat node: {qkv_node.name}") + + # 4. Create AttentionPlugin node + enable_tree_attention = 0 + head_dim = match.head_dim + num_kv_heads = match.num_kv_heads + num_q_heads = match.num_q_heads + + with graph.inserting_before(match.attn_node): + cast_node = graph.call_function( + torch.ops.aten.to.dtype, args=(qkv_node, torch.float16) + ) + rope_attn_node = graph.call_function( + torch.ops.auto_deploy.torch_onnx_attention_plugin.default, + args=( + cast_node, + past_key_values_node, + context_lengths_node, + rope_rotary_cos_sin_node, + kvcache_start_index_node, + enable_tree_attention, + head_dim, + num_kv_heads, + num_q_heads, + ), + ) + + ad_logger.debug(f"Created AttentionPlugin node: {rope_attn_node.name}") + + # 5. Create getitem nodes for rope_attn outputs + with graph.inserting_after(rope_attn_node): + rope_attn_output = graph.call_function(operator.getitem, args=(rope_attn_node, 0)) + rope_attn_present_kv = graph.call_function( + operator.getitem, + args=(rope_attn_node, 1), + name=f"present_key_values_{match_id}", + ) + + ad_logger.debug( + f"Created getitem nodes: {rope_attn_output.name}, {rope_attn_present_kv.name}" + ) + + # 6. Replace all uses of attn_node with rope_attn_output + match.attn_node.replace_all_uses_with(rope_attn_output) + ad_logger.debug(f"Replaced {match.attn_node.name} with {rope_attn_output.name}") + + # 7. Add present_key_values_ to graph outputs using add_graph_output + add_graph_output(gm, rope_attn_present_kv, f"present_key_values_{match_id}") + ad_logger.debug(f"Added present_key_values_{match_id} to graph outputs") + + # 8. Eliminate dead code (remove old pattern nodes) + graph.eliminate_dead_code() + ad_logger.debug("Eliminated dead code") + + return len(matches) + + def _remove_position_ids_placeholder(self, gm: GraphModule) -> str: + """Remove the position_ids placeholder from the graph. + + After fusing RoPE into AttentionPlugin, position_ids is no longer needed + as an input since positional information is now provided through + rope_rotary_cos_sin and kvcache_start_index. + + This method also updates any sym_size operations that were reading from + position_ids to use input_ids instead. + + Args: + gm: The GraphModule to modify. + + Returns: + Name of the removed position_ids node. + """ + graph = gm.graph + position_ids_node = graph.find_nodes(op="placeholder", target="position_ids")[0] + input_ids_node = graph.find_nodes(op="placeholder", target="input_ids")[0] + + # Get size from token_ids placeholder instead of position_ids placeholder + sym_size_int_nodes = graph.find_nodes( + op="call_function", target=torch.ops.aten.sym_size.int + ) + for node in sym_size_int_nodes: + if node.args[0] == position_ids_node: + new_args = (input_ids_node, *node.args[1:]) + node.args = new_args + + return remove_graph_input(gm, position_ids_node) + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + ad_logger.debug("Fusing rope attention") + + # Perform pattern matching + model_config, _ = factory._get_model_config() + matches = self._match_rope_attention_pattern(gm, model_config) + cnt = len(matches) + + ad_logger.info(f"Matched {cnt} rope attention patterns") + + if cnt == 0: + ad_logger.info("No patterns matched, skipping replacement") + info = TransformInfo(skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True) + return gm, info + + # Add global placeholders + context_lengths_node, rope_rotary_cos_sin_node, kvcache_start_index_node = ( + self._add_global_placeholders(gm, factory) + ) + + # Perform replacement for all matches + num_replaced = self._perform_replacement( + gm, + cm, + matches, + context_lengths_node, + rope_rotary_cos_sin_node, + kvcache_start_index_node, + ) + + # Remove position_ids placeholder + removed_node_name = self._remove_position_ids_placeholder(gm) + ad_logger.debug(f"Removed position_ids placeholder: {removed_node_name}") + + info = TransformInfo( + skipped=False, num_matches=num_replaced, is_clean=True, has_valid_shapes=True + ) + ad_logger.info(f"Fused {num_replaced} rope attention patterns") + + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/gather_last_token_ids.py b/tensorrt_llm/_torch/auto_deploy/transform/library/gather_last_token_ids.py new file mode 100644 index 0000000000..85654a8bf0 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/gather_last_token_ids.py @@ -0,0 +1,162 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch +from torch.fx import GraphModule, Node + +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ...utils._graph import add_graph_input +from ...utils.logger import ad_logger +from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry + + +@TransformRegistry.register("gather_last_token_ids") +class GatherLastTokenIds(BaseTransform): + def _find_last_linear_simple(self, gm: GraphModule) -> Node: + """Find the final torch_linear_simple node that produces logits. + This is the last matmul operation in the graph + that produces the vocabulary logits. + """ + graph = gm.graph + + # Look for the output node and trace back to find MatMul + output_node = graph.find_nodes(op="output")[0] + + # Look for the last torch_linear_simple node + linear_simple_nodes = graph.find_nodes( + op="call_function", target=torch.ops.auto_deploy.torch_linear_simple.default, sort=True + ) + assert len(linear_simple_nodes) > 0, "No linear simple nodes found" + last_linear_simple_node = linear_simple_nodes[-1] + + # Verify that the last linear simple node is the one producing the logits + if last_linear_simple_node not in output_node.args[0]: + ad_logger.warning( + f"Last linear simple node {last_linear_simple_node.name} is not the one producing the logits" + ) + return None + + return last_linear_simple_node + + def _add_last_token_ids_input(self, gm: GraphModule) -> Node: + """Add last_token_ids as a graph input. + + Shape: int64[batch_size, num_selected_tokens] + """ + graph = gm.graph + + # Find token_ids or input_ids placeholder to get batch_size dimension + token_ids_node = None + for node in graph.nodes: + if node.op == "placeholder" and ( + "token" in node.name.lower() or "input_ids" in node.name + ): + token_ids_node = node + break + + if token_ids_node is None: + # Fallback: use first placeholder + token_ids_node = graph.find_nodes(op="placeholder", sort=True)[0] + + ad_logger.info(f"Using {token_ids_node.name} to extract batch_size dimension") + + # Get symbolic batch_size dimension + token_ids_meta = token_ids_node.meta.get("val") + assert token_ids_meta is not None, "token_ids_meta is None" + batch_size_dim = token_ids_meta.size(0) + ad_logger.info(f"Extracted batch_size={batch_size_dim}") + + # Add last_token_ids placeholder: int64[batch_size] + num_selected_tokens = 2 + input_shape = (batch_size_dim, num_selected_tokens) + last_token_ids_example = torch.zeros(input_shape, dtype=torch.int64, device="meta") + last_token_ids_node = add_graph_input(gm, name="last_token_ids", val=last_token_ids_example) + + ad_logger.info(f"Added last_token_ids placeholder: {last_token_ids_node.name}") + + return last_token_ids_node + + def _insert_gather_nd( + self, + gm: GraphModule, + linear_simple: Node, + last_token_ids_node: Node, + ) -> bool: + """Insert GatherND operation before MatMul. + + The pattern to create: + 1. Create GatherND with the linear input and last_token_ids + 2. Replace MatMul input with GatherND output + """ + graph = gm.graph + + # Get the input to linear_simple (should be from the previous layer) + linear_input = linear_simple.args[0] + + ad_logger.info( + f"Linear input: {linear_input.name if hasattr(linear_input, 'name') else linear_input}" + ) + + # 1. Insert GatherND operation before linear_simple + # GatherND takes the hidden states (linear_input) and gathers based on last_token_ids + with graph.inserting_before(linear_simple): + unsqueeze_node = graph.call_function( + torch.ops.aten.unsqueeze.default, args=(last_token_ids_node, -1) + ) + gather_nd_node = graph.call_function( + torch.ops.auto_deploy.torch_onnx_gather_nd.default, + args=(linear_input, unsqueeze_node, 1), # Use linear_input, not linear_simple! + ) + ad_logger.info(f"Created GatherND node: {gather_nd_node.name}") + + # 2. Replace linear_simple's first input with GatherND output + linear_simple.replace_input_with(linear_input, gather_nd_node) + ad_logger.info("Replaced Linear input with GatherND output") + + return True + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + ad_logger.info("Adding last_token_ids gather operation") + + # Step 1: Find last linear simple node + linear_simple_node = self._find_last_linear_simple(gm) + + if linear_simple_node is None: + ad_logger.info("Could not find last linear simple node, skipping") + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + + # Step 2: Add last_token_ids input + last_token_ids_node = self._add_last_token_ids_input(gm) + + # Step 3: Insert GatherND operation + success = self._insert_gather_nd(gm, linear_simple_node, last_token_ids_node) + + if not success: + ad_logger.info("Failed to insert GatherND operation") + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + + return gm, TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/short_reshape_attention_output.py b/tensorrt_llm/_torch/auto_deploy/transform/library/short_reshape_attention_output.py new file mode 100644 index 0000000000..3ecceffa3e --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/short_reshape_attention_output.py @@ -0,0 +1,165 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Tuple + +import torch +from torch.fx import GraphModule, Node + +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry + + +@TransformRegistry.register("short_reshape_attention_output") +class ShortReshapeAttentionOutput(BaseTransform): + """Transform that simplifies reshape operations after attention outputs. + + This transform optimizes reshape nodes that follow AttentionPlugin outputs + by replacing fixed shape dimensions with dynamic symbolic dimensions derived + from the input tensor. This ensures the reshape operation correctly handles + dynamic batch sizes and sequence lengths. + + The transform identifies patterns where: + 1. A reshape node follows an AttentionPlugin output + 2. The reshape's input can be traced back to a linear projection + + It then replaces the reshape with a new one that uses symbolic dimensions + from the linear projection's input, resulting in the shape [batch, seq, -1]. + """ + + def _lookup_ascending_node(self, node: Node, target, max_depth: int = 3) -> Node: + """Recursively search for an ancestor node with a specific target operation. + + Traverses the graph backwards through node arguments to find a node + that matches the specified target operation type. + + Args: + node: Starting node for the backward search. + target: Target operation to search for (e.g., torch.ops.aten.reshape.default). + max_depth: Maximum number of levels to traverse (default: 3). + + Returns: + The matching ancestor node if found, None otherwise. + """ + if max_depth == 0: + return None + if node.target == target: + return node + + # Helper function to check a single node + def check_node(n): + if isinstance(n, Node): + result = self._lookup_ascending_node(n, target, max_depth - 1) + if result is not None: + return result + return None + + # Check all arguments + for arg in node.args: + if isinstance(arg, (tuple, list)): + for item in arg: + result = check_node(item) + if result is not None: + return result + else: + result = check_node(arg) + if result is not None: + return result + return None + + def _find_reshape_attention_output(self, gm: GraphModule) -> List[Node]: + """Find reshape nodes that follow AttentionPlugin outputs. + + Searches for reshape operations that: + 1. Have an AttentionPlugin in their input chain + 2. Can be traced back to a torch_linear_simple operation + + The linear operation is needed to extract the original input shape + for constructing the new reshape with symbolic dimensions. + + Args: + gm: The GraphModule to search. + + Returns: + List of (reshape_node, linear_node) tuples representing the + matched patterns. + """ + reshape_linear_pairs = [] + reshape_nodes = gm.graph.find_nodes( + op="call_function", target=torch.ops.aten.reshape.default + ) + + for node in reshape_nodes: + # Looking for AttentionPlugin from the input of the reshape node + attention_plugin_node = self._lookup_ascending_node( + node.args[0], torch.ops.auto_deploy.torch_onnx_attention_plugin.default + ) + if attention_plugin_node is None: + continue + + # Looking for torch_linear_simple from the input of the AttentionPlugin node + linear_simple_node = self._lookup_ascending_node( + attention_plugin_node.args[0], torch.ops.auto_deploy.torch_linear_simple.default + ) + if linear_simple_node is None: + continue + + reshape_linear_pairs.append((node, linear_simple_node)) + + return reshape_linear_pairs + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + num_matches = 0 + reshape_linear_pairs = self._find_reshape_attention_output(gm) + self._log_info(f"Found {len(reshape_linear_pairs)} reshape_linear_pairs") + + graph = gm.graph + for reshape_node, linear_simple_node in reshape_linear_pairs: + # Get the input tensor of reshape node (keep it unchanged) + shape_src = linear_simple_node.args[0] + data_input = reshape_node.args[0] + + # Extract first two dimensions from the input tensor + with graph.inserting_before(reshape_node): + dim0 = graph.call_function(torch.ops.aten.sym_size.int, args=(shape_src, 0)) + dim1 = graph.call_function(torch.ops.aten.sym_size.int, args=(shape_src, 1)) + + # Create new shape using input tensor's first two dimensions + new_shape = [dim0, dim1, -1] + + # Create a NEW reshape node instead of modifying the existing one + # This ensures the graph dependencies are correctly established + new_reshape_node = graph.call_function( + torch.ops.aten.reshape.default, args=(data_input, new_shape) + ) + + # Replace all uses of the old reshape node with the new one + reshape_node.replace_all_uses_with(new_reshape_node) + + # Remove the old reshape node from the graph + graph.erase_node(reshape_node) + + num_matches += 1 + + info = TransformInfo( + skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=False + ) + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/optimizer.py b/tensorrt_llm/_torch/auto_deploy/transform/optimizer.py index ea510a854f..ce43730928 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/optimizer.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/optimizer.py @@ -64,11 +64,11 @@ class InferenceOptimizer: mod = nn.Module() # iterate over all transforms sorted by stage in the config - for t_name, t_config in self.config.items(): + for idx, (t_name, t_config) in enumerate(self.config.items()): # instantiate transform transform = TransformRegistry.get(t_name)(t_config) # run transform - mod = transform(mod, cm, self.factory, self.shared_config) + mod = transform(mod, cm, self.factory, self.shared_config, idx) ############################################################################################ # RETURN OPTIMIZED MODEL diff --git a/tensorrt_llm/_torch/auto_deploy/utils/_graph.py b/tensorrt_llm/_torch/auto_deploy/utils/_graph.py index 270e1ea4fe..749b2cd130 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/_graph.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/_graph.py @@ -15,12 +15,26 @@ from torch._subclasses import FakeTensor, FakeTensorMode from torch.fx import Graph, GraphModule, Node from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.fx.passes.shape_prop import _extract_tensor_metadata -from torch.utils._pytree import _LEAF_SPEC +from torch.utils._pytree import _LEAF_SPEC, TreeSpec from .logger import ad_logger from .node_utils import get_weight_tensor, is_op _NoValType = type("_NoValType", (), {}) + + +def _call_post_init_recursive(spec: TreeSpec) -> None: + """Recursively call __post_init__ on TreeSpec starting from leaves. + + This is needed to update internal cached values (like num_leaves) after + modifying children_specs. + """ + if hasattr(spec, "children_specs"): + for child_spec in spec.children_specs: + _call_post_init_recursive(child_spec) + spec.__post_init__() + + _NO_VAL = _NoValType() @@ -367,12 +381,7 @@ def add_graph_input( in_spec_for_args.context.append(name) # update pytree info recursively with __post_init__ starting at leaves - def call_post_init(spec): - for child_spec in spec.children_specs: - call_post_init(child_spec) - spec.__post_init__() - - call_post_init(in_spec) + _call_post_init_recursive(in_spec) # set fake tensor information if all required information is available fake_mode: Optional[FakeTensorMode] = _detect_fake_mode_from_gm(gm) @@ -527,3 +536,194 @@ def del_attr_by_name(obj, name): for part in parts[:-1]: obj = getattr(obj, part) delattr(obj, parts[-1]) + + +def add_graph_output(gm: GraphModule, output_node: Node, name: str) -> None: + """Add a graph output to the given GraphModule. + + This function appends a new output to the graph's output node and updates + the pytree_info metadata accordingly. + + NOTE: function does NOT do any graph canonicalization. This is left to the user! + + Args: + gm (GraphModule): The GraphModule to add the output to. + output_node (Node): The node to add as an output. + name (str): The name of the new output in the output dict. + + Raises: + RuntimeError: If the graph has no output node or multiple output nodes. + + Implementation Notes (Hacky Details): + 1. **Handling single-output graphs (out_spec == _LEAF_SPEC)**: + When the graph originally has a single output, its out_spec is a + _LEAF_SPEC (a leaf node in the pytree). To add a new output, we must + first convert it to a container spec (tuple with children_specs). + + 2. **Forcing output type to dict**: + The original out_spec.type is typically a model-specific output class + (e.g., CausalLMOutputWithPast from transformers). When we add new + outputs with custom names (like "present_key_values_xx"), the original + class constructor will fail because these names are not valid data + members of that class. + + PyTorch uses the out_spec.type's constructor to wrap the output: + result_obj = result_type(**output) + + To avoid constructor failures, we forcibly change out_spec.type to + dict using object.__setattr__ (since TreeSpec is a frozen dataclass). + This ensures the output is returned as a plain dictionary instead of + the original model-specific class. + """ + # extract graph and output spec + graph: Graph = gm.graph + + # find the output node + output_nodes = graph.find_nodes(op="output") + if len(output_nodes) != 1: + raise RuntimeError(f"Expected exactly one output node, found {len(output_nodes)}") + + graph_output_node = output_nodes[0] + + # get current outputs + current_outputs = graph_output_node.args[0] + if not isinstance(current_outputs, (tuple, list)): + # single output, convert to tuple + current_outputs = (current_outputs,) + + # append new output + new_outputs = tuple(current_outputs) + (output_node,) + graph_output_node.args = (new_outputs,) + + # update pytree info: append spec for the new output + # The out_spec structure mirrors the output structure + # For a tuple of outputs, out_spec should be a TreeSpec with children_specs + # And graph._codegen.pytree_info is a NamedTuple, so we have to use _replace to replace the out_spec + out_spec = graph._codegen.pytree_info.out_spec + assert out_spec is not None, "Graph must have an out_spec" + + if out_spec == _LEAF_SPEC: + new_out_spec = TreeSpec(type=tuple, children_specs=[_LEAF_SPEC], context=["output"]) + graph._codegen.pytree_info = graph._codegen.pytree_info._replace(out_spec=new_out_spec) + out_spec = graph._codegen.pytree_info.out_spec + if hasattr(out_spec, "children_specs"): + # out_spec is already a container (tuple/list), append to it + out_spec.children_specs.append(_LEAF_SPEC) + out_spec.context.append(name) + else: + # This shouldn't happen in normal cases, but handle it just in case + # If out_spec is a single leaf, we need to convert it to a tuple spec + raise NotImplementedError( + "Cannot add output to a graph with non-container out_spec. " + "This case is not currently supported." + ) + + # update pytree info recursively with __post_init__ starting at leaves + _call_post_init_recursive(out_spec) + + # Change the type of the output spec to dict, + # + # NOTE(yocox) This will affect how torch inteprete the output of the graph. + # The original type is some LLM output result class, + # for example, CausalLMOutputWithPast, depends on the model. + # To add new fields to the output, we need to change the type to dict. + # The type's constructor will be used to create an object containing + # the result, for example, + # + # result_obj = reuslt_type(**output) + # + # Because we added new output's with names like "present_key_values_xx" which + # is not a data member of CausalLMOutputWithPast, so the constructor will fail. + # So we need to change the type to dict. + # + # However the out_spec.type is frozen dataclass, + # so we need to use object.__setattr__ to change it. + object.__setattr__(out_spec, "type", dict) + + +def remove_graph_input(gm: GraphModule, input_node: Node) -> str: + """Remove a graph input from the given GraphModule. + + This is the inverse operation of add_graph_input(). It removes a placeholder node + and updates the pytree_info metadata accordingly. The function automatically detects + whether the node belongs to args or kwargs and updates the correct spec. + + NOTE: function does NOT do any graph canonicalization. This is left to the user! + + Args: + gm (GraphModule): The GraphModule to remove the input from. + input_node (Node): The placeholder node to remove. + + Returns: + str: The name of the removed node. + + Raises: + ValueError: If the provided Node is not a placeholder or not in the graph. + RuntimeError: If the input node is still being used by other nodes. + """ + graph: Graph = gm.graph + + # validate that the node is a placeholder + if input_node.op != "placeholder": + raise ValueError( + f"Node '{input_node.name}' is not a placeholder node (op='{input_node.op}'). " + f"Only placeholder nodes can be removed as graph inputs." + ) + + # find all placeholder nodes and validate the node is in this graph + placeholder_nodes = graph.find_nodes(op="placeholder", sort=True) + if input_node not in placeholder_nodes: + raise ValueError( + f"Node '{input_node.name}' is not found in the graph's placeholder nodes. " + f"It may belong to a different graph or have already been removed." + ) + + # check that the node is not being used + if len(input_node.users) > 0: + user_names = [user.name for user in input_node.users.keys()] + raise RuntimeError( + f"Cannot remove input '{input_node.name}' " + f"because it is still being used by nodes: {user_names}. " + f"Please remove or replace all usages first." + ) + + # get pytree info + in_spec = graph._codegen.pytree_info.in_spec + args_spec = in_spec.children_specs[0] # tuple for args + kwargs_spec = in_spec.children_specs[1] # dict for kwargs + orig_args = graph._codegen.pytree_info.orig_args + + # determine the global index of the node + global_idx = placeholder_nodes.index(input_node) + + # determine number of args (kwargs come after args in placeholder order) + num_args = len(args_spec.children_specs) + + # determine if this is an arg or kwarg, and compute relative index + if global_idx < num_args: + # it's an arg + relative_idx = global_idx + target_spec = args_spec + is_kwarg = False + else: + # it's a kwarg + relative_idx = global_idx - num_args + target_spec = kwargs_spec + is_kwarg = True + + # save the node name before removing + removed_node_name = input_node.name + + # remove the node from the graph + graph.erase_node(input_node) + + # update pytree info: remove the corresponding spec and arg name + target_spec.children_specs.pop(relative_idx) + if is_kwarg: + target_spec.context.pop(relative_idx) + orig_args.pop(global_idx) + + # update pytree info recursively with __post_init__ starting at leaves + _call_post_init_recursive(in_spec) + + return removed_node_name diff --git a/tests/unittest/_torch/auto_deploy/_utils_test/torch_attention_reference.py b/tests/unittest/_torch/auto_deploy/_utils_test/torch_attention_reference.py index 3720959cd7..710667924f 100644 --- a/tests/unittest/_torch/auto_deploy/_utils_test/torch_attention_reference.py +++ b/tests/unittest/_torch/auto_deploy/_utils_test/torch_attention_reference.py @@ -56,7 +56,7 @@ class TorchAttentionReference: v_flat = v.reshape(1, batch_size * seq_len, -1) # Call torch backend via custom op registry - output_flat = torch.ops.auto_deploy.torch_cached_attention_with_cache( + output_flat = torch.ops.auto_deploy.torch_cached_attention_with_cache.default( q_flat, k_flat, v_flat, @@ -97,7 +97,7 @@ class TorchAttentionReference: This function directly calls the torch backend implementation via custom op registry. """ - return torch.ops.auto_deploy.torch_cached_attention_with_cache( + return torch.ops.auto_deploy.torch_cached_attention_with_cache.default( q, k, v, @@ -148,7 +148,7 @@ class TorchAttentionReference: batch_info_host = torch.tensor([0, 0, batch_size], device=q.device, dtype=torch.int32) # Call torch backend via custom op registry - output_flat = torch.ops.auto_deploy.torch_cached_attention_with_cache( + output_flat = torch.ops.auto_deploy.torch_cached_attention_with_cache.default( q_flat, k_flat, v_flat, @@ -185,7 +185,7 @@ class TorchAttentionReference: This demonstrates how to use the torch backend with additional features. """ - return torch.ops.auto_deploy.torch_cached_attention_with_cache( + return torch.ops.auto_deploy.torch_cached_attention_with_cache.default( q, k, v, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_export_onnx.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_export_onnx.py new file mode 100644 index 0000000000..3de8527263 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_export_onnx.py @@ -0,0 +1,65 @@ +import os + +import onnx +import pytest +from _model_test_utils import get_small_model_config + +from tensorrt_llm._torch.auto_deploy.export import export_onnx +from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs + + +@pytest.mark.parametrize( + "model, output_dir, num_attn_ops", + [ + ( + "Qwen/Qwen2.5-3B-Instruct", + "/tmp/test_ad_export_onnx_qwen2.5-3b", + 36, + ), + ], +) +def test_ad_export_onnx(model: str, output_dir: str, num_attn_ops: int): + ad_config = LlmArgs( + model=get_small_model_config(model)["args"]["model"], + mode="export_edgellm_onnx", + max_batch_size=13, + max_seq_len=4, + ) + ad_config.transforms["export_to_onnx"]["output_dir"] = output_dir + export_onnx(ad_config) + + # check if the output directory exists + assert os.path.exists(output_dir) + + # check if the output json files exist + assert os.path.exists(os.path.join(output_dir, "added_tokens.json")) + assert os.path.exists(os.path.join(output_dir, "config.json")) + assert os.path.exists(os.path.join(output_dir, "processed_chat_template.json")) + assert os.path.exists(os.path.join(output_dir, "special_tokens_map.json")) + assert os.path.exists(os.path.join(output_dir, "tokenizer.json")) + assert os.path.exists(os.path.join(output_dir, "tokenizer_config.json")) + assert os.path.exists(os.path.join(output_dir, "vocab.json")) + + # check if the output onnx file exists + assert os.path.exists(os.path.join(output_dir, "model.onnx")) + + # check if the onnx file has 24 AttentionPlugin operators + onnx_model = onnx.load(os.path.join(output_dir, "model.onnx")) + attn_ops = [node for node in onnx_model.graph.node if node.op_type == "AttentionPlugin"] + assert len(attn_ops) == num_attn_ops + + # check input and output names of the model + actual_inputs_names = {input.name for input in onnx_model.graph.input} + actual_outputs_names = {output.name for output in onnx_model.graph.output} + expect_inputs_names = { + "input_ids", + "context_lengths", + "rope_rotary_cos_sin", + "kvcache_start_index", + "last_token_ids", + } + expect_outputs_names = {"logits"} + expect_inputs_names.update({f"past_key_values_{i}" for i in range(num_attn_ops)}) + expect_outputs_names.update({f"present_key_values_{i}" for i in range(num_attn_ops)}) + assert actual_inputs_names == expect_inputs_names + assert actual_outputs_names == expect_outputs_names diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rope_attention.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rope_attention.py new file mode 100644 index 0000000000..561fe34338 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rope_attention.py @@ -0,0 +1,242 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for fuse_rope_attention transformation. +""" + +from collections import namedtuple + +import pytest +import torch +from torch.export import Dim + +# Import modules to register custom ops (torch.ops.auto_deploy.*) +import tensorrt_llm._torch.auto_deploy.custom_ops.torch_attention # noqa: F401 +import tensorrt_llm._torch.auto_deploy.custom_ops.torch_rope # noqa: F401 +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer + +torch.manual_seed(0) + + +def _get_cos_sin(x: torch.Tensor, head_dim): + """Precompute cos and sin for RoPE.""" + batch_size = x.size(0) + seq_len = x.size(1) + cos = x.new_zeros(batch_size, seq_len, head_dim) + sin = x.new_zeros(batch_size, seq_len, head_dim) + return cos, sin + + +class RopeAttentionModel(torch.nn.Module): + """ + Model that implements the rope + attention pattern that fuse_rope_attention expects. + + Pattern: + 1. q_proj, k_proj, v_proj + 2. view to [batch, seq, num_heads, head_dim] + 3. contiguous + 4. torch_rope_with_explicit_cos_sin for q and k + 5. torch_attention with layout="bsnd" + """ + + def __init__( + self, + head_dim: int, + num_q_heads: int, + num_kv_heads: int, + max_seq_len: int = 128, + ): + super().__init__() + hidden_size = head_dim * num_q_heads + self.head_dim = head_dim + self.hidden_size = hidden_size + self.num_q_heads = num_q_heads + self.num_kv_heads = num_kv_heads + self.head_dim = hidden_size // num_q_heads + self.max_seq_len = max_seq_len + + # Linear projections + self.q_proj = torch.nn.Linear(hidden_size, num_q_heads * self.head_dim, bias=False) + self.k_proj = torch.nn.Linear(hidden_size, num_kv_heads * self.head_dim, bias=False) + self.v_proj = torch.nn.Linear(hidden_size, num_kv_heads * self.head_dim, bias=False) + self.o_proj = torch.nn.Linear(num_q_heads * self.head_dim, hidden_size, bias=False) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, # Required for export API compatibility + ) -> torch.Tensor: + x = input_ids.unsqueeze(-1).expand(-1, -1, self.hidden_size).to(torch.float16) + batch_size, seq_len, _ = x.shape + + # Generate q, k, v: [batch, seq, hidden_size] -> [batch, seq, num_heads * head_dim] + q = self.q_proj(x) # [batch, seq, num_q_heads * head_dim] + k = self.k_proj(x) # [batch, seq, num_kv_heads * head_dim] + v = self.v_proj(x) # [batch, seq, num_kv_heads * head_dim] + + # View to [batch, seq, num_heads, head_dim] + q = q.view(batch_size, seq_len, self.num_q_heads, self.head_dim) + k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim) + v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim) + + # Apply contiguous (this is part of the pattern) + q = torch.ops.aten.contiguous(q) + k = torch.ops.aten.contiguous(k) + + # Precompute cos and sin for RoPE + cos, sin = _get_cos_sin(x, self.head_dim) + + # Apply RoPE with unsqueeze_dim=2 for BSND layout + # NOTE(yoco): This should be (q, k) according to the definition of torch_rope_with_explicit_cos_sin. + # However, the actual graph is (k, q), Need further investigation. + q_rot, k_rot = torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin(k, q, cos, sin, 2) + + # Apply attention with layout="bsnd" + attn_output = torch.ops.auto_deploy.torch_attention( + k_rot, + q_rot, + v, + None, # attn_mask + 0.0, # dropout_p + True, # is_causal + None, # scale + None, # sinks + None, # sliding_window + None, # logit_cap + "bsnd", # layout + ) + + # Reshape back and apply output projection + attn_output = attn_output.reshape(batch_size, seq_len, -1) + output = self.o_proj(attn_output) + return output + + def _apply_rope_manual(self, q, k, cos, sin): + """Manual rope application for fallback/comparison.""" + + def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + cos = cos.unsqueeze(2) # unsqueeze_dim=2 for BSND + sin = sin.unsqueeze(2) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + def get_dynamic_shapes(self): + return [ + {0: Dim("batch_size", max=8), 1: Dim("seq_len", max=128)}, + {0: Dim("batch_size", max=8), 1: Dim("seq_len", max=128)}, + ] + + +def convert_placeholder_meta_to_cuda(gm: torch.fx.GraphModule): + """Convert placeholder meta to cuda.""" + for node in gm.graph.nodes: + if node.op == "placeholder": + node.meta["val"] = node.meta["val"].to("cuda") + + +def _run_test( + head_dim: int, + num_q_heads: int, + num_kv_heads: int, + batch_size: int, + seq_len: int, +): + """Helper function to run the transformation test.""" + model = RopeAttentionModel( + head_dim=head_dim, + num_q_heads=num_q_heads, + num_kv_heads=num_kv_heads, + ).to("cuda", torch.float16) + input_ids = torch.randint(0, 10, (batch_size, seq_len), device="cuda", dtype=torch.int32) + position_ids = torch.randint(0, 10, (batch_size, seq_len), device="cuda", dtype=torch.int32) + dynamic_shapes = model.get_dynamic_shapes() + + # Export to graph module + args = (input_ids, position_ids) + gm = torch_export_to_gm(model, args=args, dynamic_shapes=dynamic_shapes, clone=True) + + # Model config object, include num_attention_heads and hidden_size + class Factory: + def _get_model_config(self): + ModelConfig = namedtuple( + "ModelConfig", ["num_attention_heads", "hidden_size", "num_key_value_heads"] + ) + return ModelConfig( + num_attention_heads=num_q_heads, + hidden_size=head_dim * num_q_heads, + num_key_value_heads=num_kv_heads, + ), None + + # Apply fuse_rope_attention transformation + + optimizer = InferenceOptimizer( + Factory(), + { + "fuse_rope_attention": { + "stage": "pattern_matcher", + }, + }, + ) + convert_placeholder_meta_to_cuda(gm) + gm_transformed = optimizer(None, gm) + gm_transformed.to("cuda") + + fused_nodes = gm_transformed.graph.find_nodes( + op="call_function", target=torch.ops.auto_deploy.torch_onnx_attention_plugin.default + ) + assert len(fused_nodes) == 1, "Expected 1 AttentionPlugin node, got {len(fused_nodes)}" + input_nodes = gm_transformed.graph.find_nodes(op="placeholder") + assert len(input_nodes) == 5, "Expected 5 input nodes, got {len(input_nodes)}" + input_nodes_targets = {node.target for node in input_nodes} + assert input_nodes_targets == { + "input_ids", + "context_lengths", + "rope_rotary_cos_sin", + "kvcache_start_index", + "past_key_values_0", + } + out_node = gm_transformed.graph.find_nodes(op="output")[0] + assert len(out_node.args[0]) == 2, "Expected 2 output nodes, got {len(out_node.args[0])}" + + +@pytest.mark.parametrize( + "head_dim,num_q_heads,num_kv_heads,batch_size,seq_len", + [ + # GQA (Grouped-Query Attention): num_q_heads > num_kv_heads + pytest.param(64, 14, 2, 2, 16, id="gqa_ratio_7"), + ], +) +@torch.inference_mode() +def test_fuse_rope_attention( + head_dim: int, + num_q_heads: int, + num_kv_heads: int, + batch_size: int, + seq_len: int, +): + """Test fuse_rope_attention transformation with various configurations.""" + _run_test( + head_dim=head_dim, + num_q_heads=num_q_heads, + num_kv_heads=num_kv_heads, + batch_size=batch_size, + seq_len=seq_len, + )