[#11148][feat] AutoDeploy: Better structure the custom op (#11152)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
This commit is contained in:
Chenghao Zhang 2026-02-05 21:32:22 -08:00 committed by GitHub
parent 639051e98b
commit d160439ef9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
85 changed files with 872 additions and 58 deletions

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom ops and make sure they are all registered."""
import importlib

View File

@ -0,0 +1,36 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Attention operations.
This module provides various attention implementations and backends:
- torch_attention: PyTorch reference implementations
- torch_backend_attention: PyTorch-based attention backend
- flashinfer_attention: FlashInfer-based optimized attention
- triton_attention: Triton-based attention implementations
- triton_attention_with_kv_cache: Triton attention with KV cache support
- triton_attention_with_paged_kv_cache: Triton attention with paged KV cache
- onnx_attention: Placeholder ops for ONNX export of attention mechanisms
"""
__all__ = [
"torch_attention",
"torch_backend_attention",
"flashinfer_attention",
"triton_attention",
"triton_attention_with_kv_cache",
"triton_attention_with_paged_kv_cache",
"onnx_attention",
]

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, fields
from typing import Dict, List, Literal, Optional, Tuple, Union
@ -7,12 +22,12 @@ from torch._ops import OpOverloadPacket
from torch._subclasses import FakeTensor
from torch.fx import Node
from ....llmapi.llm_args import KvCacheConfig
from ...flashinfer_utils import get_env_enable_pdl
from ..utils.cuda_graph import cuda_graph_state
from ..utils.logger import ad_logger
from ..utils.node_utils import extract_op_args
from .attention_interface import (
from .....llmapi.llm_args import KvCacheConfig
from ....flashinfer_utils import get_env_enable_pdl
from ...utils.cuda_graph import cuda_graph_state
from ...utils.logger import ad_logger
from ...utils.node_utils import extract_op_args
from ..attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,

View File

@ -1,10 +1,11 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Torch reference implementations for attention."""
import math

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Torch backend attention using pure PyTorch reference implementations."""
import math
@ -8,10 +23,10 @@ from torch._ops import OpOverloadPacket
from torch._subclasses import FakeTensor
from torch.fx import Node
from ....llmapi.llm_args import KvCacheConfig
from ..utils.logger import ad_logger
from ..utils.node_utils import extract_op_args
from .attention_interface import (
from .....llmapi.llm_args import KvCacheConfig
from ...utils.logger import ad_logger
from ...utils.node_utils import extract_op_args
from ..attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom ops for MHA/XQA attention."""
import math
@ -9,10 +24,10 @@ from torch._ops import OpOverloadPacket
from torch._subclasses import FakeTensor
from torch.fx import Node
from ....llmapi.llm_args import KvCacheConfig
from ..utils.logger import ad_logger
from ..utils.node_utils import extract_op_args
from .attention_interface import (
from .....llmapi.llm_args import KvCacheConfig
from ...utils.logger import ad_logger
from ...utils.node_utils import extract_op_args
from ..attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,
@ -21,7 +36,7 @@ from .attention_interface import (
ResourceHandlerDict,
UnpagedResourceHandler,
)
from .triton_kernels.attention_with_kv_cache import (
from .triton_attention_with_kv_cache import (
attention_kv_stage2,
context_attention_kv_flattened,
gqa_attention_kv_stage1,

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-head attention kernel that can operate with kv-caches."""
import triton

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Attention Interface to handle various attention operators and cache operations.
This module provides an interface between the high-level runtime and cache management system and

View File

@ -0,0 +1,26 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Distributed operations.
This module provides distributed communication primitives:
- torch_dist: PyTorch distributed backend operations
- trtllm_dist: TensorRT-LLM optimized distributed operations
"""
__all__ = [
"torch_dist",
"trtllm_dist",
]

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom ops required for implementing tensor parallelism.
This module defines atomic distributed ops - each op uses a specific backend
@ -8,7 +23,7 @@ from typing import List, Optional
import torch
from ..distributed import common as dist
from ...distributed import common as dist
# ============================================================================
# PyTorch Distributed Backend Ops (demollm mode)

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TRT-LLM distributed operations and fused kernels.
This module defines atomic TRT-LLM-specific ops that use optimized kernels.
@ -9,10 +24,10 @@ from typing import List, Optional
import torch
# use trtllm distributed ops to improve TP performance if possible
from ....mapping import Mapping
from ...distributed import AllReduce, allgather
from ...modules.linear import AllReduceFusionOp, AllReduceParams, AllReduceStrategy
from ..distributed.common import ReduceOp, get_rank_world_size, get_world_size, is_ompi
from .....mapping import Mapping
from ....distributed import AllReduce, allgather
from ....modules.linear import AllReduceFusionOp, AllReduceParams, AllReduceStrategy
from ...distributed.common import ReduceOp, get_rank_world_size, get_world_size, is_ompi
# Cache AllReduce modules to avoid recreating on every call
# This is critical for CUDA graph compatibility - recreating modules during

View File

@ -0,0 +1,14 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,14 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cached attention op for delta rule using the fla kernel library.
Delta Rule is based on this paper: https://arxiv.org/abs/2406.06484

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom ops corresponding to fla's chunked delta rule.
Delta Rule is based on this paper: https://arxiv.org/abs/2406.06484

View File

@ -0,0 +1,14 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
AOT-compiled moe_align CUDA kernel.

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Triton-kernels-based MXFP4 MoE ops (GPT-OSS style) with routing, swizzling, and fused activation
from typing import Callable, Tuple

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, List
import torch

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Triton implementation of the Fused MOE ops. Inspired by vLLM's triton MOE implementation.
"""

View File

@ -15,7 +15,9 @@
import torch
from tensorrt_llm._torch.auto_deploy.custom_ops.quant import TRTLLM_NVFP4_SCALING_VECTOR_SIZE
from tensorrt_llm._torch.auto_deploy.custom_ops.quantization.quant import (
TRTLLM_NVFP4_SCALING_VECTOR_SIZE,
)
from tensorrt_llm._torch.utils import ActivationType

View File

@ -0,0 +1,26 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Linear operations.
This module provides linear layer implementations:
- linear: Linear layer operations
- torch_router: MoE router operations
"""
__all__ = [
"linear",
"torch_router",
]

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom ops for linear layers."""
from typing import Optional

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F

View File

@ -0,0 +1,14 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import torch

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom op collection for cached causal conv1d in pure PyTorch.
This mirrors the structure used by the cached Mamba/SSM ops:

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom op collection for cached mamba2 ssm transform (linear attention) in pure PyTorch.
This file contains two kinds of functionality:

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom op collection for uncached causal conv (sliding window with 1d)."""
from typing import Optional

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom op collection for uncached mamba mixer (linear attention)."""
from typing import List, Tuple

View File

@ -0,0 +1,24 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-head Latent Attention operations.
This module provides Multi-head Latent Attention (MLA) implementations:
- mla: MLA operations and attention descriptor
"""
__all__ = [
"mla",
]

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom ops for MultiHead Latent attention."""
import math
@ -7,8 +22,9 @@ import torch
from torch._ops import OpOverloadPacket
from torch.fx import Node
from ....llmapi.llm_args import KvCacheConfig
from .attention_interface import (
from .....llmapi.llm_args import KvCacheConfig
from ..attention.triton_attention import _decode_attention, _prefill_attention
from ..attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,
@ -16,7 +32,6 @@ from .attention_interface import (
ResourceHandlerDict,
UnpagedResourceHandler,
)
from .triton_attention import _decode_attention, _prefill_attention
Constant = Union[int, float, str, None]

View File

@ -0,0 +1,30 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Normalization operations.
This module provides various normalization implementations:
- rms_norm: RMSNorm implementations (FlashInfer, Triton, reference)
- triton_rms_norm: Low-level Triton RMSNorm kernel
- l2norm: L2 normalization operations
- flashinfer_fused_add_rms_norm: Fused add + RMSNorm operation
"""
__all__ = [
"rms_norm",
"triton_rms_norm",
"l2norm",
"flashinfer_fused_add_rms_norm",
]

View File

@ -12,7 +12,7 @@
import flashinfer
import torch
from ...flashinfer_utils import get_env_enable_pdl
from ....flashinfer_utils import get_env_enable_pdl
@torch.library.custom_op(

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom ops corresponding to l2norm."""
import torch

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom operator for FlashInfer and Triton RMSNorm implementation."""
import flashinfer
@ -6,9 +21,9 @@ import torch.distributed as dist
import torch.nn.functional as F
from einops import rearrange
from ...flashinfer_utils import get_env_enable_pdl
from ...modules.mamba.layernorm_gated import _layer_norm_fwd
from .triton_kernels.rms_norm import rms_norm
from ....flashinfer_utils import get_env_enable_pdl
from ....modules.mamba.layernorm_gated import _layer_norm_fwd
from .triton_rms_norm import rms_norm
@torch.library.custom_op("auto_deploy::flashinfer_rms_norm", mutates_args=())

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import triton
import triton.language as tl

View File

@ -0,0 +1,26 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Quantization operations.
This module provides quantization utilities and operations:
- quant: Quantization operations (FP8, FP4, INT4, INT8)
- torch_quant: PyTorch-based quantization implementations
"""
__all__ = [
"quant",
"torch_quant",
]

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Definition of the quant module that can be used for PTQ."""
import warnings
@ -7,7 +22,7 @@ import torch
from flashinfer import bmm_fp8
from torch import nn
from .torch_libs.float8_python_api import addmm_float8_unwrapped
from ..torch_libs.float8_python_api import addmm_float8_unwrapped
TRTLLM_FP4_OP_AVAILABLE = True

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import torch

View File

@ -0,0 +1,30 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RoPE (Rotary Position Embedding) operations.
This module provides various RoPE implementations:
- torch_rope: PyTorch reference implementation
- flashinfer_rope: FlashInfer-based optimized RoPE
- triton_rope: Triton-based RoPE implementation
- triton_rope_kernel: Low-level Triton kernels for RoPE
"""
__all__ = [
"torch_rope",
"flashinfer_rope",
"triton_rope",
"triton_rope_kernel",
]

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import flashinfer

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch

View File

@ -1,7 +1,22 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import triton
from .triton_kernels.rope import rope_fwd_flattened_kernel, rope_fwd_kernel
from .triton_rope_kernel import rope_fwd_flattened_kernel, rope_fwd_kernel
@torch.library.custom_op("auto_deploy::triton_rope_with_input_pos", mutates_args=())

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import triton
import triton.language as tl

View File

@ -0,0 +1,14 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,26 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility operations.
This module provides utility functions and helpers:
- torch_gather_logits: Logit gathering operations
- triton_utils: Triton utility functions and helpers
"""
__all__ = [
"torch_gather_logits",
"triton_utils",
]

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

View File

@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Triton utility operations for auto_deploy."""
import torch

View File

@ -31,7 +31,7 @@ from transformers.generation import GenerationMixin
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput
from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import gated_rms_norm_ref
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.rms_norm import gated_rms_norm_ref
from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory
from tensorrt_llm._torch.utils import ActivationType

View File

@ -16,7 +16,7 @@ from typing import Tuple
import torch
from torch.fx import GraphModule
from ...custom_ops.flashinfer_fused_add_rms_norm import flashinfer_fused_add_rms_norm
from ...custom_ops.normalization.flashinfer_fused_add_rms_norm import flashinfer_fused_add_rms_norm
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern

View File

@ -9,7 +9,10 @@ from torch.fx import GraphModule, Node
from tensorrt_llm._torch.utils import ActivationType
from ...custom_ops.quant import TRTLLM_NVFP4_PACKING_FACTOR, TRTLLM_NVFP4_SCALING_VECTOR_SIZE
from ...custom_ops.quantization.quant import (
TRTLLM_NVFP4_PACKING_FACTOR,
TRTLLM_NVFP4_SCALING_VECTOR_SIZE,
)
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils._graph import delete_all_unused_submodules, eliminate_dead_code, get_attr_by_name

View File

@ -6,7 +6,7 @@ import torch
import torch.nn as nn
from torch.fx import GraphModule, Node
from ...custom_ops.quant import (
from ...custom_ops.quantization.quant import (
FP4_GLOBAL_SCALE_MAX,
FP8_MAX,
TRTLLM_NVFP4_COLUMN_SIZE,

View File

@ -6,7 +6,7 @@ import torch
from pydantic import Field
from torch.fx import GraphModule, Node
from ...custom_ops.rms_norm import gated_rms_norm_ref
from ...custom_ops.normalization.rms_norm import gated_rms_norm_ref
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface

View File

@ -30,7 +30,7 @@ from pydantic import BaseModel, Field, field_validator
from torch.fx import GraphModule, Node
from .....functional import AllReduceStrategy
from ...custom_ops.trtllm_dist import is_trtllm_op_available
from ...custom_ops.distributed.trtllm_dist import is_trtllm_op_available
from ...models.factory import ModelFactory, ShardingConfigSource
from ...shim.interface import CachedSequenceInterface
from ...utils._graph import del_attr_by_name, eliminate_dead_code

View File

@ -5,7 +5,7 @@ import torch
import torch.nn.functional as F
from torch.fx import GraphModule, Node
from ..custom_ops.quant import FP4_GLOBAL_SCALE_MAX, FP8_MAX
from ..custom_ops.quantization.quant import FP4_GLOBAL_SCALE_MAX, FP8_MAX
from .logger import ad_logger
from .node_utils import (
extract_weight_name,

View File

@ -5,7 +5,9 @@ import torch
from _dist_test_utils import get_device_counts
from torch.export import export
from tensorrt_llm._torch.auto_deploy.custom_ops.trtllm_dist import is_trtllm_op_available
from tensorrt_llm._torch.auto_deploy.custom_ops.distributed.trtllm_dist import (
is_trtllm_op_available,
)
from tensorrt_llm._torch.auto_deploy.distributed.common import initialize_or_skip
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer

View File

@ -22,7 +22,7 @@ import torch
import torch.nn as nn
# Ensure custom ops are registered
from tensorrt_llm._torch.auto_deploy.custom_ops import rms_norm # noqa: F401
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization import rms_norm # noqa: F401
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op

View File

@ -3,7 +3,9 @@ import pytest
import torch
from torch_attention_reference import TorchAttentionReference
from tensorrt_llm._torch.auto_deploy.custom_ops.flashinfer_attention import _GlobalFlashInferPlanner
from tensorrt_llm._torch.auto_deploy.custom_ops.attention.flashinfer_attention import (
_GlobalFlashInferPlanner,
)
def _create_combined_kv_cache(k_cache: torch.Tensor, v_cache: torch.Tensor) -> torch.Tensor:

View File

@ -7,7 +7,7 @@ import triton
from _custom_op_utils import torch_rope_reference
from _model_test_utils import repeat_kv
from tensorrt_llm._torch.auto_deploy.custom_ops.triton_kernels.attention_with_kv_cache import (
from tensorrt_llm._torch.auto_deploy.custom_ops.attention.triton_attention_with_kv_cache import (
attention_kv_stage1,
attention_kv_stage2,
context_attention_kv,

View File

@ -2,7 +2,7 @@ import pytest
import torch
import tensorrt_llm._torch.auto_deploy # noqa: F401
from tests.unittest._torch.auto_deploy.unit.singlegpu.custom_ops.test_triton_mamba_cached_op import (
from tests.unittest._torch.auto_deploy.unit.singlegpu.custom_ops.mamba.test_triton_mamba_cached_op import (
_random_params,
)

View File

@ -13,7 +13,7 @@ from torch.nn import functional as F
from utils.util import skip_pre_hopper
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
from tensorrt_llm._torch.auto_deploy.custom_ops.quant import (
from tensorrt_llm._torch.auto_deploy.custom_ops.quantization.quant import (
TRTLLM_NVFP4_COLUMN_SIZE,
TRTLLM_NVFP4_ROW_SIZE,
TRTLLM_NVFP4_SCALING_VECTOR_SIZE,

View File

@ -1,7 +1,7 @@
import pytest
import torch
from tensorrt_llm._torch.auto_deploy.custom_ops.flashinfer_fused_add_rms_norm import (
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.flashinfer_fused_add_rms_norm import (
flashinfer_fused_add_rms_norm,
)

View File

@ -1,7 +1,7 @@
import pytest
import torch
from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import (
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.rms_norm import (
gated_rms_norm_ref,
triton_rmsnorm_gated,
)

View File

@ -1,7 +1,7 @@
import torch
from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import * # noqa
from tensorrt_llm._torch.auto_deploy.custom_ops.triton_kernels.rms_norm import rms_norm
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.rms_norm import * # noqa
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.triton_rms_norm import rms_norm
def test_rmsnorm_triton_op():

View File

@ -1,6 +1,6 @@
import torch
from tensorrt_llm._torch.auto_deploy.custom_ops.torch_attention import update_kv_cache
from tensorrt_llm._torch.auto_deploy.custom_ops.attention.torch_attention import update_kv_cache
def test_update_kv_cache():

View File

@ -4,7 +4,7 @@ import pytest
import torch
# Import to register the custom op
from tensorrt_llm._torch.auto_deploy.custom_ops import triton_utils # noqa: F401
from tensorrt_llm._torch.auto_deploy.custom_ops.utils import triton_utils # noqa: F401
def _reference_gather_scatter(

View File

@ -3,7 +3,7 @@ import torch
from _graph_test_helpers import run_test_transformed_gm
from torch.export import Dim
from tensorrt_llm._torch.auto_deploy.custom_ops.l2norm import * # noqa
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.l2norm import * # noqa
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op

View File

@ -3,7 +3,7 @@ import torch
from _graph_test_helpers import run_test_transformed_gm
from torch.export import Dim
from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import * # noqa
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.rms_norm import * # noqa
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op

View File

@ -23,8 +23,8 @@ import torch
from torch.export import Dim
# Import modules to register custom ops (torch.ops.auto_deploy.*)
import tensorrt_llm._torch.auto_deploy.custom_ops.torch_attention # noqa: F401
import tensorrt_llm._torch.auto_deploy.custom_ops.torch_rope # noqa: F401
import tensorrt_llm._torch.auto_deploy.custom_ops.attention.torch_attention # noqa: F401
import tensorrt_llm._torch.auto_deploy.custom_ops.rope.torch_rope # noqa: F401
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer

View File

@ -1,8 +1,8 @@
import torch
from torch.export import Dim
from tensorrt_llm._torch.auto_deploy.custom_ops.flashinfer_fused_add_rms_norm import * # noqa
from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import * # noqa
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.flashinfer_fused_add_rms_norm import * # noqa
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.rms_norm import * # noqa
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op

View File

@ -1,7 +1,7 @@
import pytest
import torch
from tensorrt_llm._torch.auto_deploy.custom_ops.quant import FP8_MAX
from tensorrt_llm._torch.auto_deploy.custom_ops.quantization.quant import FP8_MAX
from tensorrt_llm._torch.auto_deploy.transform.interface import TransformConfig
from tensorrt_llm._torch.auto_deploy.transform.library.quantization import (
FP8LinearQuantizationFromConfig,