[None][feat] Integrate cuda.tile RMS norm kernels (#9725)

Signed-off-by: Rundong (David) Li <davidli@nvidia.com>
Co-authored-by: Jinman Xie <jinmanx@nvidia.com>
Co-authored-by: Alexey Bylinkin <abylinkin@nvidia.com>
Co-authored-by: Qiqi Xiao <qiqix@nvidia.com>
Co-authored-by: Biao Wang <biaow@nvidia.com>
Co-authored-by: Thomas Schmid <thschmid@nvidia.com>
This commit is contained in:
Rundong Li 2026-02-02 19:44:27 +08:00 committed by GitHub
parent 13b0ab9c0e
commit f1b85fea4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 1113 additions and 3 deletions

View File

@ -82,3 +82,5 @@ mistral-common==1.8.6
torchao>=0.14.1
cuda-core
llist
cuda-tile>=1.0.1
nvidia-cuda-tileiras>=13.1

View File

@ -5,6 +5,8 @@ import torch
from torch.fx import Node
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from ..cuda_tile_utils import IS_CUDA_TILE_AVAILABLE
def get_symint_val(i: Union[torch.SymInt | int]):
if isinstance(i, int):
@ -93,6 +95,13 @@ def inplace_info():
},
torch.ops.trtllm.pp_send_tensors.default: {
1: "tensors"
}
},
}
if IS_CUDA_TILE_AVAILABLE:
# cuda.tile availability depends on GPU capability thus runtime check.
inplace_map[
torch.ops.trtllm.cuda_tile_rms_norm_fuse_residual_.default] = {
1: "x",
2: "residual"
}
return inplace_map

View File

@ -0,0 +1,33 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..cuda_tile_utils import IS_CUDA_TILE_AVAILABLE
if IS_CUDA_TILE_AVAILABLE:
from .rms_norm import rms_norm_kernel, rms_norm_kernel_gather, rms_norm_kernel_static_persistent
from .rms_norm_fuse_residual import (
rms_norm_fuse_residual_kernel,
rms_norm_fuse_residual_kernel_gather,
rms_norm_fuse_residual_kernel_static_persistent,
)
__all__ = [
"rms_norm_kernel",
"rms_norm_kernel_gather",
"rms_norm_kernel_static_persistent",
"rms_norm_fuse_residual_kernel",
"rms_norm_fuse_residual_kernel_gather",
"rms_norm_fuse_residual_kernel_static_persistent",
]

View File

