diff --git a/requirements.txt b/requirements.txt index 2ecc98bfdc..effc76c552 100644 --- a/requirements.txt +++ b/requirements.txt @@ -82,3 +82,5 @@ mistral-common==1.8.6 torchao>=0.14.1 cuda-core llist +cuda-tile>=1.0.1 +nvidia-cuda-tileiras>=13.1 diff --git a/tensorrt_llm/_torch/compilation/utils.py b/tensorrt_llm/_torch/compilation/utils.py index dc22d08681..d5aa808f6d 100644 --- a/tensorrt_llm/_torch/compilation/utils.py +++ b/tensorrt_llm/_torch/compilation/utils.py @@ -5,6 +5,8 @@ import torch from torch.fx import Node from torch.fx.experimental.symbolic_shapes import ShapeEnv +from ..cuda_tile_utils import IS_CUDA_TILE_AVAILABLE + def get_symint_val(i: Union[torch.SymInt | int]): if isinstance(i, int): @@ -93,6 +95,13 @@ def inplace_info(): }, torch.ops.trtllm.pp_send_tensors.default: { 1: "tensors" - } + }, } + if IS_CUDA_TILE_AVAILABLE: + # cuda.tile availability depends on GPU capability thus runtime check. + inplace_map[ + torch.ops.trtllm.cuda_tile_rms_norm_fuse_residual_.default] = { + 1: "x", + 2: "residual" + } return inplace_map diff --git a/tensorrt_llm/_torch/cuda_tile_kernels/__init__.py b/tensorrt_llm/_torch/cuda_tile_kernels/__init__.py new file mode 100644 index 0000000000..6863d6b32a --- /dev/null +++ b/tensorrt_llm/_torch/cuda_tile_kernels/__init__.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..cuda_tile_utils import IS_CUDA_TILE_AVAILABLE + +if IS_CUDA_TILE_AVAILABLE: + from .rms_norm import rms_norm_kernel, rms_norm_kernel_gather, rms_norm_kernel_static_persistent + from .rms_norm_fuse_residual import ( + rms_norm_fuse_residual_kernel, + rms_norm_fuse_residual_kernel_gather, + rms_norm_fuse_residual_kernel_static_persistent, + ) + + __all__ = [ + "rms_norm_kernel", + "rms_norm_kernel_gather", + "rms_norm_kernel_static_persistent", + "rms_norm_fuse_residual_kernel", + "rms_norm_fuse_residual_kernel_gather", + "rms_norm_fuse_residual_kernel_static_persistent", + ] diff --git a/tensorrt_llm/_torch/cuda_tile_kernels/rms_norm.py b/tensorrt_llm/_torch/cuda_tile_kernels/rms_norm.py new file mode 100644 index 0000000000..33189863d3 --- /dev/null +++ b/tensorrt_llm/_torch/cuda_tile_kernels/rms_norm.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from https://github.com/NVIDIA/cutile-python/blob/main/test/kernels/rms_norm.py +from ..cuda_tile_utils import IS_CUDA_TILE_AVAILABLE + +if IS_CUDA_TILE_AVAILABLE: + import cuda.tile as ct + + @ct.kernel + def rms_norm_kernel( + x, + w, + out, + Rstd, + N: ct.Constant[int], + eps: ct.Constant[float], + TILE_SIZE: ct.Constant[int], + use_gemma: ct.Constant[bool], + ): + """Standard RMSNorm kernel for non-static persistent mode with tiled loads""" + row = ct.bid(0) + _rms = ct.full((1, TILE_SIZE), 0.0, dtype=ct.float32) + num_tiles = ct.cdiv(x.shape[1], TILE_SIZE) + + for j in range(0, num_tiles): + xj = ct.load( + x, + index=(row, j), + shape=(1, TILE_SIZE), + allow_tma=False, + latency=1, + ) + xj = ct.astype(xj, ct.float32) + _rms += xj * xj + + # Calculate RMS Norm + rms = ct.rsqrt(ct.sum(_rms, axis=1, keepdims=False) / N + eps) + ct.store(Rstd, index=(row,), tile=rms) + + for j in range(0, num_tiles): + wj = ct.load( + w, + index=(j,), + shape=(TILE_SIZE,), + allow_tma=False, + latency=1, + ) + wj = ct.astype(wj, ct.float32) + # Apply Gemma-style bias if enabled + if use_gemma: + wj = wj + 1.0 + xj = ct.load( + x, + index=(row, j), + shape=(1, TILE_SIZE), + allow_tma=False, + latency=1, + ) + xj = ct.astype(xj, ct.float32) + yj = xj * rms * wj + yj = ct.astype(yj, x.dtype) + ct.store( + out, + index=(row, j), + tile=yj, + allow_tma=False, + latency=1, + ) + + @ct.kernel + def rms_norm_kernel_gather( + x, + w, + out, + Rstd, + N: ct.Constant[int], + eps: ct.Constant[float], + TILE_SIZE: ct.Constant[int], + use_gemma: ct.Constant[bool], + ): + """Standard RMSNorm kernel for non-static persistent mode with ptr loads""" + row = ct.bid(0) + _rms = ct.full((TILE_SIZE,), 0.0, dtype=ct.float32) + num_tiles = ct.cdiv(N, TILE_SIZE) + offsets = ct.arange(TILE_SIZE, dtype=ct.int32) + + for j in range(0, num_tiles): + offs = j * TILE_SIZE + offsets + xj = ct.gather(x, (row, offs), latency=1) + xj = ct.astype(xj, ct.float32) + _rms += xj * xj + + # Calculate RMS Norm + rms = ct.rsqrt(ct.sum(_rms, axis=0, keepdims=False) / N + eps) + ct.scatter(Rstd, row, rms) + + for j in range(0, num_tiles): + offs = j * TILE_SIZE + offsets + wj = ct.gather(w, offs, latency=1) + wj = ct.astype(wj, ct.float32) + # Apply Gemma-style bias if enabled + if use_gemma: + wj = wj + 1.0 + xj = ct.gather(x, (row, offs), latency=1) + xj = ct.astype(xj, ct.float32) + yj = xj * rms * wj + yj = ct.astype(yj, x.dtype) + ct.scatter(out, (row, offs), yj, latency=1) + + @ct.kernel + def rms_norm_kernel_static_persistent( + X, # Input tensor + Y, # Output tensor + W, # Weight tensor + TILE_SIZE_M: ct.Constant[int], # 4 rows per block + TILE_SIZE_N: ct.Constant[int], # columns per block + eps: ct.Constant[float], # Epsilon value + use_gemma: ct.Constant[bool], # Gemma-style weight bias + ): + """ + CuTile static persistent RMSNorm kernel that processes multiple blocks per program. + Each program processes multiple blocks in a loop for better efficiency. + """ + # Get program ID + pid = ct.bid(0) + + # Infer tensor dimensions from input shape + M = X.shape[0] # Number of rows + N = X.shape[1] # Number of columns + + # Calculate upper bound - number of row blocks to process + upper_bound = (M + TILE_SIZE_M - 1) // TILE_SIZE_M + + # Load weight vector once (shared across all blocks processed by this program) + w = ct.load(W, index=(0,), shape=(TILE_SIZE_N,)) + w = ct.astype(w, ct.float32) + # Apply Gemma-style bias if enabled + if use_gemma: + w = w + 1.0 + + # Static persistent loop: each program processes multiple blocks + num_tiles_x = ct.num_blocks(0) + for current_block in range(pid, upper_bound, num_tiles_x): + # Load input tile + x = ct.load( + X, + index=(current_block, 0), + shape=(TILE_SIZE_M, TILE_SIZE_N), + latency=10, # +2% perf from this hint + ) + x = ct.astype(x, ct.float32) + + # Step 1: Compute x^2 + x_squared = ct.mul(x, x) + + # Step 2: Reduce sum along axis=1 (columns) + x2_sum = ct.sum(x_squared, axis=1, keepdims=True) # Shape: [TILE_SIZE_M, 1] + + # Step 3: Compute variance (divide by N) + N_f32 = ct.full((TILE_SIZE_M, 1), N * 1.0, dtype=ct.float32) + variance = ct.truediv(x2_sum, N_f32) + + # Step 4: Add epsilon and compute rsqrt + eps_tensor = ct.full((TILE_SIZE_M, 1), eps, dtype=ct.float32) + variance_eps = ct.add(variance, eps_tensor) + rsqrt_var = ct.rsqrt(variance_eps) + + # Step 5: Apply normalization + x_normalized = ct.mul(x, rsqrt_var) + + # Step 6: Apply linear transformation + # Broadcast weight to match input shape + w_broadcasted = ct.reshape(w, (1, TILE_SIZE_N)) + b_broadcasted = ct.full((1, TILE_SIZE_N), 0.0, dtype=ct.float32) + + # Apply linear transformation: y = x_normalized * w + b + y = ct.mul(x_normalized, w_broadcasted) + y = ct.add(y, b_broadcasted) + + # Convert back to original dtype + y = ct.astype(y, X.dtype) + + # Store result + ct.store( + Y, + index=(current_block, 0), + tile=y, + allow_tma=False, # +30% perf + latency=3, # +3% perf from this hint + ) diff --git a/tensorrt_llm/_torch/cuda_tile_kernels/rms_norm_fuse_residual.py b/tensorrt_llm/_torch/cuda_tile_kernels/rms_norm_fuse_residual.py new file mode 100644 index 0000000000..468348976c --- /dev/null +++ b/tensorrt_llm/_torch/cuda_tile_kernels/rms_norm_fuse_residual.py @@ -0,0 +1,258 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from rms_norm.py with residual fusion support +from ..cuda_tile_utils import IS_CUDA_TILE_AVAILABLE + +if IS_CUDA_TILE_AVAILABLE: + import cuda.tile as ct + + @ct.kernel + def rms_norm_fuse_residual_kernel( + x, + residual, + w, + Rstd, + N: ct.Constant[int], + eps: ct.Constant[float], + TILE_SIZE: ct.Constant[int], + use_gemma: ct.Constant[bool], + ): + """RMSNorm kernel with residual fusion for non-static persistent mode with tiled loads""" + row = ct.bid(0) + _rms = ct.full((1, TILE_SIZE), 0.0, dtype=ct.float32) + num_tiles = ct.cdiv(x.shape[1], TILE_SIZE) + + # First pass: compute RMS with fused residual addition and store sum to residual + for j in range(0, num_tiles): + xj = ct.load( + x, + index=(row, j), + shape=(1, TILE_SIZE), + allow_tma=False, + latency=1, + ) + residual_j = ct.load( + residual, + index=(row, j), + shape=(1, TILE_SIZE), + allow_tma=False, + latency=1, + ) + # Fuse residual: convert to float32, add, then use for RMS computation + xj = ct.astype(xj, ct.float32) + residual_j = ct.astype(residual_j, ct.float32) + xj = xj + residual_j + _rms += xj * xj + + # Store the sum (new residual) back to residual tensor + xj_stored = ct.astype(xj, residual.dtype) + ct.store( + residual, + index=(row, j), + tile=xj_stored, + allow_tma=False, + latency=1, + ) + + # Calculate RMS Norm + rms = ct.rsqrt(ct.sum(_rms, axis=1, keepdims=False) / N + eps) + ct.store(Rstd, index=(row,), tile=rms) + + # Second pass: load from residual (which now contains the sum), apply normalization, store to x + for j in range(0, num_tiles): + wj = ct.load( + w, + index=(j,), + shape=(TILE_SIZE,), + allow_tma=False, + latency=1, + ) + wj = ct.astype(wj, ct.float32) + # Apply Gemma-style bias if enabled + if use_gemma: + wj = wj + 1.0 + residual_j = ct.load( + residual, + index=(row, j), + shape=(1, TILE_SIZE), + allow_tma=False, + latency=1, + ) + # Load from residual (which now contains x + residual sum) + residual_j = ct.astype(residual_j, ct.float32) + yj = residual_j * rms * wj + yj = ct.astype(yj, x.dtype) + ct.store( + x, + index=(row, j), + tile=yj, + allow_tma=False, + latency=1, + ) + + @ct.kernel + def rms_norm_fuse_residual_kernel_gather( + x, + residual, + w, + Rstd, + N: ct.Constant[int], + eps: ct.Constant[float], + TILE_SIZE: ct.Constant[int], + use_gemma: ct.Constant[bool], + ): + """RMSNorm kernel with residual fusion for non-static persistent mode with ptr loads""" + row = ct.bid(0) + _rms = ct.full((TILE_SIZE,), 0.0, dtype=ct.float32) + num_tiles = ct.cdiv(N, TILE_SIZE) + offsets = ct.arange(TILE_SIZE, dtype=ct.int32) + + # First pass: compute RMS with fused residual addition and store sum to residual + for j in range(0, num_tiles): + offs = j * TILE_SIZE + offsets + xj = ct.gather(x, (row, offs), latency=1) + residual_j = ct.gather(residual, (row, offs), latency=1) + # Fuse residual: convert to float32, add, then use for RMS computation + xj = ct.astype(xj, ct.float32) + residual_j = ct.astype(residual_j, ct.float32) + xj = xj + residual_j + _rms += xj * xj + + # Store the sum (new residual) back to residual tensor + xj_stored = ct.astype(xj, residual.dtype) + ct.scatter(residual, (row, offs), xj_stored, latency=1) + + # Calculate RMS Norm + rms = ct.rsqrt(ct.sum(_rms, axis=0, keepdims=False) / N + eps) + ct.scatter(Rstd, row, rms) + + # Second pass: load from residual (which now contains the sum), apply normalization, store to x + for j in range(0, num_tiles): + offs = j * TILE_SIZE + offsets + wj = ct.gather(w, offs, latency=1) + wj = ct.astype(wj, ct.float32) + # Apply Gemma-style bias if enabled + if use_gemma: + wj = wj + 1.0 + residual_j = ct.gather(residual, (row, offs), latency=1) + # Load from residual (which now contains x + residual sum) + residual_j = ct.astype(residual_j, ct.float32) + yj = residual_j * rms * wj + yj = ct.astype(yj, x.dtype) + ct.scatter(x, (row, offs), yj, latency=1) + + @ct.kernel + def rms_norm_fuse_residual_kernel_static_persistent( + X, # Input tensor + Residual, # Residual tensor + W, # Weight tensor + TILE_SIZE_M: ct.Constant[int], # 4 rows per block + TILE_SIZE_N: ct.Constant[int], # columns per block + eps: ct.Constant[float], # Epsilon value + use_gemma: ct.Constant[bool], # Gemma-style weight bias + ): + """ + CuTile static persistent RMSNorm kernel with residual fusion that processes multiple blocks per program. + Each program processes multiple blocks in a loop for better efficiency. + """ + # Get program ID + pid = ct.bid(0) + + # Infer tensor dimensions from input shape + M = X.shape[0] # Number of rows + N = X.shape[1] # Number of columns + + # Calculate upper bound - number of row blocks to process + upper_bound = (M + TILE_SIZE_M - 1) // TILE_SIZE_M + + # Load weight vector once (shared across all blocks processed by this program) + w = ct.load(W, index=(0,), shape=(TILE_SIZE_N,)) + w = ct.astype(w, ct.float32) + # Apply Gemma-style bias if enabled + if use_gemma: + w = w + 1.0 + + # Static persistent loop: each program processes multiple blocks + num_tiles_x = ct.num_blocks(0) + for current_block in range(pid, upper_bound, num_tiles_x): + # Load input tile + x = ct.load( + X, + index=(current_block, 0), + shape=(TILE_SIZE_M, TILE_SIZE_N), + latency=10, # +2% perf from this hint + ) + # Load residual tile + residual = ct.load( + Residual, + index=(current_block, 0), + shape=(TILE_SIZE_M, TILE_SIZE_N), + latency=10, + ) + + # Fuse residual: convert to float32 and add + x = ct.astype(x, ct.float32) + residual = ct.astype(residual, ct.float32) + x = ct.add(x, residual) + + # Store the sum (new residual) back to Residual tensor + x_stored = ct.astype(x, Residual.dtype) + ct.store( + Residual, + index=(current_block, 0), + tile=x_stored, + allow_tma=False, + latency=3, + ) + + # Step 1: Compute x^2 + x_squared = ct.mul(x, x) + + # Step 2: Reduce sum along axis=1 (columns) + x2_sum = ct.sum(x_squared, axis=1, keepdims=True) # Shape: [TILE_SIZE_M, 1] + + # Step 3: Compute variance (divide by N) + N_f32 = ct.full((TILE_SIZE_M, 1), N * 1.0, dtype=ct.float32) + variance = ct.truediv(x2_sum, N_f32) + + # Step 4: Add epsilon and compute rsqrt + eps_tensor = ct.full((TILE_SIZE_M, 1), eps, dtype=ct.float32) + variance_eps = ct.add(variance, eps_tensor) + rsqrt_var = ct.rsqrt(variance_eps) + + # Step 5: Apply normalization + x_normalized = ct.mul(x, rsqrt_var) + + # Step 6: Apply linear transformation + # Broadcast weight to match input shape + w_broadcasted = ct.reshape(w, (1, TILE_SIZE_N)) + b_broadcasted = ct.full((1, TILE_SIZE_N), 0.0, dtype=ct.float32) + + # Apply linear transformation: y = x_normalized * w + b + y = ct.mul(x_normalized, w_broadcasted) + y = ct.add(y, b_broadcasted) + + # Convert back to original dtype and store to X (new hidden_states) + y = ct.astype(y, X.dtype) + + # Store result to X + ct.store( + X, + index=(current_block, 0), + tile=y, + allow_tma=False, # +30% perf + latency=3, # +3% perf from this hint + ) diff --git a/tensorrt_llm/_torch/cuda_tile_utils.py b/tensorrt_llm/_torch/cuda_tile_utils.py new file mode 100644 index 0000000000..b16e688189 --- /dev/null +++ b/tensorrt_llm/_torch/cuda_tile_utils.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import importlib.metadata +import os +import platform +import shutil + +import torch + +from ..logger import logger + +IS_CUDA_TILE_AVAILABLE = False + + +@functools.lru_cache() +def next_power_of_2(n: int): + """Return the smallest power of 2 greater than or equal to n""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + n += 1 + return n + + +def ceil_div(a, b): + return (a + b - 1) // b + + +if platform.system() != "Windows": + try: + import cuda.tile # noqa: F401 + except ImportError: + logger.warning("cuda-tile package not found, TileIR kernels will not be available") + else: + if (cc := torch.cuda.get_device_properties()) and (cc.major, cc.minor) < (10, 0): + logger.warning( + f"TileIR requires compute capability 10.0 or higher, but the current device has " + f"{cc.major}.{cc.minor}. TileIR kernels will not be available" + ) + elif shutil.which("tileiras") is not None: + IS_CUDA_TILE_AVAILABLE = True + # For systems without tileiras installed, try to locate from nvidia-cuda-tileiras package. + elif tileiras_files := importlib.metadata.files("nvidia-cuda-tileiras"): + for pkg_file in tileiras_files: + if pkg_file.name == "tileiras": + tileiras_dir = pkg_file.locate().parent + os.environ["PATH"] = f"{os.environ['PATH']}:{tileiras_dir}" + break + assert shutil.which("tileiras") is not None + IS_CUDA_TILE_AVAILABLE = True + else: + logger.warning("tileiras compiler not found, TileIR kernels will not be available") diff --git a/tensorrt_llm/_torch/custom_ops/__init__.py b/tensorrt_llm/_torch/custom_ops/__init__.py index 5fb2927b01..ef58979ef6 100644 --- a/tensorrt_llm/_torch/custom_ops/__init__.py +++ b/tensorrt_llm/_torch/custom_ops/__init__.py @@ -1,3 +1,4 @@ +from ..cuda_tile_utils import IS_CUDA_TILE_AVAILABLE from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE from ..modules.attention import attn_custom_op_inplace, mla_custom_op_inplace @@ -38,3 +39,11 @@ if IS_CUTLASS_DSL_AVAILABLE: __all__ += [ 'cute_dsl_nvfp4_gemm_blackwell', ] + +if IS_CUDA_TILE_AVAILABLE: + from .cuda_tile_custom_ops import (cuda_tile_rms_norm, + cuda_tile_rms_norm_fuse_residual_) + __all__ += [ + 'cuda_tile_rms_norm', + 'cuda_tile_rms_norm_fuse_residual_', + ] diff --git a/tensorrt_llm/_torch/custom_ops/cuda_tile_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cuda_tile_custom_ops.py new file mode 100644 index 0000000000..0dc6b82680 --- /dev/null +++ b/tensorrt_llm/_torch/custom_ops/cuda_tile_custom_ops.py @@ -0,0 +1,188 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from https://github.com/NVIDIA/cutile-python/blob/main/test/bench_rms_norm.py + +import torch + +from ..cuda_tile_utils import IS_CUDA_TILE_AVAILABLE + +if IS_CUDA_TILE_AVAILABLE: + import cuda.tile as ct + + from ..cuda_tile_kernels import ( + rms_norm_fuse_residual_kernel, + rms_norm_fuse_residual_kernel_gather, + rms_norm_fuse_residual_kernel_static_persistent, + rms_norm_kernel, + rms_norm_kernel_gather, + rms_norm_kernel_static_persistent, + ) + from ..cuda_tile_utils import ceil_div, next_power_of_2 + + @torch.library.custom_op("trtllm::cuda_tile_rms_norm", mutates_args=()) + def cuda_tile_rms_norm( + x: torch.Tensor, + weight: torch.Tensor, + eps: float, + static_persistent: bool, + gather: bool, + use_gemma: bool, + ) -> torch.Tensor: + x = x.contiguous() + weight = weight.contiguous() + + # Allocate output tensor + y = torch.empty_like(x) + M, N = x.shape + + if static_persistent: + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + TILE_SIZE_M = 4 # Default value, could be made configurable + TILE_SIZE_N = next_power_of_2(N) + + # Other tile sizes are more optimal when other dimension is too large/too small + if TILE_SIZE_N <= 1024: + TILE_SIZE_M = 16 + elif TILE_SIZE_N >= 16384: + TILE_SIZE_M = 2 + + grid_size = min( + NUM_SMS, + ceil_div(M, TILE_SIZE_M) * ceil_div(N, TILE_SIZE_N), + ) + grid = (grid_size,) + ct.launch( + torch.cuda.current_stream(), + grid, + rms_norm_kernel_static_persistent, + ( + x, + y, + weight, + TILE_SIZE_M, + TILE_SIZE_N, + eps, + use_gemma, + ), + ) + else: + # Standard RMSNorm kernel + rstd = torch.empty((M,), dtype=torch.float32, device="cuda") + MAX_FUSED_SIZE = 2048 // x.element_size() + TILE_SIZE = min(MAX_FUSED_SIZE, next_power_of_2(N)) + grid = (M,) + kernel = rms_norm_kernel_gather if gather else rms_norm_kernel + ct.launch( + torch.cuda.current_stream(), + grid, + kernel, + ( + x, + weight, + y, + rstd, + N, + eps, + TILE_SIZE, + use_gemma, + ), + ) + return y.view(*x.shape) + + @cuda_tile_rms_norm.register_fake + def _( + x: torch.Tensor, + weight: torch.Tensor, + eps: float, + static_persistent: bool, + gather: bool, + use_gemma: bool, + ) -> torch.Tensor: + return torch.empty_like(x.contiguous()) + + @torch.library.custom_op( + "trtllm::cuda_tile_rms_norm_fuse_residual_", + mutates_args=("x", "residual"), + ) + def cuda_tile_rms_norm_fuse_residual_( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float, + static_persistent: bool, + gather: bool, + use_gemma: bool, + ) -> None: + assert x.is_contiguous(), "x must be contiguous for in-place operation" + assert residual.is_contiguous(), "residual must be contiguous for in-place operation" + weight = weight.contiguous() + + M, N = x.shape + + if static_persistent: + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + TILE_SIZE_M = 4 # Default value, could be made configurable + TILE_SIZE_N = next_power_of_2(N) + + # Other tile sizes are more optimal when other dimension is too large/too small + if TILE_SIZE_N <= 1024: + TILE_SIZE_M = 16 + elif TILE_SIZE_N >= 16384: + TILE_SIZE_M = 2 + + grid_size = min( + NUM_SMS, + ceil_div(M, TILE_SIZE_M) * ceil_div(N, TILE_SIZE_N), + ) + grid = (grid_size,) + ct.launch( + torch.cuda.current_stream(), + grid, + rms_norm_fuse_residual_kernel_static_persistent, + ( + x, + residual, + weight, + TILE_SIZE_M, + TILE_SIZE_N, + eps, + use_gemma, + ), + ) + else: + # Standard RMSNorm kernel with residual fusion + rstd = torch.empty((M,), dtype=torch.float32, device="cuda") + MAX_FUSED_SIZE = 2048 // x.element_size() + TILE_SIZE = min(MAX_FUSED_SIZE, next_power_of_2(N)) + grid = (M,) + kernel = ( + rms_norm_fuse_residual_kernel_gather if gather else rms_norm_fuse_residual_kernel + ) + ct.launch( + torch.cuda.current_stream(), + grid, + kernel, + ( + x, + residual, + weight, + rstd, + N, + eps, + TILE_SIZE, + use_gemma, + ), + ) diff --git a/tensorrt_llm/_torch/modules/rms_norm.py b/tensorrt_llm/_torch/modules/rms_norm.py index a3e6bcde96..d6e0a5994b 100644 --- a/tensorrt_llm/_torch/modules/rms_norm.py +++ b/tensorrt_llm/_torch/modules/rms_norm.py @@ -20,6 +20,7 @@ from typing import Optional, Tuple, TypeAlias, Union, cast import torch from torch import nn +from ..cuda_tile_utils import IS_CUDA_TILE_AVAILABLE from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE from ..utils import Fp4QuantizedTensor @@ -39,6 +40,7 @@ class RMSNorm(nn.Module): has_weights: bool = True, use_gemma: bool = False, quantize_type: Optional[str] = None, + use_cuda_tile: bool = False, ): super().__init__() @@ -49,6 +51,10 @@ class RMSNorm(nn.Module): raise NotImplementedError( f"Quantize type {quantize_type} not implemented in RMSNorm") self.is_nvfp4 = quantize_type == "nvfp4" + if use_cuda_tile and not IS_CUDA_TILE_AVAILABLE: + raise ValueError( + "cuda.tile is not available, please install cuda-tile pypi package" + ) if has_weights: if not use_gemma: @@ -65,6 +71,7 @@ class RMSNorm(nn.Module): persistent=False) self.variance_epsilon = eps self.use_gemma = use_gemma + self.use_cuda_tile = use_cuda_tile def forward( self, @@ -127,8 +134,30 @@ class RMSNorm(nn.Module): hidden_states_fused = Fp4QuantizedTensor(normed_fp4_u8, sf_fused) return (hidden_states_fused, residual_out) if has_residual else hidden_states_fused - - if IS_FLASHINFER_AVAILABLE: + elif self.use_cuda_tile: + if isinstance(residual, torch.Tensor): + # Use fused residual kernel + hidden_states = hidden_states.contiguous() + residual = residual.contiguous() + torch.ops.trtllm.cuda_tile_rms_norm_fuse_residual_( + x=hidden_states, + residual=residual, + weight=self.weight, + eps=self.variance_epsilon, + static_persistent=True, + gather=True, + use_gemma=self.use_gemma, + ) + else: + hidden_states = torch.ops.trtllm.cuda_tile_rms_norm( + x=hidden_states, + weight=self.weight, + eps=self.variance_epsilon, + static_persistent=True, + gather=True, + use_gemma=self.use_gemma, + ) + elif IS_FLASHINFER_AVAILABLE: from ..custom_ops import (flashinfer_fused_add_rmsnorm, flashinfer_gemma_fused_add_rmsnorm, flashinfer_gemma_rmsnorm, diff --git a/tests/unittest/_torch/thop/parallel/test_cuda_tile_custom_ops.py b/tests/unittest/_torch/thop/parallel/test_cuda_tile_custom_ops.py new file mode 100644 index 0000000000..cd0d4e6cc8 --- /dev/null +++ b/tests/unittest/_torch/thop/parallel/test_cuda_tile_custom_ops.py @@ -0,0 +1,309 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest +import torch +import torch.nn.functional as F + +import tensorrt_llm # noqa: F401 +from tensorrt_llm._torch.cuda_tile_utils import IS_CUDA_TILE_AVAILABLE + +# Skip all tests if CUDA tile is not available +pytestmark = pytest.mark.skipif(not IS_CUDA_TILE_AVAILABLE, reason="CUDA tile is not available") + + +@pytest.fixture(autouse=True) +def prepare_testcase_environment(tmp_path): + """Set random seed and enable deterministic mode before each test.""" + # Enable deterministic execution. + prev_cublas_workspace_config = os.environ.get("CUBLAS_WORKSPACE_CONFIG") + prev_deterministic_mode = torch.are_deterministic_algorithms_enabled() + torch.manual_seed(19260817) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + torch.use_deterministic_algorithms(True) + # Enable cuTile debug dump. + os.environ["CUDA_TILE_DUMP_BYTECODE"] = str(tmp_path / "bytecode") + os.environ["CUDA_TILE_DUMP_TILEIR"] = str(tmp_path / "tileir") + + yield + + # Rewind to previous states. + if prev_cublas_workspace_config is not None: + os.environ["CUBLAS_WORKSPACE_CONFIG"] = prev_cublas_workspace_config + else: + del os.environ["CUBLAS_WORKSPACE_CONFIG"] + torch.use_deterministic_algorithms(prev_deterministic_mode) + del os.environ["CUDA_TILE_DUMP_BYTECODE"] + del os.environ["CUDA_TILE_DUMP_TILEIR"] + + +def reference_rms_norm( + hidden_states: torch.Tensor, + weight: torch.Tensor, + eps: float, + use_gemma: bool, + residual: torch.Tensor | None = None, +): + """ + Reference RMSNorm implementation using PyTorch operations. + + Args: + hidden_states: Input tensor + weight: Weight tensor + eps: Epsilon for numerical stability + use_gemma: Whether to use Gemma-style weight bias (weight + 1) + residual: Optional residual tensor to add before normalization + + Returns: + Tuple of (normalized output, new residual) if residual is provided, + otherwise just normalized output + """ + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + + new_residual = None + if residual is not None: + hidden_states = hidden_states + residual.to(torch.float32) + new_residual = hidden_states.to(input_dtype) + + # Prepare weight with Gemma-style bias if needed + if use_gemma: + weight_to_apply = weight + 1.0 + else: + weight_to_apply = weight + + # Use torch.nn.functional.rms_norm for the normalization + hidden_states = F.rms_norm( + hidden_states, (hidden_states.shape[-1],), weight=weight_to_apply.to(torch.float32), eps=eps + ) + hidden_states = hidden_states.to(input_dtype) + + if residual is not None: + return hidden_states, new_residual + else: + return hidden_states + + +@pytest.mark.parametrize( + "M,N", + [ + (1, 128), + (4, 256), + (16, 512), + (32, 1024), + (64, 2048), + (128, 4096), + (8, 8192), + ], +) +@pytest.mark.parametrize("use_gemma", [False, True]) +@pytest.mark.parametrize("static_persistent", [False, True]) +@pytest.mark.parametrize("gather", [False, True]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_cuda_tile_rms_norm(M, N, use_gemma, static_persistent, gather, dtype): + """Test cuda_tile_rms_norm operator against reference implementation.""" + eps = 1e-5 + + # Create input tensors + x = torch.randn(M, N, dtype=dtype, device="cuda") + weight = torch.randn(N, dtype=dtype, device="cuda") + + # Clone for reference computation + x_ref = x.clone() + weight_ref = weight.clone() + + # Compute reference + ref_output = reference_rms_norm(x_ref, weight_ref, eps, use_gemma) + + # Compute with cuda_tile kernel + cuda_output = torch.ops.trtllm.cuda_tile_rms_norm( + x=x, + weight=weight, + eps=eps, + static_persistent=static_persistent, + gather=gather, + use_gemma=use_gemma, + ) + + # Compare results + # Use relatively loose tolerance due to different computation orders + rtol = 1e-2 if dtype == torch.float16 else 5e-2 + atol = 1e-3 if dtype == torch.float16 else 5e-3 + + torch.testing.assert_close( + cuda_output, + ref_output, + rtol=rtol, + atol=atol, + msg=f"cuda_tile_rms_norm output mismatch for M={M}, N={N}, " + f"use_gemma={use_gemma}, static_persistent={static_persistent}, " + f"gather={gather}, dtype={dtype}", + ) + + +@pytest.mark.parametrize( + "M,N", + [ + (1, 128), + (4, 256), + (16, 512), + (32, 1024), + (64, 2048), + (128, 4096), + (8, 8192), + ], +) +@pytest.mark.parametrize("use_gemma", [False, True]) +@pytest.mark.parametrize("static_persistent", [False, True]) +@pytest.mark.parametrize("gather", [False, True]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_cuda_tile_rms_norm_fuse_residual(M, N, use_gemma, static_persistent, gather, dtype): + """Test cuda_tile_rms_norm_fuse_residual_ operator against reference implementation.""" + eps = 1e-5 + + # Create input tensors + x = torch.randn(M, N, dtype=dtype, device="cuda") + residual = torch.randn(M, N, dtype=dtype, device="cuda") + weight = torch.randn(N, dtype=dtype, device="cuda") + + # Clone for reference computation + x_ref = x.clone() + residual_ref = residual.clone() + weight_ref = weight.clone() + + # Compute reference + ref_output, ref_new_residual = reference_rms_norm( + x_ref, weight_ref, eps, use_gemma, residual_ref + ) + + # Ensure tensors are contiguous for in-place operation + x = x.contiguous() + residual = residual.contiguous() + + # Compute with cuda_tile kernel (in-place operation) + torch.ops.trtllm.cuda_tile_rms_norm_fuse_residual_( + x=x, + residual=residual, + weight=weight, + eps=eps, + static_persistent=static_persistent, + gather=gather, + use_gemma=use_gemma, + ) + + # After in-place operation: + # x contains the normalized output + # residual contains the un-normalized sum (new residual) + + # Compare results + # Use relatively loose tolerance due to different computation orders + rtol = 1e-2 if dtype == torch.float16 else 5e-2 + atol = 1e-3 if dtype == torch.float16 else 5e-3 + + torch.testing.assert_close( + x, + ref_output, + rtol=rtol, + atol=atol, + msg=f"cuda_tile_rms_norm_fuse_residual_ output mismatch for M={M}, N={N}, " + f"use_gemma={use_gemma}, static_persistent={static_persistent}, " + f"gather={gather}, dtype={dtype}", + ) + + torch.testing.assert_close( + residual, + ref_new_residual, + rtol=rtol, + atol=atol, + msg=f"cuda_tile_rms_norm_fuse_residual_ residual mismatch for M={M}, N={N}, " + f"use_gemma={use_gemma}, static_persistent={static_persistent}, " + f"gather={gather}, dtype={dtype}", + ) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_cuda_tile_rms_norm_fuse_residual_inplace(dtype): + """Test that fuse_residual operator truly modifies tensors in-place.""" + eps = 1e-5 + M, N = 16, 256 + + x = torch.randn(M, N, dtype=dtype, device="cuda").contiguous() + residual = torch.randn(M, N, dtype=dtype, device="cuda").contiguous() + weight = torch.randn(N, dtype=dtype, device="cuda") + + # Store original data pointers + x_data_ptr = x.data_ptr() + residual_data_ptr = residual.data_ptr() + + # Call in-place operator + torch.ops.trtllm.cuda_tile_rms_norm_fuse_residual_( + x=x, + residual=residual, + weight=weight, + eps=eps, + static_persistent=True, + gather=True, + use_gemma=False, + ) + + # Verify that tensors were modified in-place (same memory location) + assert x.data_ptr() == x_data_ptr, "x tensor was not modified in-place" + assert residual.data_ptr() == residual_data_ptr, "residual tensor was not modified in-place" + + +def test_cuda_tile_rms_norm_fuse_residual_requires_contiguous(): + """Test that fuse_residual operator requires contiguous tensors.""" + eps = 1e-5 + M, N = 16, 256 + dtype = torch.float16 + + # Create non-contiguous tensors + x = torch.randn(M, N * 2, dtype=dtype, device="cuda")[:, ::2] + residual = torch.randn(M, N, dtype=dtype, device="cuda").contiguous() + weight = torch.randn(N, dtype=dtype, device="cuda") + + assert not x.is_contiguous(), "x should be non-contiguous for this test" + + # Should raise assertion error for non-contiguous x + with pytest.raises(AssertionError, match="x must be contiguous"): + torch.ops.trtllm.cuda_tile_rms_norm_fuse_residual_( + x=x, + residual=residual, + weight=weight, + eps=eps, + static_persistent=True, + gather=True, + use_gemma=False, + ) + + # Create non-contiguous residual + x = torch.randn(M, N, dtype=dtype, device="cuda").contiguous() + residual = torch.randn(M, N * 2, dtype=dtype, device="cuda")[:, ::2] + + assert not residual.is_contiguous(), "residual should be non-contiguous for this test" + + # Should raise assertion error for non-contiguous residual + with pytest.raises(AssertionError, match="residual must be contiguous"): + torch.ops.trtllm.cuda_tile_rms_norm_fuse_residual_( + x=x, + residual=residual, + weight=weight, + eps=eps, + static_persistent=True, + gather=True, + use_gemma=False, + )