mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-17 00:04:57 +08:00
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
This commit is contained in:
parent
639051e98b
commit
d160439ef9
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Custom ops and make sure they are all registered."""
|
||||
|
||||
import importlib
|
||||
|
||||
@ -0,0 +1,36 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Attention operations.
|
||||
|
||||
This module provides various attention implementations and backends:
|
||||
- torch_attention: PyTorch reference implementations
|
||||
- torch_backend_attention: PyTorch-based attention backend
|
||||
- flashinfer_attention: FlashInfer-based optimized attention
|
||||
- triton_attention: Triton-based attention implementations
|
||||
- triton_attention_with_kv_cache: Triton attention with KV cache support
|
||||
- triton_attention_with_paged_kv_cache: Triton attention with paged KV cache
|
||||
- onnx_attention: Placeholder ops for ONNX export of attention mechanisms
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"torch_attention",
|
||||
"torch_backend_attention",
|
||||
"flashinfer_attention",
|
||||
"triton_attention",
|
||||
"triton_attention_with_kv_cache",
|
||||
"triton_attention_with_paged_kv_cache",
|
||||
"onnx_attention",
|
||||
]
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
@ -7,12 +22,12 @@ from torch._ops import OpOverloadPacket
|
||||
from torch._subclasses import FakeTensor
|
||||
from torch.fx import Node
|
||||
|
||||
from ....llmapi.llm_args import KvCacheConfig
|
||||
from ...flashinfer_utils import get_env_enable_pdl
|
||||
from ..utils.cuda_graph import cuda_graph_state
|
||||
from ..utils.logger import ad_logger
|
||||
from ..utils.node_utils import extract_op_args
|
||||
from .attention_interface import (
|
||||
from .....llmapi.llm_args import KvCacheConfig
|
||||
from ....flashinfer_utils import get_env_enable_pdl
|
||||
from ...utils.cuda_graph import cuda_graph_state
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import extract_op_args
|
||||
from ..attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionLayout,
|
||||
AttentionRegistry,
|
||||
@ -1,10 +1,11 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Torch reference implementations for attention."""
|
||||
|
||||
import math
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Torch backend attention using pure PyTorch reference implementations."""
|
||||
|
||||
import math
|
||||
@ -8,10 +23,10 @@ from torch._ops import OpOverloadPacket
|
||||
from torch._subclasses import FakeTensor
|
||||
from torch.fx import Node
|
||||
|
||||
from ....llmapi.llm_args import KvCacheConfig
|
||||
from ..utils.logger import ad_logger
|
||||
from ..utils.node_utils import extract_op_args
|
||||
from .attention_interface import (
|
||||
from .....llmapi.llm_args import KvCacheConfig
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import extract_op_args
|
||||
from ..attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionLayout,
|
||||
AttentionRegistry,
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Custom ops for MHA/XQA attention."""
|
||||
|
||||
import math
|
||||
@ -9,10 +24,10 @@ from torch._ops import OpOverloadPacket
|
||||
from torch._subclasses import FakeTensor
|
||||
from torch.fx import Node
|
||||
|
||||
from ....llmapi.llm_args import KvCacheConfig
|
||||
from ..utils.logger import ad_logger
|
||||
from ..utils.node_utils import extract_op_args
|
||||
from .attention_interface import (
|
||||
from .....llmapi.llm_args import KvCacheConfig
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import extract_op_args
|
||||
from ..attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionLayout,
|
||||
AttentionRegistry,
|
||||
@ -21,7 +36,7 @@ from .attention_interface import (
|
||||
ResourceHandlerDict,
|
||||
UnpagedResourceHandler,
|
||||
)
|
||||
from .triton_kernels.attention_with_kv_cache import (
|
||||
from .triton_attention_with_kv_cache import (
|
||||
attention_kv_stage2,
|
||||
context_attention_kv_flattened,
|
||||
gqa_attention_kv_stage1,
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Multi-head attention kernel that can operate with kv-caches."""
|
||||
|
||||
import triton
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Attention Interface to handle various attention operators and cache operations.
|
||||
|
||||
This module provides an interface between the high-level runtime and cache management system and
|
||||
|
||||
@ -0,0 +1,26 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Distributed operations.
|
||||
|
||||
This module provides distributed communication primitives:
|
||||
- torch_dist: PyTorch distributed backend operations
|
||||
- trtllm_dist: TensorRT-LLM optimized distributed operations
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"torch_dist",
|
||||
"trtllm_dist",
|
||||
]
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Custom ops required for implementing tensor parallelism.
|
||||
|
||||
This module defines atomic distributed ops - each op uses a specific backend
|
||||
@ -8,7 +23,7 @@ from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ..distributed import common as dist
|
||||
from ...distributed import common as dist
|
||||
|
||||
# ============================================================================
|
||||
# PyTorch Distributed Backend Ops (demollm mode)
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""TRT-LLM distributed operations and fused kernels.
|
||||
|
||||
This module defines atomic TRT-LLM-specific ops that use optimized kernels.
|
||||
@ -9,10 +24,10 @@ from typing import List, Optional
|
||||
import torch
|
||||
|
||||
# use trtllm distributed ops to improve TP performance if possible
|
||||
from ....mapping import Mapping
|
||||
from ...distributed import AllReduce, allgather
|
||||
from ...modules.linear import AllReduceFusionOp, AllReduceParams, AllReduceStrategy
|
||||
from ..distributed.common import ReduceOp, get_rank_world_size, get_world_size, is_ompi
|
||||
from .....mapping import Mapping
|
||||
from ....distributed import AllReduce, allgather
|
||||
from ....modules.linear import AllReduceFusionOp, AllReduceParams, AllReduceStrategy
|
||||
from ...distributed.common import ReduceOp, get_rank_world_size, get_world_size, is_ompi
|
||||
|
||||
# Cache AllReduce modules to avoid recreating on every call
|
||||
# This is critical for CUDA graph compatibility - recreating modules during
|
||||
@ -0,0 +1,14 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@ -0,0 +1,14 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Cached attention op for delta rule using the fla kernel library.
|
||||
|
||||
Delta Rule is based on this paper: https://arxiv.org/abs/2406.06484
|
||||
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Custom ops corresponding to fla's chunked delta rule.
|
||||
|
||||
Delta Rule is based on this paper: https://arxiv.org/abs/2406.06484
|
||||
|
||||
@ -0,0 +1,14 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
AOT-compiled moe_align CUDA kernel.
|
||||
|
||||
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Triton-kernels-based MXFP4 MoE ops (GPT-OSS style) with routing, swizzling, and fused activation
|
||||
|
||||
from typing import Callable, Tuple
|
||||
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, List
|
||||
|
||||
import torch
|
||||
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Triton implementation of the Fused MOE ops. Inspired by vLLM's triton MOE implementation.
|
||||
"""
|
||||
|
||||
@ -15,7 +15,9 @@
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.quant import TRTLLM_NVFP4_SCALING_VECTOR_SIZE
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.quantization.quant import (
|
||||
TRTLLM_NVFP4_SCALING_VECTOR_SIZE,
|
||||
)
|
||||
from tensorrt_llm._torch.utils import ActivationType
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,26 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Linear operations.
|
||||
|
||||
This module provides linear layer implementations:
|
||||
- linear: Linear layer operations
|
||||
- torch_router: MoE router operations
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"linear",
|
||||
"torch_router",
|
||||
]
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Custom ops for linear layers."""
|
||||
|
||||
from typing import Optional
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
@ -0,0 +1,14 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Custom op collection for cached causal conv1d in pure PyTorch.
|
||||
|
||||
This mirrors the structure used by the cached Mamba/SSM ops:
|
||||
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Custom op collection for cached mamba2 ssm transform (linear attention) in pure PyTorch.
|
||||
|
||||
This file contains two kinds of functionality:
|
||||
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Custom op collection for uncached causal conv (sliding window with 1d)."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Custom op collection for uncached mamba mixer (linear attention)."""
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
24
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py
Normal file
24
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py
Normal file
@ -0,0 +1,24 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Multi-head Latent Attention operations.
|
||||
|
||||
This module provides Multi-head Latent Attention (MLA) implementations:
|
||||
- mla: MLA operations and attention descriptor
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"mla",
|
||||
]
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Custom ops for MultiHead Latent attention."""
|
||||
|
||||
import math
|
||||
@ -7,8 +22,9 @@ import torch
|
||||
from torch._ops import OpOverloadPacket
|
||||
from torch.fx import Node
|
||||
|
||||
from ....llmapi.llm_args import KvCacheConfig
|
||||
from .attention_interface import (
|
||||
from .....llmapi.llm_args import KvCacheConfig
|
||||
from ..attention.triton_attention import _decode_attention, _prefill_attention
|
||||
from ..attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionLayout,
|
||||
AttentionRegistry,
|
||||
@ -16,7 +32,6 @@ from .attention_interface import (
|
||||
ResourceHandlerDict,
|
||||
UnpagedResourceHandler,
|
||||
)
|
||||
from .triton_attention import _decode_attention, _prefill_attention
|
||||
|
||||
Constant = Union[int, float, str, None]
|
||||
|
||||
@ -0,0 +1,30 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Normalization operations.
|
||||
|
||||
This module provides various normalization implementations:
|
||||
- rms_norm: RMSNorm implementations (FlashInfer, Triton, reference)
|
||||
- triton_rms_norm: Low-level Triton RMSNorm kernel
|
||||
- l2norm: L2 normalization operations
|
||||
- flashinfer_fused_add_rms_norm: Fused add + RMSNorm operation
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"rms_norm",
|
||||
"triton_rms_norm",
|
||||
"l2norm",
|
||||
"flashinfer_fused_add_rms_norm",
|
||||
]
|
||||
@ -12,7 +12,7 @@
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
from ...flashinfer_utils import get_env_enable_pdl
|
||||
from ....flashinfer_utils import get_env_enable_pdl
|
||||
|
||||
|
||||
@torch.library.custom_op(
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Custom ops corresponding to l2norm."""
|
||||
|
||||
import torch
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Custom operator for FlashInfer and Triton RMSNorm implementation."""
|
||||
|
||||
import flashinfer
|
||||
@ -6,9 +21,9 @@ import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from ...flashinfer_utils import get_env_enable_pdl
|
||||
from ...modules.mamba.layernorm_gated import _layer_norm_fwd
|
||||
from .triton_kernels.rms_norm import rms_norm
|
||||
from ....flashinfer_utils import get_env_enable_pdl
|
||||
from ....modules.mamba.layernorm_gated import _layer_norm_fwd
|
||||
from .triton_rms_norm import rms_norm
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::flashinfer_rms_norm", mutates_args=())
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@ -0,0 +1,26 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Quantization operations.
|
||||
|
||||
This module provides quantization utilities and operations:
|
||||
- quant: Quantization operations (FP8, FP4, INT4, INT8)
|
||||
- torch_quant: PyTorch-based quantization implementations
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"quant",
|
||||
"torch_quant",
|
||||
]
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Definition of the quant module that can be used for PTQ."""
|
||||
|
||||
import warnings
|
||||
@ -7,7 +22,7 @@ import torch
|
||||
from flashinfer import bmm_fp8
|
||||
from torch import nn
|
||||
|
||||
from .torch_libs.float8_python_api import addmm_float8_unwrapped
|
||||
from ..torch_libs.float8_python_api import addmm_float8_unwrapped
|
||||
|
||||
TRTLLM_FP4_OP_AVAILABLE = True
|
||||
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
30
tensorrt_llm/_torch/auto_deploy/custom_ops/rope/__init__.py
Normal file
30
tensorrt_llm/_torch/auto_deploy/custom_ops/rope/__init__.py
Normal file
@ -0,0 +1,30 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""RoPE (Rotary Position Embedding) operations.
|
||||
|
||||
This module provides various RoPE implementations:
|
||||
- torch_rope: PyTorch reference implementation
|
||||
- flashinfer_rope: FlashInfer-based optimized RoPE
|
||||
- triton_rope: Triton-based RoPE implementation
|
||||
- triton_rope_kernel: Low-level Triton kernels for RoPE
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"torch_rope",
|
||||
"flashinfer_rope",
|
||||
"triton_rope",
|
||||
"triton_rope_kernel",
|
||||
]
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import flashinfer
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
@ -1,7 +1,22 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from .triton_kernels.rope import rope_fwd_flattened_kernel, rope_fwd_kernel
|
||||
from .triton_rope_kernel import rope_fwd_flattened_kernel, rope_fwd_kernel
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::triton_rope_with_input_pos", mutates_args=())
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@ -0,0 +1,14 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
26
tensorrt_llm/_torch/auto_deploy/custom_ops/utils/__init__.py
Normal file
26
tensorrt_llm/_torch/auto_deploy/custom_ops/utils/__init__.py
Normal file
@ -0,0 +1,26 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utility operations.
|
||||
|
||||
This module provides utility functions and helpers:
|
||||
- torch_gather_logits: Logit gathering operations
|
||||
- triton_utils: Triton utility functions and helpers
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"torch_gather_logits",
|
||||
"triton_utils",
|
||||
]
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Triton utility operations for auto_deploy."""
|
||||
|
||||
import torch
|
||||
@ -31,7 +31,7 @@ from transformers.generation import GenerationMixin
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import ModelOutput
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import gated_rms_norm_ref
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.rms_norm import gated_rms_norm_ref
|
||||
from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory
|
||||
from tensorrt_llm._torch.utils import ActivationType
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ from typing import Tuple
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...custom_ops.flashinfer_fused_add_rms_norm import flashinfer_fused_add_rms_norm
|
||||
from ...custom_ops.normalization.flashinfer_fused_add_rms_norm import flashinfer_fused_add_rms_norm
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
|
||||
|
||||
@ -9,7 +9,10 @@ from torch.fx import GraphModule, Node
|
||||
|
||||
from tensorrt_llm._torch.utils import ActivationType
|
||||
|
||||
from ...custom_ops.quant import TRTLLM_NVFP4_PACKING_FACTOR, TRTLLM_NVFP4_SCALING_VECTOR_SIZE
|
||||
from ...custom_ops.quantization.quant import (
|
||||
TRTLLM_NVFP4_PACKING_FACTOR,
|
||||
TRTLLM_NVFP4_SCALING_VECTOR_SIZE,
|
||||
)
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils._graph import delete_all_unused_submodules, eliminate_dead_code, get_attr_by_name
|
||||
|
||||
@ -6,7 +6,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ...custom_ops.quant import (
|
||||
from ...custom_ops.quantization.quant import (
|
||||
FP4_GLOBAL_SCALE_MAX,
|
||||
FP8_MAX,
|
||||
TRTLLM_NVFP4_COLUMN_SIZE,
|
||||
|
||||
@ -6,7 +6,7 @@ import torch
|
||||
from pydantic import Field
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ...custom_ops.rms_norm import gated_rms_norm_ref
|
||||
from ...custom_ops.normalization.rms_norm import gated_rms_norm_ref
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from .....functional import AllReduceStrategy
|
||||
from ...custom_ops.trtllm_dist import is_trtllm_op_available
|
||||
from ...custom_ops.distributed.trtllm_dist import is_trtllm_op_available
|
||||
from ...models.factory import ModelFactory, ShardingConfigSource
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils._graph import del_attr_by_name, eliminate_dead_code
|
||||
|
||||
@ -5,7 +5,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ..custom_ops.quant import FP4_GLOBAL_SCALE_MAX, FP8_MAX
|
||||
from ..custom_ops.quantization.quant import FP4_GLOBAL_SCALE_MAX, FP8_MAX
|
||||
from .logger import ad_logger
|
||||
from .node_utils import (
|
||||
extract_weight_name,
|
||||
|
||||
@ -5,7 +5,9 @@ import torch
|
||||
from _dist_test_utils import get_device_counts
|
||||
from torch.export import export
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.trtllm_dist import is_trtllm_op_available
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.distributed.trtllm_dist import (
|
||||
is_trtllm_op_available,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.distributed.common import initialize_or_skip
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
|
||||
@ -22,7 +22,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Ensure custom ops are registered
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops import rms_norm # noqa: F401
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization import rms_norm # noqa: F401
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
|
||||
@ -3,7 +3,9 @@ import pytest
|
||||
import torch
|
||||
from torch_attention_reference import TorchAttentionReference
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.flashinfer_attention import _GlobalFlashInferPlanner
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention.flashinfer_attention import (
|
||||
_GlobalFlashInferPlanner,
|
||||
)
|
||||
|
||||
|
||||
def _create_combined_kv_cache(k_cache: torch.Tensor, v_cache: torch.Tensor) -> torch.Tensor:
|
||||
@ -7,7 +7,7 @@ import triton
|
||||
from _custom_op_utils import torch_rope_reference
|
||||
from _model_test_utils import repeat_kv
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.triton_kernels.attention_with_kv_cache import (
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention.triton_attention_with_kv_cache import (
|
||||
attention_kv_stage1,
|
||||
attention_kv_stage2,
|
||||
context_attention_kv,
|
||||
@ -2,7 +2,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy # noqa: F401
|
||||
from tests.unittest._torch.auto_deploy.unit.singlegpu.custom_ops.test_triton_mamba_cached_op import (
|
||||
from tests.unittest._torch.auto_deploy.unit.singlegpu.custom_ops.mamba.test_triton_mamba_cached_op import (
|
||||
_random_params,
|
||||
)
|
||||
|
||||
@ -13,7 +13,7 @@ from torch.nn import functional as F
|
||||
from utils.util import skip_pre_hopper
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.quant import (
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.quantization.quant import (
|
||||
TRTLLM_NVFP4_COLUMN_SIZE,
|
||||
TRTLLM_NVFP4_ROW_SIZE,
|
||||
TRTLLM_NVFP4_SCALING_VECTOR_SIZE,
|
||||
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.flashinfer_fused_add_rms_norm import (
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.flashinfer_fused_add_rms_norm import (
|
||||
flashinfer_fused_add_rms_norm,
|
||||
)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import (
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.rms_norm import (
|
||||
gated_rms_norm_ref,
|
||||
triton_rmsnorm_gated,
|
||||
)
|
||||
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import * # noqa
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.triton_kernels.rms_norm import rms_norm
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.rms_norm import * # noqa
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.triton_rms_norm import rms_norm
|
||||
|
||||
|
||||
def test_rmsnorm_triton_op():
|
||||
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.torch_attention import update_kv_cache
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention.torch_attention import update_kv_cache
|
||||
|
||||
|
||||
def test_update_kv_cache():
|
||||
|
||||
@ -4,7 +4,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
# Import to register the custom op
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops import triton_utils # noqa: F401
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.utils import triton_utils # noqa: F401
|
||||
|
||||
|
||||
def _reference_gather_scatter(
|
||||
@ -3,7 +3,7 @@ import torch
|
||||
from _graph_test_helpers import run_test_transformed_gm
|
||||
from torch.export import Dim
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.l2norm import * # noqa
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.l2norm import * # noqa
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
|
||||
@ -3,7 +3,7 @@ import torch
|
||||
from _graph_test_helpers import run_test_transformed_gm
|
||||
from torch.export import Dim
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import * # noqa
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.rms_norm import * # noqa
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
|
||||
@ -23,8 +23,8 @@ import torch
|
||||
from torch.export import Dim
|
||||
|
||||
# Import modules to register custom ops (torch.ops.auto_deploy.*)
|
||||
import tensorrt_llm._torch.auto_deploy.custom_ops.torch_attention # noqa: F401
|
||||
import tensorrt_llm._torch.auto_deploy.custom_ops.torch_rope # noqa: F401
|
||||
import tensorrt_llm._torch.auto_deploy.custom_ops.attention.torch_attention # noqa: F401
|
||||
import tensorrt_llm._torch.auto_deploy.custom_ops.rope.torch_rope # noqa: F401
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import torch
|
||||
from torch.export import Dim
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.flashinfer_fused_add_rms_norm import * # noqa
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import * # noqa
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.flashinfer_fused_add_rms_norm import * # noqa
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.rms_norm import * # noqa
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.quant import FP8_MAX
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.quantization.quant import FP8_MAX
|
||||
from tensorrt_llm._torch.auto_deploy.transform.interface import TransformConfig
|
||||
from tensorrt_llm._torch.auto_deploy.transform.library.quantization import (
|
||||
FP8LinearQuantizationFromConfig,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user