@ -0,0 +1,203 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/NVIDIA/cutile-python/blob/main/test/kernels/rms_norm.py
from ..cuda_tile_utils import IS_CUDA_TILE_AVAILABLE
if IS_CUDA_TILE_AVAILABLE:
import cuda.tile as ct
@ct.kernel
def rms_norm_kernel(
x,
w,
out,
Rstd,
N: ct.Constant[int],
eps: ct.Constant[float],
TILE_SIZE: ct.Constant[int],
use_gemma: ct.Constant[bool],
):
"""Standard RMSNorm kernel for non-static persistent mode with tiled loads"""
row = ct.bid(0)
_rms = ct.full((1, TILE_SIZE), 0.0, dtype=ct.float32)
num_tiles = ct.cdiv(x.shape[1], TILE_SIZE)
for j in range(0, num_tiles):
xj = ct.load(
x,
index=(row, j),
shape=(1, TILE_SIZE),
allow_tma=False,
latency=1,
)
xj = ct.astype(xj, ct.float32)
_rms += xj * xj
# Calculate RMS Norm
rms = ct.rsqrt(ct.sum(_rms, axis=1, keepdims=False) / N + eps)
ct.store(Rstd, index=(row,), tile=rms)
for j in range(0, num_tiles):
wj = ct.load(
w,
index=(j,),
shape=(TILE_SIZE,),
allow_tma=False,
latency=1,
)
wj = ct.astype(wj, ct.float32)
# Apply Gemma-style bias if enabled
if use_gemma:
wj = wj + 1.0
xj = ct.load(
x,
index=(row, j),
shape=(1, TILE_SIZE),
allow_tma=False,
latency=1,
)
xj = ct.astype(xj, ct.float32)
yj = xj * rms * wj
yj = ct.astype(yj, x.dtype)
ct.store(
out,
index=(row, j),
tile=yj,
allow_tma=False,
latency=1,
)
@ct.kernel
def rms_norm_kernel_gather(
x,
w,
out,
Rstd,
N: ct.Constant[int],
eps: ct.Constant[float],
TILE_SIZE: ct.Constant[int],
use_gemma: ct.Constant[bool],
):
"""Standard RMSNorm kernel for non-static persistent mode with ptr loads"""
row = ct.bid(0)
_rms = ct.full((TILE_SIZE,), 0.0, dtype=ct.float32)
num_tiles = ct.cdiv(N, TILE_SIZE)
offsets = ct.arange(TILE_SIZE, dtype=ct.int32)
for j in range(0, num_tiles):
offs = j * TILE_SIZE + offsets
xj = ct.gather(x, (row, offs), latency=1)
xj = ct.astype(xj, ct.float32)
_rms += xj * xj
# Calculate RMS Norm
rms = ct.rsqrt(ct.sum(_rms, axis=0, keepdims=False) / N + eps)
ct.scatter(Rstd, row, rms)
for j in range(0, num_tiles):
offs = j * TILE_SIZE + offsets
wj = ct.gather(w, offs, latency=1)
wj = ct.astype(wj, ct.float32)
# Apply Gemma-style bias if enabled
if use_gemma:
wj = wj + 1.0
xj = ct.gather(x, (row, offs), latency=1)
xj = ct.astype(xj, ct.float32)
yj = xj * rms * wj
yj = ct.astype(yj, x.dtype)
ct.scatter(out, (row, offs), yj, latency=1)
@ct.kernel
def rms_norm_kernel_static_persistent(
X, # Input tensor
Y, # Output tensor
W, # Weight tensor
TILE_SIZE_M: ct.Constant[int], # 4 rows per block
TILE_SIZE_N: ct.Constant[int], # columns per block
eps: ct.Constant[float], # Epsilon value
use_gemma: ct.Constant[bool], # Gemma-style weight bias
):
"""
CuTile static persistent RMSNorm kernel that processes multiple blocks per program.
Each program processes multiple blocks in a loop for better efficiency.
"""
# Get program ID
pid = ct.bid(0)
# Infer tensor dimensions from input shape
M = X.shape[0] # Number of rows
N = X.shape[1] # Number of columns
# Calculate upper bound - number of row blocks to process
upper_bound = (M + TILE_SIZE_M - 1) // TILE_SIZE_M
# Load weight vector once (shared across all blocks processed by this program)
w = ct.load(W, index=(0,), shape=(TILE_SIZE_N,))
w = ct.astype(w, ct.float32)
# Apply Gemma-style bias if enabled
if use_gemma:
w = w + 1.0
# Static persistent loop: each program processes multiple blocks
num_tiles_x = ct.num_blocks(0)
for current_block in range(pid, upper_bound, num_tiles_x):
# Load input tile
x = ct.load(
X,
index=(current_block, 0),
shape=(TILE_SIZE_M, TILE_SIZE_N),
latency=10, # +2% perf from this hint
)
x = ct.astype(x, ct.float32)
# Step 1: Compute x^2
x_squared = ct.mul(x, x)
# Step 2: Reduce sum along axis=1 (columns)
x2_sum = ct.sum(x_squared, axis=1, keepdims=True) # Shape: [TILE_SIZE_M, 1]
# Step 3: Compute variance (divide by N)
N_f32 = ct.full((TILE_SIZE_M, 1), N * 1.0, dtype=ct.float32)
variance = ct.truediv(x2_sum, N_f32)
# Step 4: Add epsilon and compute rsqrt
eps_tensor = ct.full((TILE_SIZE_M, 1), eps, dtype=ct.float32)
variance_eps = ct.add(variance, eps_tensor)
rsqrt_var = ct.rsqrt(variance_eps)
# Step 5: Apply normalization
x_normalized = ct.mul(x, rsqrt_var)
# Step 6: Apply linear transformation
# Broadcast weight to match input shape
w_broadcasted = ct.reshape(w, (1, TILE_SIZE_N))
b_broadcasted = ct.full((1, TILE_SIZE_N), 0.0, dtype=ct.float32)
# Apply linear transformation: y = x_normalized * w + b
y = ct.mul(x_normalized, w_broadcasted)
y = ct.add(y, b_broadcasted)
# Convert back to original dtype
y = ct.astype(y, X.dtype)
# Store result
ct.store(
Y,
index=(current_block, 0),
tile=y,
allow_tma=False, # +30% perf
latency=3, # +3% perf from this hint
)

View File

@ -0,0 +1,258 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from rms_norm.py with residual fusion support
from ..cuda_tile_utils import IS_CUDA_TILE_AVAILABLE
if IS_CUDA_TILE_AVAILABLE:
import cuda.tile as ct
@ct.kernel
def rms_norm_fuse_residual_kernel(
x,
residual,
w,
Rstd,
N: ct.Constant[int],
eps: ct.Constant[float],
TILE_SIZE: ct.Constant[int],
use_gemma: ct.Constant[bool],
):
"""RMSNorm kernel with residual fusion for non-static persistent mode with tiled loads"""
row = ct.bid(0)
_rms = ct.full((1, TILE_SIZE), 0.0, dtype=ct.float32)
num_tiles = ct.cdiv(x.shape[1], TILE_SIZE)
# First pass: compute RMS with fused residual addition and store sum to residual
for j in range(0, num_tiles):
xj = ct.load(
x,
index=(row, j),
shape=(1, TILE_SIZE),
allow_tma=False,
latency=1,
)
residual_j = ct.load(
residual,
index=(row, j),
shape=(1, TILE_SIZE),
allow_tma=False,
latency=1,
)
# Fuse residual: convert to float32, add, then use for RMS computation
xj = ct.astype(xj, ct.float32)
residual_j = ct.astype(residual_j, ct.float32)
xj = xj + residual_j
_rms += xj * xj
# Store the sum (new residual) back to residual tensor
xj_stored = ct.astype(xj, residual.dtype)
ct.store(
residual,
index=(row, j),
tile=xj_stored,
allow_tma=False,
latency=1,
)
# Calculate RMS Norm
rms = ct.rsqrt(ct.sum(_rms, axis=1, keepdims=False) / N + eps)
ct.store(Rstd, index=(row,), tile=rms)
# Second pass: load from residual (which now contains the sum), apply normalization, store to x
for j in range(0, num_tiles):
wj = ct.load(
w,
index=(j,),
shape=(TILE_SIZE,),
allow_tma=False,
latency=1,
)
wj = ct.astype(wj, ct.float32)
# Apply Gemma-style bias if enabled
if use_gemma:
wj = wj + 1.0
residual_j = ct.load(
residual,
index=(row, j),
shape=(1, TILE_SIZE),
allow_tma=False,
latency=1,
)
# Load from residual (which now contains x + residual sum)
residual_j = ct.astype(residual_j, ct.float32)
yj = residual_j * rms * wj
yj = ct.astype(yj, x.dtype)
ct.store(
x,
index=(row, j),
tile=yj,
allow_tma=False,
latency=1,
)
@ct.kernel
def rms_norm_fuse_residual_kernel_gather(
x,
residual,
w,
Rstd,
N: ct.Constant[int],
eps: ct.Constant[float],
TILE_SIZE: ct.Constant[int],
use_gemma: ct.Constant[bool],
):
"""RMSNorm kernel with residual fusion for non-static persistent mode with ptr loads"""
row = ct.bid(0)
_rms = ct.full((TILE_SIZE,), 0.0, dtype=ct.float32)
num_tiles = ct.cdiv(N, TILE_SIZE)
offsets = ct.arange(TILE_SIZE, dtype=ct.int32)
# First pass: compute RMS with fused residual addition and store sum to residual
for j in range(0, num_tiles):
offs = j * TILE_SIZE + offsets
xj = ct.gather(x, (row, offs), latency=1)
residual_j = ct.gather(residual, (row, offs), latency=1)
# Fuse residual: convert to float32, add, then use for RMS computation
xj = ct.astype(xj, ct.float32)
residual_j = ct.astype(residual_j, ct.float32)
xj = xj + residual_j
_rms += xj * xj
# Store the sum (new residual) back to residual tensor
xj_stored = ct.astype(xj, residual.dtype)
ct.scatter(residual, (row, offs), xj_stored, latency=1)
# Calculate RMS Norm
rms = ct.rsqrt(ct.sum(_rms, axis=0, keepdims=False) / N + eps)
ct.scatter(Rstd, row, rms)
# Second pass: load from residual (which now contains the sum), apply normalization, store to x
for j in range(0, num_tiles):
offs = j * TILE_SIZE + offsets
wj = ct.gather(w, offs, latency=1)
wj = ct.astype(wj, ct.float32)
# Apply Gemma-style bias if enabled
if use_gemma:
wj = wj + 1.0
residual_j = ct.gather(residual, (row, offs), latency=1)
# Load from residual (which now contains x + residual sum)
residual_j = ct.astype(residual_j, ct.float32)
yj = residual_j * rms * wj
yj = ct.astype(yj, x.dtype)
ct.scatter(x, (row, offs), yj, latency=1)
@ct.kernel
def rms_norm_fuse_residual_kernel_static_persistent(
X, # Input tensor
Residual, # Residual tensor
W, # Weight tensor
TILE_SIZE_M: ct.Constant[int], # 4 rows per block
TILE_SIZE_N: ct.Constant[int], # columns per block
eps: ct.Constant[float], # Epsilon value
use_gemma: ct.Constant[bool], # Gemma-style weight bias
):
"""
CuTile static persistent RMSNorm kernel with residual fusion that processes multiple blocks per program.
Each program processes multiple blocks in a loop for better efficiency.
"""
# Get program ID
pid = ct.bid(0)
# Infer tensor dimensions from input shape
M = X.shape[0] # Number of rows
N = X.shape[1] # Number of columns
# Calculate upper bound - number of row blocks to process
upper_bound = (M + TILE_SIZE_M - 1) // TILE_SIZE_M
# Load weight vector once (shared across all blocks processed by this program)
w = ct.load(W, index=(0,), shape=(TILE_SIZE_N,))
w = ct.astype(w, ct.float32)
# Apply Gemma-style bias if enabled
if use_gemma:
w = w + 1.0
# Static persistent loop: each program processes multiple blocks
num_tiles_x = ct.num_blocks(0)
for current_block in range(pid, upper_bound, num_tiles_x):
# Load input tile
x = ct.load(
X,
index=(current_block, 0),
shape=(TILE_SIZE_M, TILE_SIZE_N),
latency=10, # +2% perf from this hint
)
# Load residual tile
residual = ct.load(
Residual,
index=(current_block, 0),
shape=(TILE_SIZE_M, TILE_SIZE_N),
latency=10,
)
# Fuse residual: convert to float32 and add
x = ct.astype(x, ct.float32)
residual = ct.astype(residual, ct.float32)
x = ct.add(x, residual)
# Store the sum (new residual) back to Residual tensor
x_stored = ct.astype(x, Residual.dtype)
ct.store(
Residual,
index=(current_block, 0),
tile=x_stored,
allow_tma=False,
latency=3,
)
# Step 1: Compute x^2
x_squared = ct.mul(x, x)
# Step 2: Reduce sum along axis=1 (columns)
x2_sum = ct.sum(x_squared, axis=1, keepdims=True) # Shape: [TILE_SIZE_M, 1]
# Step 3: Compute variance (divide by N)
N_f32 = ct.full((TILE_SIZE_M, 1), N * 1.0, dtype=ct.float32)
variance = ct.truediv(x2_sum, N_f32)
# Step 4: Add epsilon and compute rsqrt
eps_tensor = ct.full((TILE_SIZE_M, 1), eps, dtype=ct.float32)
variance_eps = ct.add(variance, eps_tensor)
rsqrt_var = ct.rsqrt(variance_eps)
# Step 5: Apply normalization
x_normalized = ct.mul(x, rsqrt_var)
# Step 6: Apply linear transformation
# Broadcast weight to match input shape
w_broadcasted = ct.reshape(w, (1, TILE_SIZE_N))
b_broadcasted = ct.full((1, TILE_SIZE_N), 0.0, dtype=ct.float32)
# Apply linear transformation: y = x_normalized * w + b
y = ct.mul(x_normalized, w_broadcasted)
y = ct.add(y, b_broadcasted)
# Convert back to original dtype and store to X (new hidden_states)
y = ct.astype(y, X.dtype)
# Store result to X
ct.store(
X,
index=(current_block, 0),
tile=y,
allow_tma=False, # +30% perf
latency=3, # +3% perf from this hint
)

View File

@ -0,0 +1,70 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import importlib.metadata
import os
import platform
import shutil
import torch
from ..logger import logger
IS_CUDA_TILE_AVAILABLE = False
@functools.lru_cache()
def next_power_of_2(n: int):
"""Return the smallest power of 2 greater than or equal to n"""
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n |= n >> 32
n += 1
return n
def ceil_div(a, b):
return (a + b - 1) // b
if platform.system() != "Windows":
try:
import cuda.tile # noqa: F401
except ImportError:
logger.warning("cuda-tile package not found, TileIR kernels will not be available")
else:
if (cc := torch.cuda.get_device_properties()) and (cc.major, cc.minor) < (10, 0):
logger.warning(
f"TileIR requires compute capability 10.0 or higher, but the current device has "
f"{cc.major}.{cc.minor}. TileIR kernels will not be available"
)
elif shutil.which("tileiras") is not None:
IS_CUDA_TILE_AVAILABLE = True
# For systems without tileiras installed, try to locate from nvidia-cuda-tileiras package.
elif tileiras_files := importlib.metadata.files("nvidia-cuda-tileiras"):
for pkg_file in tileiras_files:
if pkg_file.name == "tileiras":
tileiras_dir = pkg_file.locate().parent
os.environ["PATH"] = f"{os.environ['PATH']}:{tileiras_dir}"
break
assert shutil.which("tileiras") is not None
IS_CUDA_TILE_AVAILABLE = True
else:
logger.warning("tileiras compiler not found, TileIR kernels will not be available")

View File

@ -1,3 +1,4 @@
from ..cuda_tile_utils import IS_CUDA_TILE_AVAILABLE
from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
from ..modules.attention import attn_custom_op_inplace, mla_custom_op_inplace
@ -38,3 +39,11 @@ if IS_CUTLASS_DSL_AVAILABLE:
__all__ += [
'cute_dsl_nvfp4_gemm_blackwell',
]
if IS_CUDA_TILE_AVAILABLE:
from .cuda_tile_custom_ops import (cuda_tile_rms_norm,
cuda_tile_rms_norm_fuse_residual_)
__all__ += [
'cuda_tile_rms_norm',
'cuda_tile_rms_norm_fuse_residual_',
]

View File

@ -0,0 +1,188 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/NVIDIA/cutile-python/blob/main/test/bench_rms_norm.py
import torch
from ..cuda_tile_utils import IS_CUDA_TILE_AVAILABLE
if IS_CUDA_TILE_AVAILABLE:
import cuda.tile as ct
from ..cuda_tile_kernels import (
rms_norm_fuse_residual_kernel,
rms_norm_fuse_residual_kernel_gather,
rms_norm_fuse_residual_kernel_static_persistent,
rms_norm_kernel,
rms_norm_kernel_gather,
rms_norm_kernel_static_persistent,
)
from ..cuda_tile_utils import ceil_div, next_power_of_2
@torch.library.custom_op("trtllm::cuda_tile_rms_norm", mutates_args=())
def cuda_tile_rms_norm(
x: torch.Tensor,
weight: torch.Tensor,
eps: float,
static_persistent: bool,
gather: bool,
use_gemma: bool,
) -> torch.Tensor:
x = x.contiguous()
weight = weight.contiguous()
# Allocate output tensor
y = torch.empty_like(x)
M, N = x.shape
if static_persistent:
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
TILE_SIZE_M = 4 # Default value, could be made configurable
TILE_SIZE_N = next_power_of_2(N)
# Other tile sizes are more optimal when other dimension is too large/too small
if TILE_SIZE_N <= 1024:
TILE_SIZE_M = 16
elif TILE_SIZE_N >= 16384:
TILE_SIZE_M = 2
grid_size = min(
NUM_SMS,
ceil_div(M, TILE_SIZE_M) * ceil_div(N, TILE_SIZE_N),
)
grid = (grid_size,)
ct.launch(
torch.cuda.current_stream(),
grid,
rms_norm_kernel_static_persistent,
(
x,
y,
weight,
TILE_SIZE_M,
TILE_SIZE_N,
eps,
use_gemma,
),
)
else:
# Standard RMSNorm kernel
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
MAX_FUSED_SIZE = 2048 // x.element_size()
TILE_SIZE = min(MAX_FUSED_SIZE, next_power_of_2(N))
grid = (M,)
kernel = rms_norm_kernel_gather if gather else rms_norm_kernel
ct.launch(
torch.cuda.current_stream(),
grid,
kernel,
(
x,
weight,
y,
rstd,
N,
eps,
TILE_SIZE,
use_gemma,
),
)
return y.view(*x.shape)
@cuda_tile_rms_norm.register_fake
def _(
x: torch.Tensor,
weight: torch.Tensor,
eps: float,
static_persistent: bool,
gather: bool,
use_gemma: bool,
) -> torch.Tensor:
return torch.empty_like(x.contiguous())
@torch.library.custom_op(
"trtllm::cuda_tile_rms_norm_fuse_residual_",
mutates_args=("x", "residual"),
)
def cuda_tile_rms_norm_fuse_residual_(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float,
static_persistent: bool,
gather: bool,
use_gemma: bool,
) -> None:
assert x.is_contiguous(), "x must be contiguous for in-place operation"
assert residual.is_contiguous(), "residual must be contiguous for in-place operation"
weight = weight.contiguous()
M, N = x.shape
if static_persistent:
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
TILE_SIZE_M = 4 # Default value, could be made configurable
TILE_SIZE_N = next_power_of_2(N)
# Other tile sizes are more optimal when other dimension is too large/too small
if TILE_SIZE_N <= 1024:
TILE_SIZE_M = 16
elif TILE_SIZE_N >= 16384:
TILE_SIZE_M = 2
grid_size = min(
NUM_SMS,
ceil_div(M, TILE_SIZE_M) * ceil_div(N, TILE_SIZE_N),
)
grid = (grid_size,)
ct.launch(
torch.cuda.current_stream(),
grid,
rms_norm_fuse_residual_kernel_static_persistent,
(
x,
residual,
weight,
TILE_SIZE_M,
TILE_SIZE_N,
eps,
use_gemma,
),
)
else:
# Standard RMSNorm kernel with residual fusion
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
MAX_FUSED_SIZE = 2048 // x.element_size()
TILE_SIZE = min(MAX_FUSED_SIZE, next_power_of_2(N))
grid = (M,)
kernel = (
rms_norm_fuse_residual_kernel_gather if gather else rms_norm_fuse_residual_kernel
)
ct.launch(
torch.cuda.current_stream(),
grid,
kernel,
(
x,
residual,
weight,
rstd,
N,
eps,
TILE_SIZE,
use_gemma,
),
)

View File

@ -20,6 +20,7 @@ from typing import Optional, Tuple, TypeAlias, Union, cast
import torch
from torch import nn
from ..cuda_tile_utils import IS_CUDA_TILE_AVAILABLE
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
from ..utils import Fp4QuantizedTensor
@ -39,6 +40,7 @@ class RMSNorm(nn.Module):
has_weights: bool = True,
use_gemma: bool = False,
quantize_type: Optional[str] = None,
use_cuda_tile: bool = False,
):
super().__init__()
@ -49,6 +51,10 @@ class RMSNorm(nn.Module):
raise NotImplementedError(
f"Quantize type {quantize_type} not implemented in RMSNorm")
self.is_nvfp4 = quantize_type == "nvfp4"
if use_cuda_tile and not IS_CUDA_TILE_AVAILABLE:
raise ValueError(
"cuda.tile is not available, please install cuda-tile pypi package"
)
if has_weights:
if not use_gemma:
@ -65,6 +71,7 @@ class RMSNorm(nn.Module):
persistent=False)
self.variance_epsilon = eps
self.use_gemma = use_gemma
self.use_cuda_tile = use_cuda_tile
def forward(
self,
@ -127,8 +134,30 @@ class RMSNorm(nn.Module):
hidden_states_fused = Fp4QuantizedTensor(normed_fp4_u8, sf_fused)
return (hidden_states_fused,
residual_out) if has_residual else hidden_states_fused
if IS_FLASHINFER_AVAILABLE:
elif self.use_cuda_tile:
if isinstance(residual, torch.Tensor):
# Use fused residual kernel
hidden_states = hidden_states.contiguous()
residual = residual.contiguous()
torch.ops.trtllm.cuda_tile_rms_norm_fuse_residual_(
x=hidden_states,
residual=residual,
weight=self.weight,
eps=self.variance_epsilon,
static_persistent=True,
gather=True,
use_gemma=self.use_gemma,
)
else:
hidden_states = torch.ops.trtllm.cuda_tile_rms_norm(
x=hidden_states,
weight=self.weight,
eps=self.variance_epsilon,
static_persistent=True,
gather=True,
use_gemma=self.use_gemma,
)
elif IS_FLASHINFER_AVAILABLE:
from ..custom_ops import (flashinfer_fused_add_rmsnorm,
flashinfer_gemma_fused_add_rmsnorm,
flashinfer_gemma_rmsnorm,

View File

@ -0,0 +1,309 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
import torch.nn.functional as F
import tensorrt_llm # noqa: F401
from tensorrt_llm._torch.cuda_tile_utils import IS_CUDA_TILE_AVAILABLE
# Skip all tests if CUDA tile is not available
pytestmark = pytest.mark.skipif(not IS_CUDA_TILE_AVAILABLE, reason="CUDA tile is not available")
@pytest.fixture(autouse=True)
def prepare_testcase_environment(tmp_path):
"""Set random seed and enable deterministic mode before each test."""
# Enable deterministic execution.
prev_cublas_workspace_config = os.environ.get("CUBLAS_WORKSPACE_CONFIG")
prev_deterministic_mode = torch.are_deterministic_algorithms_enabled()
torch.manual_seed(19260817)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.use_deterministic_algorithms(True)
# Enable cuTile debug dump.
os.environ["CUDA_TILE_DUMP_BYTECODE"] = str(tmp_path / "bytecode")
os.environ["CUDA_TILE_DUMP_TILEIR"] = str(tmp_path / "tileir")
yield
# Rewind to previous states.
if prev_cublas_workspace_config is not None:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = prev_cublas_workspace_config
else:
del os.environ["CUBLAS_WORKSPACE_CONFIG"]
torch.use_deterministic_algorithms(prev_deterministic_mode)
del os.environ["CUDA_TILE_DUMP_BYTECODE"]
del os.environ["CUDA_TILE_DUMP_TILEIR"]
def reference_rms_norm(
hidden_states: torch.Tensor,
weight: torch.Tensor,
eps: float,
use_gemma: bool,
residual: torch.Tensor | None = None,
):
"""
Reference RMSNorm implementation using PyTorch operations.
Args:
hidden_states: Input tensor
weight: Weight tensor
eps: Epsilon for numerical stability
use_gemma: Whether to use Gemma-style weight bias (weight + 1)
residual: Optional residual tensor to add before normalization
Returns:
Tuple of (normalized output, new residual) if residual is provided,
otherwise just normalized output
"""
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
new_residual = None
if residual is not None:
hidden_states = hidden_states + residual.to(torch.float32)
new_residual = hidden_states.to(input_dtype)
# Prepare weight with Gemma-style bias if needed
if use_gemma:
weight_to_apply = weight + 1.0
else:
weight_to_apply = weight
# Use torch.nn.functional.rms_norm for the normalization
hidden_states = F.rms_norm(
hidden_states, (hidden_states.shape[-1],), weight=weight_to_apply.to(torch.float32), eps=eps
)
hidden_states = hidden_states.to(input_dtype)
if residual is not None:
return hidden_states, new_residual
else:
return hidden_states
@pytest.mark.parametrize(
"M,N",
[
(1, 128),
(4, 256),
(16, 512),
(32, 1024),
(64, 2048),
(128, 4096),
(8, 8192),
],
)
@pytest.mark.parametrize("use_gemma", [False, True])
@pytest.mark.parametrize("static_persistent", [False, True])
@pytest.mark.parametrize("gather", [False, True])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_cuda_tile_rms_norm(M, N, use_gemma, static_persistent, gather, dtype):
"""Test cuda_tile_rms_norm operator against reference implementation."""
eps = 1e-5
# Create input tensors
x = torch.randn(M, N, dtype=dtype, device="cuda")
weight = torch.randn(N, dtype=dtype, device="cuda")
# Clone for reference computation
x_ref = x.clone()
weight_ref = weight.clone()
# Compute reference
ref_output = reference_rms_norm(x_ref, weight_ref, eps, use_gemma)
# Compute with cuda_tile kernel
cuda_output = torch.ops.trtllm.cuda_tile_rms_norm(
x=x,
weight=weight,
eps=eps,
static_persistent=static_persistent,
gather=gather,
use_gemma=use_gemma,
)
# Compare results
# Use relatively loose tolerance due to different computation orders
rtol = 1e-2 if dtype == torch.float16 else 5e-2
atol = 1e-3 if dtype == torch.float16 else 5e-3
torch.testing.assert_close(
cuda_output,
ref_output,
rtol=rtol,
atol=atol,
msg=f"cuda_tile_rms_norm output mismatch for M={M}, N={N}, "
f"use_gemma={use_gemma}, static_persistent={static_persistent}, "
f"gather={gather}, dtype={dtype}",
)
@pytest.mark.parametrize(
"M,N",
[
(1, 128),
(4, 256),
(16, 512),
(32, 1024),
(64, 2048),
(128, 4096),
(8, 8192),
],
)
@pytest.mark.parametrize("use_gemma", [False, True])
@pytest.mark.parametrize("static_persistent", [False, True])
@pytest.mark.parametrize("gather", [False, True])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_cuda_tile_rms_norm_fuse_residual(M, N, use_gemma, static_persistent, gather, dtype):
"""Test cuda_tile_rms_norm_fuse_residual_ operator against reference implementation."""
eps = 1e-5
# Create input tensors
x = torch.randn(M, N, dtype=dtype, device="cuda")
residual = torch.randn(M, N, dtype=dtype, device="cuda")
weight = torch.randn(N, dtype=dtype, device="cuda")
# Clone for reference computation
x_ref = x.clone()
residual_ref = residual.clone()
weight_ref = weight.clone()
# Compute reference
ref_output, ref_new_residual = reference_rms_norm(
x_ref, weight_ref, eps, use_gemma, residual_ref
)
# Ensure tensors are contiguous for in-place operation
x = x.contiguous()
residual = residual.contiguous()
# Compute with cuda_tile kernel (in-place operation)
torch.ops.trtllm.cuda_tile_rms_norm_fuse_residual_(
x=x,
residual=residual,
weight=weight,
eps=eps,
static_persistent=static_persistent,
gather=gather,
use_gemma=use_gemma,
)
# After in-place operation:
# x contains the normalized output
# residual contains the un-normalized sum (new residual)
# Compare results
# Use relatively loose tolerance due to different computation orders
rtol = 1e-2 if dtype == torch.float16 else 5e-2
atol = 1e-3 if dtype == torch.float16 else 5e-3
torch.testing.assert_close(
x,
ref_output,
rtol=rtol,
atol=atol,
msg=f"cuda_tile_rms_norm_fuse_residual_ output mismatch for M={M}, N={N}, "
f"use_gemma={use_gemma}, static_persistent={static_persistent}, "
f"gather={gather}, dtype={dtype}",
)
torch.testing.assert_close(
residual,
ref_new_residual,
rtol=rtol,
atol=atol,
msg=f"cuda_tile_rms_norm_fuse_residual_ residual mismatch for M={M}, N={N}, "
f"use_gemma={use_gemma}, static_persistent={static_persistent}, "
f"gather={gather}, dtype={dtype}",
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_cuda_tile_rms_norm_fuse_residual_inplace(dtype):
"""Test that fuse_residual operator truly modifies tensors in-place."""
eps = 1e-5
M, N = 16, 256
x = torch.randn(M, N, dtype=dtype, device="cuda").contiguous()
residual = torch.randn(M, N, dtype=dtype, device="cuda").contiguous()
weight = torch.randn(N, dtype=dtype, device="cuda")
# Store original data pointers
x_data_ptr = x.data_ptr()
residual_data_ptr = residual.data_ptr()
# Call in-place operator
torch.ops.trtllm.cuda_tile_rms_norm_fuse_residual_(
x=x,
residual=residual,
weight=weight,
eps=eps,
static_persistent=True,
gather=True,
use_gemma=False,
)
# Verify that tensors were modified in-place (same memory location)
assert x.data_ptr() == x_data_ptr, "x tensor was not modified in-place"
assert residual.data_ptr() == residual_data_ptr, "residual tensor was not modified in-place"
def test_cuda_tile_rms_norm_fuse_residual_requires_contiguous():
"""Test that fuse_residual operator requires contiguous tensors."""
eps = 1e-5
M, N = 16, 256
dtype = torch.float16
# Create non-contiguous tensors
x = torch.randn(M, N * 2, dtype=dtype, device="cuda")[:, ::2]
residual = torch.randn(M, N, dtype=dtype, device="cuda").contiguous()
weight = torch.randn(N, dtype=dtype, device="cuda")
assert not x.is_contiguous(), "x should be non-contiguous for this test"
# Should raise assertion error for non-contiguous x
with pytest.raises(AssertionError, match="x must be contiguous"):
torch.ops.trtllm.cuda_tile_rms_norm_fuse_residual_(
x=x,
residual=residual,
weight=weight,
eps=eps,
static_persistent=True,
gather=True,
use_gemma=False,
)
# Create non-contiguous residual
x = torch.randn(M, N, dtype=dtype, device="cuda").contiguous()
residual = torch.randn(M, N * 2, dtype=dtype, device="cuda")[:, ::2]
assert not residual.is_contiguous(), "residual should be non-contiguous for this test"
# Should raise assertion error for non-contiguous residual
with pytest.raises(AssertionError, match="residual must be contiguous"):
torch.ops.trtllm.cuda_tile_rms_norm_fuse_residual_(
x=x,
residual=residual,
weight=weight,
eps=eps,
static_persistent=True,
gather=True,
use_gemma=False,
)