mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[None][feat] Integrate cuda.tile RMS norm kernels (#9725)
Signed-off-by: Rundong (David) Li <davidli@nvidia.com> Co-authored-by: Jinman Xie <jinmanx@nvidia.com> Co-authored-by: Alexey Bylinkin <abylinkin@nvidia.com> Co-authored-by: Qiqi Xiao <qiqix@nvidia.com> Co-authored-by: Biao Wang <biaow@nvidia.com> Co-authored-by: Thomas Schmid <thschmid@nvidia.com>
This commit is contained in:
parent
13b0ab9c0e
commit
f1b85fea4c
@ -82,3 +82,5 @@ mistral-common==1.8.6
|
||||
torchao>=0.14.1
|
||||
cuda-core
|
||||
llist
|
||||
cuda-tile>=1.0.1
|
||||
nvidia-cuda-tileiras>=13.1
|
||||
|
||||
@ -5,6 +5,8 @@ import torch
|
||||
from torch.fx import Node
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
|
||||
from ..cuda_tile_utils import IS_CUDA_TILE_AVAILABLE
|
||||
|
||||
|
||||
def get_symint_val(i: Union[torch.SymInt | int]):
|
||||
if isinstance(i, int):
|
||||
@ -93,6 +95,13 @@ def inplace_info():
|
||||
},
|
||||
torch.ops.trtllm.pp_send_tensors.default: {
|
||||
1: "tensors"
|
||||
}
|
||||
},
|
||||
}
|
||||
if IS_CUDA_TILE_AVAILABLE:
|
||||
# cuda.tile availability depends on GPU capability thus runtime check.
|
||||
inplace_map[
|
||||
torch.ops.trtllm.cuda_tile_rms_norm_fuse_residual_.default] = {
|
||||
1: "x",
|
||||
2: "residual"
|
||||
}
|
||||
return inplace_map
|
||||
|
||||
33
tensorrt_llm/_torch/cuda_tile_kernels/__init__.py
Normal file
33
tensorrt_llm/_torch/cuda_tile_kernels/__init__.py
Normal file
@ -0,0 +1,33 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..cuda_tile_utils import IS_CUDA_TILE_AVAILABLE
|
||||
|
||||
if IS_CUDA_TILE_AVAILABLE:
|
||||
from .rms_norm import rms_norm_kernel, rms_norm_kernel_gather, rms_norm_kernel_static_persistent
|
||||
from .rms_norm_fuse_residual import (
|
||||
rms_norm_fuse_residual_kernel,
|
||||
rms_norm_fuse_residual_kernel_gather,
|
||||
rms_norm_fuse_residual_kernel_static_persistent,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"rms_norm_kernel",
|
||||
"rms_norm_kernel_gather",
|
||||
"rms_norm_kernel_static_persistent",
|
||||
"rms_norm_fuse_residual_kernel",
|
||||
"rms_norm_fuse_residual_kernel_gather",
|
||||
"rms_norm_fuse_residual_kernel_static_persistent",
|
||||
]
|
||||
203
tensorrt_llm/_torch/cuda_tile_kernels/rms_norm.py
Normal file
203
tensorrt_llm/_torch/cuda_tile_kernels/rms_norm.py
Normal file
@ -0,0 +1,203 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Adapted from https://github.com/NVIDIA/cutile-python/blob/main/test/kernels/rms_norm.py
|
||||
from ..cuda_tile_utils import IS_CUDA_TILE_AVAILABLE
|
||||
|
||||
if IS_CUDA_TILE_AVAILABLE:
|
||||
import cuda.tile as ct
|
||||
|
||||
@ct.kernel
|
||||
def rms_norm_kernel(
|
||||
x,
|
||||
w,
|
||||
out,
|
||||
Rstd,
|
||||
N: ct.Constant[int],
|
||||
eps: ct.Constant[float],
|
||||
TILE_SIZE: ct.Constant[int],
|
||||
use_gemma: ct.Constant[bool],
|
||||
):
|
||||
"""Standard RMSNorm kernel for non-static persistent mode with tiled loads"""
|
||||
row = ct.bid(0)
|
||||
_rms = ct.full((1, TILE_SIZE), 0.0, dtype=ct.float32)
|
||||
num_tiles = ct.cdiv(x.shape[1], TILE_SIZE)
|
||||
|
||||
for j in range(0, num_tiles):
|
||||
xj = ct.load(
|
||||
x,
|
||||
index=(row, j),
|
||||
shape=(1, TILE_SIZE),
|
||||
allow_tma=False,
|
||||
latency=1,
|
||||
)
|
||||
xj = ct.astype(xj, ct.float32)
|
||||
_rms += xj * xj
|
||||
|
||||
# Calculate RMS Norm
|
||||
rms = ct.rsqrt(ct.sum(_rms, axis=1, keepdims=False) / N + eps)
|
||||
ct.store(Rstd, index=(row,), tile=rms)
|
||||
|
||||
for j in range(0, num_tiles):
|
||||
wj = ct.load(
|
||||
w,
|
||||
index=(j,),
|
||||
shape=(TILE_SIZE,),
|
||||
allow_tma=False,
|
||||
latency=1,
|
||||
)
|
||||
wj = ct.astype(wj, ct.float32)
|
||||
# Apply Gemma-style bias if enabled
|
||||
if use_gemma:
|
||||
wj = wj + 1.0
|
||||
xj = ct.load(
|
||||
x,
|
||||
index=(row, j),
|
||||
shape=(1, TILE_SIZE),
|
||||
allow_tma=False,
|
||||
latency=1,
|
||||
)
|
||||
xj = ct.astype(xj, ct.float32)
|
||||
yj = xj * rms * wj
|
||||
yj = ct.astype(yj, x.dtype)
|
||||
ct.store(
|
||||
out,
|
||||
index=(row, j),
|
||||
tile=yj,
|
||||
allow_tma=False,
|
||||
latency=1,
|
||||
)
|
||||
|
||||
@ct.kernel
|
||||
def rms_norm_kernel_gather(
|
||||
x,
|
||||
w,
|
||||
out,
|
||||
Rstd,
|
||||
N: ct.Constant[int],
|
||||
eps: ct.Constant[float],
|
||||
TILE_SIZE: ct.Constant[int],
|
||||
use_gemma: ct.Constant[bool],
|
||||
):
|
||||
"""Standard RMSNorm kernel for non-static persistent mode with ptr loads"""
|
||||
row = ct.bid(0)
|
||||
_rms = ct.full((TILE_SIZE,), 0.0, dtype=ct.float32)
|
||||
num_tiles = ct.cdiv(N, TILE_SIZE)
|
||||
offsets = ct.arange(TILE_SIZE, dtype=ct.int32)
|
||||
|
||||
for j in range(0, num_tiles):
|
||||
offs = j * TILE_SIZE + offsets
|
||||
xj = ct.gather(x, (row, offs), latency=1)
|
||||
xj = ct.astype(xj, ct.float32)
|
||||
_rms += xj * xj
|
||||
|
||||
# Calculate RMS Norm
|
||||
rms = ct.rsqrt(ct.sum(_rms, axis=0, keepdims=False) / N + eps)
|
||||
ct.scatter(Rstd, row, rms)
|
||||
|
||||
for j in range(0, num_tiles):
|
||||
offs = j * TILE_SIZE + offsets
|
||||
wj = ct.gather(w, offs, latency=1)
|
||||
wj = ct.astype(wj, ct.float32)
|
||||
# Apply Gemma-style bias if enabled
|
||||
if use_gemma:
|
||||
wj = wj + 1.0
|
||||
xj = ct.gather(x, (row, offs), latency=1)
|
||||
xj = ct.astype(xj, ct.float32)
|
||||
yj = xj * rms * wj
|
||||
yj = ct.astype(yj, x.dtype)
|
||||
ct.scatter(out, (row, offs), yj, latency=1)
|
||||
|
||||
@ct.kernel
|
||||
def rms_norm_kernel_static_persistent(
|
||||
X, # Input tensor
|
||||
Y, # Output tensor
|
||||
W, # Weight tensor
|
||||
TILE_SIZE_M: ct.Constant[int], # 4 rows per block
|
||||
TILE_SIZE_N: ct.Constant[int], # columns per block
|
||||
eps: ct.Constant[float], # Epsilon value
|
||||
use_gemma: ct.Constant[bool], # Gemma-style weight bias
|
||||
):
|
||||
"""
|
||||
CuTile static persistent RMSNorm kernel that processes multiple blocks per program.
|
||||
Each program processes multiple blocks in a loop for better efficiency.
|
||||
"""
|
||||
# Get program ID
|
||||
pid = ct.bid(0)
|
||||
|
||||
# Infer tensor dimensions from input shape
|
||||
M = X.shape[0] # Number of rows
|
||||
N = X.shape[1] # Number of columns
|
||||
|
||||
# Calculate upper bound - number of row blocks to process
|
||||
upper_bound = (M + TILE_SIZE_M - 1) // TILE_SIZE_M
|
||||
|
||||
# Load weight vector once (shared across all blocks processed by this program)
|
||||
w = ct.load(W, index=(0,), shape=(TILE_SIZE_N,))
|
||||
w = ct.astype(w, ct.float32)
|
||||
# Apply Gemma-style bias if enabled
|
||||
if use_gemma:
|
||||
w = w + 1.0
|
||||
|
||||
# Static persistent loop: each program processes multiple blocks
|
||||
num_tiles_x = ct.num_blocks(0)
|
||||
for current_block in range(pid, upper_bound, num_tiles_x):
|
||||
# Load input tile
|
||||
x = ct.load(
|
||||
X,
|
||||
index=(current_block, 0),
|
||||
shape=(TILE_SIZE_M, TILE_SIZE_N),
|
||||
latency=10, # +2% perf from this hint
|
||||
)
|
||||
x = ct.astype(x, ct.float32)
|
||||
|
||||
# Step 1: Compute x^2
|
||||
x_squared = ct.mul(x, x)
|
||||
|
||||
# Step 2: Reduce sum along axis=1 (columns)
|
||||
x2_sum = ct.sum(x_squared, axis=1, keepdims=True) # Shape: [TILE_SIZE_M, 1]
|
||||
|
||||
# Step 3: Compute variance (divide by N)
|
||||
N_f32 = ct.full((TILE_SIZE_M, 1), N * 1.0, dtype=ct.float32)
|
||||
variance = ct.truediv(x2_sum, N_f32)
|
||||
|
||||
# Step 4: Add epsilon and compute rsqrt
|
||||
eps_tensor = ct.full((TILE_SIZE_M, 1), eps, dtype=ct.float32)
|
||||
variance_eps = ct.add(variance, eps_tensor)
|
||||
rsqrt_var = ct.rsqrt(variance_eps)
|
||||
|
||||
# Step 5: Apply normalization
|
||||
x_normalized = ct.mul(x, rsqrt_var)
|
||||
|
||||
# Step 6: Apply linear transformation
|
||||
# Broadcast weight to match input shape
|
||||
w_broadcasted = ct.reshape(w, (1, TILE_SIZE_N))
|
||||
b_broadcasted = ct.full((1, TILE_SIZE_N), 0.0, dtype=ct.float32)
|
||||
|
||||
# Apply linear transformation: y = x_normalized * w + b
|
||||
y = ct.mul(x_normalized, w_broadcasted)
|
||||
y = ct.add(y, b_broadcasted)
|
||||
|
||||
# Convert back to original dtype
|
||||
y = ct.astype(y, X.dtype)
|
||||
|
||||
# Store result
|
||||
ct.store(
|
||||
Y,
|
||||
index=(current_block, 0),
|
||||
tile=y,
|
||||
allow_tma=False, # +30% perf
|
||||
latency=3, # +3% perf from this hint
|
||||
)
|
||||
258
tensorrt_llm/_torch/cuda_tile_kernels/rms_norm_fuse_residual.py
Normal file
258
tensorrt_llm/_torch/cuda_tile_kernels/rms_norm_fuse_residual.py
Normal file
@ -0,0 +1,258 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Adapted from rms_norm.py with residual fusion support
|
||||
from ..cuda_tile_utils import IS_CUDA_TILE_AVAILABLE
|
||||
|
||||
if IS_CUDA_TILE_AVAILABLE:
|
||||
import cuda.tile as ct
|
||||
|
||||
@ct.kernel
|
||||
def rms_norm_fuse_residual_kernel(
|
||||
x,
|
||||
residual,
|
||||
w,
|
||||
Rstd,
|
||||
N: ct.Constant[int],
|
||||
eps: ct.Constant[float],
|
||||
TILE_SIZE: ct.Constant[int],
|
||||
use_gemma: ct.Constant[bool],
|
||||
):
|
||||
"""RMSNorm kernel with residual fusion for non-static persistent mode with tiled loads"""
|
||||
row = ct.bid(0)
|
||||
_rms = ct.full((1, TILE_SIZE), 0.0, dtype=ct.float32)
|
||||
num_tiles = ct.cdiv(x.shape[1], TILE_SIZE)
|
||||
|
||||
# First pass: compute RMS with fused residual addition and store sum to residual
|
||||
for j in range(0, num_tiles):
|
||||
xj = ct.load(
|
||||
x,
|
||||
index=(row, j),
|
||||
shape=(1, TILE_SIZE),
|
||||
allow_tma=False,
|
||||
latency=1,
|
||||
)
|
||||
residual_j = ct.load(
|
||||
residual,
|
||||
index=(row, j),
|
||||
shape=(1, TILE_SIZE),
|
||||
allow_tma=False,
|
||||
latency=1,
|
||||
)
|
||||
# Fuse residual: convert to float32, add, then use for RMS computation
|
||||
xj = ct.astype(xj, ct.float32)
|
||||
residual_j = ct.astype(residual_j, ct.float32)
|
||||
xj = xj + residual_j
|
||||
_rms += xj * xj
|
||||
|
||||
# Store the sum (new residual) back to residual tensor
|
||||
xj_stored = ct.astype(xj, residual.dtype)
|
||||
ct.store(
|
||||
residual,
|
||||
index=(row, j),
|
||||
tile=xj_stored,
|
||||
allow_tma=False,
|
||||
latency=1,
|
||||
)
|
||||
|
||||
# Calculate RMS Norm
|
||||
rms = ct.rsqrt(ct.sum(_rms, axis=1, keepdims=False) / N + eps)
|
||||
ct.store(Rstd, index=(row,), tile=rms)
|
||||
|
||||
# Second pass: load from residual (which now contains the sum), apply normalization, store to x
|
||||
for j in range(0, num_tiles):
|
||||
wj = ct.load(
|
||||
w,
|
||||
index=(j,),
|
||||
shape=(TILE_SIZE,),
|
||||
allow_tma=False,
|
||||
latency=1,
|
||||
)
|
||||
wj = ct.astype(wj, ct.float32)
|
||||
# Apply Gemma-style bias if enabled
|
||||
if use_gemma:
|
||||
wj = wj + 1.0
|
||||
residual_j = ct.load(
|
||||
residual,
|
||||
index=(row, j),
|
||||
shape=(1, TILE_SIZE),
|
||||
allow_tma=False,
|
||||
latency=1,
|
||||
)
|
||||
# Load from residual (which now contains x + residual sum)
|
||||
residual_j = ct.astype(residual_j, ct.float32)
|
||||
yj = residual_j * rms * wj
|
||||
yj = ct.astype(yj, x.dtype)
|
||||
ct.store(
|
||||
x,
|
||||
index=(row, j),
|
||||
tile=yj,
|
||||
allow_tma=False,
|
||||
latency=1,
|
||||
)
|
||||
|
||||
@ct.kernel
|
||||
def rms_norm_fuse_residual_kernel_gather(
|
||||
x,
|
||||
residual,
|
||||
w,
|
||||
Rstd,
|
||||
N: ct.Constant[int],
|
||||
eps: ct.Constant[float],
|
||||
TILE_SIZE: ct.Constant[int],
|
||||
use_gemma: ct.Constant[bool],
|
||||
):
|
||||
"""RMSNorm kernel with residual fusion for non-static persistent mode with ptr loads"""
|
||||
row = ct.bid(0)
|
||||
_rms = ct.full((TILE_SIZE,), 0.0, dtype=ct.float32)
|
||||
num_tiles = ct.cdiv(N, TILE_SIZE)
|
||||
offsets = ct.arange(TILE_SIZE, dtype=ct.int32)
|
||||
|
||||
# First pass: compute RMS with fused residual addition and store sum to residual
|
||||
for j in range(0, num_tiles):
|
||||
offs = j * TILE_SIZE + offsets
|
||||
xj = ct.gather(x, (row, offs), latency=1)
|
||||
residual_j = ct.gather(residual, (row, offs), latency=1)
|
||||
# Fuse residual: convert to float32, add, then use for RMS computation
|
||||
xj = ct.astype(xj, ct.float32)
|
||||
residual_j = ct.astype(residual_j, ct.float32)
|
||||
xj = xj + residual_j
|
||||
_rms += xj * xj
|
||||
|
||||
# Store the sum (new residual) back to residual tensor
|
||||
xj_stored = ct.astype(xj, residual.dtype)
|
||||
ct.scatter(residual, (row, offs), xj_stored, latency=1)
|
||||
|
||||
# Calculate RMS Norm
|
||||
rms = ct.rsqrt(ct.sum(_rms, axis=0, keepdims=False) / N + eps)
|
||||
ct.scatter(Rstd, row, rms)
|
||||
|
||||
# Second pass: load from residual (which now contains the sum), apply normalization, store to x
|
||||
for j in range(0, num_tiles):
|
||||
offs = j * TILE_SIZE + offsets
|
||||
wj = ct.gather(w, offs, latency=1)
|
||||
wj = ct.astype(wj, ct.float32)
|
||||
# Apply Gemma-style bias if enabled
|
||||
if use_gemma:
|
||||
wj = wj + 1.0
|
||||
residual_j = ct.gather(residual, (row, offs), latency=1)
|
||||
# Load from residual (which now contains x + residual sum)
|
||||
residual_j = ct.astype(residual_j, ct.float32)
|
||||
yj = residual_j * rms * wj
|
||||
yj = ct.astype(yj, x.dtype)
|
||||
ct.scatter(x, (row, offs), yj, latency=1)
|
||||
|
||||
@ct.kernel
|
||||
def rms_norm_fuse_residual_kernel_static_persistent(
|
||||
X, # Input tensor
|
||||
Residual, # Residual tensor
|
||||
W, # Weight tensor
|
||||
TILE_SIZE_M: ct.Constant[int], # 4 rows per block
|
||||
TILE_SIZE_N: ct.Constant[int], # columns per block
|
||||
eps: ct.Constant[float], # Epsilon value
|
||||
use_gemma: ct.Constant[bool], # Gemma-style weight bias
|
||||
):
|
||||
"""
|
||||
CuTile static persistent RMSNorm kernel with residual fusion that processes multiple blocks per program.
|
||||
Each program processes multiple blocks in a loop for better efficiency.
|
||||
"""
|
||||
# Get program ID
|
||||
pid = ct.bid(0)
|
||||
|
||||
# Infer tensor dimensions from input shape
|
||||
M = X.shape[0] # Number of rows
|
||||
N = X.shape[1] # Number of columns
|
||||
|
||||
# Calculate upper bound - number of row blocks to process
|
||||
upper_bound = (M + TILE_SIZE_M - 1) // TILE_SIZE_M
|
||||
|
||||
# Load weight vector once (shared across all blocks processed by this program)
|
||||
w = ct.load(W, index=(0,), shape=(TILE_SIZE_N,))
|
||||
w = ct.astype(w, ct.float32)
|
||||
# Apply Gemma-style bias if enabled
|
||||
if use_gemma:
|
||||
w = w + 1.0
|
||||
|
||||
# Static persistent loop: each program processes multiple blocks
|
||||
num_tiles_x = ct.num_blocks(0)
|
||||
for current_block in range(pid, upper_bound, num_tiles_x):
|
||||
# Load input tile
|
||||
x = ct.load(
|
||||
X,
|
||||
index=(current_block, 0),
|
||||
shape=(TILE_SIZE_M, TILE_SIZE_N),
|
||||
latency=10, # +2% perf from this hint
|
||||
)
|
||||
# Load residual tile
|
||||
residual = ct.load(
|
||||
Residual,
|
||||
index=(current_block, 0),
|
||||
shape=(TILE_SIZE_M, TILE_SIZE_N),
|
||||
latency=10,
|
||||
)
|
||||
|
||||
# Fuse residual: convert to float32 and add
|
||||
x = ct.astype(x, ct.float32)
|
||||
residual = ct.astype(residual, ct.float32)
|
||||
x = ct.add(x, residual)
|
||||
|
||||
# Store the sum (new residual) back to Residual tensor
|
||||
x_stored = ct.astype(x, Residual.dtype)
|
||||
ct.store(
|
||||
Residual,
|
||||
index=(current_block, 0),
|
||||
tile=x_stored,
|
||||
allow_tma=False,
|
||||
latency=3,
|
||||
)
|
||||
|
||||
# Step 1: Compute x^2
|
||||
x_squared = ct.mul(x, x)
|
||||
|
||||
# Step 2: Reduce sum along axis=1 (columns)
|
||||
x2_sum = ct.sum(x_squared, axis=1, keepdims=True) # Shape: [TILE_SIZE_M, 1]
|
||||
|
||||
# Step 3: Compute variance (divide by N)
|
||||
N_f32 = ct.full((TILE_SIZE_M, 1), N * 1.0, dtype=ct.float32)
|
||||
variance = ct.truediv(x2_sum, N_f32)
|
||||
|
||||
# Step 4: Add epsilon and compute rsqrt
|
||||
eps_tensor = ct.full((TILE_SIZE_M, 1), eps, dtype=ct.float32)
|
||||
variance_eps = ct.add(variance, eps_tensor)
|
||||
rsqrt_var = ct.rsqrt(variance_eps)
|
||||
|
||||
# Step 5: Apply normalization
|
||||
x_normalized = ct.mul(x, rsqrt_var)
|
||||
|
||||
# Step 6: Apply linear transformation
|
||||
# Broadcast weight to match input shape
|
||||
w_broadcasted = ct.reshape(w, (1, TILE_SIZE_N))
|
||||
b_broadcasted = ct.full((1, TILE_SIZE_N), 0.0, dtype=ct.float32)
|
||||
|
||||
# Apply linear transformation: y = x_normalized * w + b
|
||||
y = ct.mul(x_normalized, w_broadcasted)
|
||||
y = ct.add(y, b_broadcasted)
|
||||
|
||||
# Convert back to original dtype and store to X (new hidden_states)
|
||||
y = ct.astype(y, X.dtype)
|
||||
|
||||
# Store result to X
|
||||
ct.store(
|
||||
X,
|
||||
index=(current_block, 0),
|
||||
tile=y,
|
||||
allow_tma=False, # +30% perf
|
||||
latency=3, # +3% perf from this hint
|
||||
)
|
||||
70
tensorrt_llm/_torch/cuda_tile_utils.py
Normal file
70
tensorrt_llm/_torch/cuda_tile_utils.py
Normal file
@ -0,0 +1,70 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import importlib.metadata
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
|
||||
from ..logger import logger
|
||||
|
||||
IS_CUDA_TILE_AVAILABLE = False
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def next_power_of_2(n: int):
|
||||
"""Return the smallest power of 2 greater than or equal to n"""
|
||||
n -= 1
|
||||
n |= n >> 1
|
||||
n |= n >> 2
|
||||
n |= n >> 4
|
||||
n |= n >> 8
|
||||
n |= n >> 16
|
||||
n |= n >> 32
|
||||
n += 1
|
||||
return n
|
||||
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
if platform.system() != "Windows":
|
||||
try:
|
||||
import cuda.tile # noqa: F401
|
||||
except ImportError:
|
||||
logger.warning("cuda-tile package not found, TileIR kernels will not be available")
|
||||
else:
|
||||
if (cc := torch.cuda.get_device_properties()) and (cc.major, cc.minor) < (10, 0):
|
||||
logger.warning(
|
||||
f"TileIR requires compute capability 10.0 or higher, but the current device has "
|
||||
f"{cc.major}.{cc.minor}. TileIR kernels will not be available"
|
||||
)
|
||||
elif shutil.which("tileiras") is not None:
|
||||
IS_CUDA_TILE_AVAILABLE = True
|
||||
# For systems without tileiras installed, try to locate from nvidia-cuda-tileiras package.
|
||||
elif tileiras_files := importlib.metadata.files("nvidia-cuda-tileiras"):
|
||||
for pkg_file in tileiras_files:
|
||||
if pkg_file.name == "tileiras":
|
||||
tileiras_dir = pkg_file.locate().parent
|
||||
os.environ["PATH"] = f"{os.environ['PATH']}:{tileiras_dir}"
|
||||
break
|
||||
assert shutil.which("tileiras") is not None
|
||||
IS_CUDA_TILE_AVAILABLE = True
|
||||
else:
|
||||
logger.warning("tileiras compiler not found, TileIR kernels will not be available")
|
||||
@ -1,3 +1,4 @@
|
||||
from ..cuda_tile_utils import IS_CUDA_TILE_AVAILABLE
|
||||
from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE
|
||||
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
|
||||
from ..modules.attention import attn_custom_op_inplace, mla_custom_op_inplace
|
||||
@ -38,3 +39,11 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
__all__ += [
|
||||
'cute_dsl_nvfp4_gemm_blackwell',
|
||||
]
|
||||
|
||||
if IS_CUDA_TILE_AVAILABLE:
|
||||
from .cuda_tile_custom_ops import (cuda_tile_rms_norm,
|
||||
cuda_tile_rms_norm_fuse_residual_)
|
||||
__all__ += [
|
||||
'cuda_tile_rms_norm',
|
||||
'cuda_tile_rms_norm_fuse_residual_',
|
||||
]
|
||||
|
||||
188
tensorrt_llm/_torch/custom_ops/cuda_tile_custom_ops.py
Normal file
188
tensorrt_llm/_torch/custom_ops/cuda_tile_custom_ops.py
Normal file
@ -0,0 +1,188 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Adapted from https://github.com/NVIDIA/cutile-python/blob/main/test/bench_rms_norm.py
|
||||
|
||||
import torch
|
||||
|
||||
from ..cuda_tile_utils import IS_CUDA_TILE_AVAILABLE
|
||||
|
||||
if IS_CUDA_TILE_AVAILABLE:
|
||||
import cuda.tile as ct
|
||||
|
||||
from ..cuda_tile_kernels import (
|
||||
rms_norm_fuse_residual_kernel,
|
||||
rms_norm_fuse_residual_kernel_gather,
|
||||
rms_norm_fuse_residual_kernel_static_persistent,
|
||||
rms_norm_kernel,
|
||||
rms_norm_kernel_gather,
|
||||
rms_norm_kernel_static_persistent,
|
||||
)
|
||||
from ..cuda_tile_utils import ceil_div, next_power_of_2
|
||||
|
||||
@torch.library.custom_op("trtllm::cuda_tile_rms_norm", mutates_args=())
|
||||
def cuda_tile_rms_norm(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float,
|
||||
static_persistent: bool,
|
||||
gather: bool,
|
||||
use_gemma: bool,
|
||||
) -> torch.Tensor:
|
||||
x = x.contiguous()
|
||||
weight = weight.contiguous()
|
||||
|
||||
# Allocate output tensor
|
||||
y = torch.empty_like(x)
|
||||
M, N = x.shape
|
||||
|
||||
if static_persistent:
|
||||
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||
TILE_SIZE_M = 4 # Default value, could be made configurable
|
||||
TILE_SIZE_N = next_power_of_2(N)
|
||||
|
||||
# Other tile sizes are more optimal when other dimension is too large/too small
|
||||
if TILE_SIZE_N <= 1024:
|
||||
TILE_SIZE_M = 16
|
||||
elif TILE_SIZE_N >= 16384:
|
||||
TILE_SIZE_M = 2
|
||||
|
||||
grid_size = min(
|
||||
NUM_SMS,
|
||||
ceil_div(M, TILE_SIZE_M) * ceil_div(N, TILE_SIZE_N),
|
||||
)
|
||||
grid = (grid_size,)
|
||||
ct.launch(
|
||||
torch.cuda.current_stream(),
|
||||
grid,
|
||||
rms_norm_kernel_static_persistent,
|
||||
(
|
||||
x,
|
||||
y,
|
||||
weight,
|
||||
TILE_SIZE_M,
|
||||
TILE_SIZE_N,
|
||||
eps,
|
||||
use_gemma,
|
||||
),
|
||||
)
|
||||
else:
|
||||
# Standard RMSNorm kernel
|
||||
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
|
||||
MAX_FUSED_SIZE = 2048 // x.element_size()
|
||||
TILE_SIZE = min(MAX_FUSED_SIZE, next_power_of_2(N))
|
||||
grid = (M,)
|
||||
kernel = rms_norm_kernel_gather if gather else rms_norm_kernel
|
||||
ct.launch(
|
||||
torch.cuda.current_stream(),
|
||||
grid,
|
||||
kernel,
|
||||
(
|
||||
x,
|
||||
weight,
|
||||
y,
|
||||
rstd,
|
||||
N,
|
||||
eps,
|
||||
TILE_SIZE,
|
||||
use_gemma,
|
||||
),
|
||||
)
|
||||
return y.view(*x.shape)
|
||||
|
||||
@cuda_tile_rms_norm.register_fake
|
||||
def _(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float,
|
||||
static_persistent: bool,
|
||||
gather: bool,
|
||||
use_gemma: bool,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x.contiguous())
|
||||
|
||||
@torch.library.custom_op(
|
||||
"trtllm::cuda_tile_rms_norm_fuse_residual_",
|
||||
mutates_args=("x", "residual"),
|
||||
)
|
||||
def cuda_tile_rms_norm_fuse_residual_(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float,
|
||||
static_persistent: bool,
|
||||
gather: bool,
|
||||
use_gemma: bool,
|
||||
) -> None:
|
||||
assert x.is_contiguous(), "x must be contiguous for in-place operation"
|
||||
assert residual.is_contiguous(), "residual must be contiguous for in-place operation"
|
||||
weight = weight.contiguous()
|
||||
|
||||
M, N = x.shape
|
||||
|
||||
if static_persistent:
|
||||
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||
TILE_SIZE_M = 4 # Default value, could be made configurable
|
||||
TILE_SIZE_N = next_power_of_2(N)
|
||||
|
||||
# Other tile sizes are more optimal when other dimension is too large/too small
|
||||
if TILE_SIZE_N <= 1024:
|
||||
TILE_SIZE_M = 16
|
||||
elif TILE_SIZE_N >= 16384:
|
||||
TILE_SIZE_M = 2
|
||||
|
||||
grid_size = min(
|
||||
NUM_SMS,
|
||||
ceil_div(M, TILE_SIZE_M) * ceil_div(N, TILE_SIZE_N),
|
||||
)
|
||||
grid = (grid_size,)
|
||||
ct.launch(
|
||||
torch.cuda.current_stream(),
|
||||
grid,
|
||||
rms_norm_fuse_residual_kernel_static_persistent,
|
||||
(
|
||||
x,
|
||||
residual,
|
||||
weight,
|
||||
TILE_SIZE_M,
|
||||
TILE_SIZE_N,
|
||||
eps,
|
||||
use_gemma,
|
||||
),
|
||||
)
|
||||
else:
|
||||
# Standard RMSNorm kernel with residual fusion
|
||||
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
|
||||
MAX_FUSED_SIZE = 2048 // x.element_size()
|
||||
TILE_SIZE = min(MAX_FUSED_SIZE, next_power_of_2(N))
|
||||
grid = (M,)
|
||||
kernel = (
|
||||
rms_norm_fuse_residual_kernel_gather if gather else rms_norm_fuse_residual_kernel
|
||||
)
|
||||
ct.launch(
|
||||
torch.cuda.current_stream(),
|
||||
grid,
|
||||
kernel,
|
||||
(
|
||||
x,
|
||||
residual,
|
||||
weight,
|
||||
rstd,
|
||||
N,
|
||||
eps,
|
||||
TILE_SIZE,
|
||||
use_gemma,
|
||||
),
|
||||
)
|
||||
@ -20,6 +20,7 @@ from typing import Optional, Tuple, TypeAlias, Union, cast
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..cuda_tile_utils import IS_CUDA_TILE_AVAILABLE
|
||||
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
|
||||
from ..utils import Fp4QuantizedTensor
|
||||
|
||||
@ -39,6 +40,7 @@ class RMSNorm(nn.Module):
|
||||
has_weights: bool = True,
|
||||
use_gemma: bool = False,
|
||||
quantize_type: Optional[str] = None,
|
||||
use_cuda_tile: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -49,6 +51,10 @@ class RMSNorm(nn.Module):
|
||||
raise NotImplementedError(
|
||||
f"Quantize type {quantize_type} not implemented in RMSNorm")
|
||||
self.is_nvfp4 = quantize_type == "nvfp4"
|
||||
if use_cuda_tile and not IS_CUDA_TILE_AVAILABLE:
|
||||
raise ValueError(
|
||||
"cuda.tile is not available, please install cuda-tile pypi package"
|
||||
)
|
||||
|
||||
if has_weights:
|
||||
if not use_gemma:
|
||||
@ -65,6 +71,7 @@ class RMSNorm(nn.Module):
|
||||
persistent=False)
|
||||
self.variance_epsilon = eps
|
||||
self.use_gemma = use_gemma
|
||||
self.use_cuda_tile = use_cuda_tile
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -127,8 +134,30 @@ class RMSNorm(nn.Module):
|
||||
hidden_states_fused = Fp4QuantizedTensor(normed_fp4_u8, sf_fused)
|
||||
return (hidden_states_fused,
|
||||
residual_out) if has_residual else hidden_states_fused
|
||||
|
||||
if IS_FLASHINFER_AVAILABLE:
|
||||
elif self.use_cuda_tile:
|
||||
if isinstance(residual, torch.Tensor):
|
||||
# Use fused residual kernel
|
||||
hidden_states = hidden_states.contiguous()
|
||||
residual = residual.contiguous()
|
||||
torch.ops.trtllm.cuda_tile_rms_norm_fuse_residual_(
|
||||
x=hidden_states,
|
||||
residual=residual,
|
||||
weight=self.weight,
|
||||
eps=self.variance_epsilon,
|
||||
static_persistent=True,
|
||||
gather=True,
|
||||
use_gemma=self.use_gemma,
|
||||
)
|
||||
else:
|
||||
hidden_states = torch.ops.trtllm.cuda_tile_rms_norm(
|
||||
x=hidden_states,
|
||||
weight=self.weight,
|
||||
eps=self.variance_epsilon,
|
||||
static_persistent=True,
|
||||
gather=True,
|
||||
use_gemma=self.use_gemma,
|
||||
)
|
||||
elif IS_FLASHINFER_AVAILABLE:
|
||||
from ..custom_ops import (flashinfer_fused_add_rmsnorm,
|
||||
flashinfer_gemma_fused_add_rmsnorm,
|
||||
flashinfer_gemma_rmsnorm,
|
||||
|
||||
309
tests/unittest/_torch/thop/parallel/test_cuda_tile_custom_ops.py
Normal file
309
tests/unittest/_torch/thop/parallel/test_cuda_tile_custom_ops.py
Normal file
@ -0,0 +1,309 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import tensorrt_llm # noqa: F401
|
||||
from tensorrt_llm._torch.cuda_tile_utils import IS_CUDA_TILE_AVAILABLE
|
||||
|
||||
# Skip all tests if CUDA tile is not available
|
||||
pytestmark = pytest.mark.skipif(not IS_CUDA_TILE_AVAILABLE, reason="CUDA tile is not available")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def prepare_testcase_environment(tmp_path):
|
||||
"""Set random seed and enable deterministic mode before each test."""
|
||||
# Enable deterministic execution.
|
||||
prev_cublas_workspace_config = os.environ.get("CUBLAS_WORKSPACE_CONFIG")
|
||||
prev_deterministic_mode = torch.are_deterministic_algorithms_enabled()
|
||||
torch.manual_seed(19260817)
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||
torch.use_deterministic_algorithms(True)
|
||||
# Enable cuTile debug dump.
|
||||
os.environ["CUDA_TILE_DUMP_BYTECODE"] = str(tmp_path / "bytecode")
|
||||
os.environ["CUDA_TILE_DUMP_TILEIR"] = str(tmp_path / "tileir")
|
||||
|
||||
yield
|
||||
|
||||
# Rewind to previous states.
|
||||
if prev_cublas_workspace_config is not None:
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = prev_cublas_workspace_config
|
||||
else:
|
||||
del os.environ["CUBLAS_WORKSPACE_CONFIG"]
|
||||
torch.use_deterministic_algorithms(prev_deterministic_mode)
|
||||
del os.environ["CUDA_TILE_DUMP_BYTECODE"]
|
||||
del os.environ["CUDA_TILE_DUMP_TILEIR"]
|
||||
|
||||
|
||||
def reference_rms_norm(
|
||||
hidden_states: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float,
|
||||
use_gemma: bool,
|
||||
residual: torch.Tensor | None = None,
|
||||
):
|
||||
"""
|
||||
Reference RMSNorm implementation using PyTorch operations.
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor
|
||||
weight: Weight tensor
|
||||
eps: Epsilon for numerical stability
|
||||
use_gemma: Whether to use Gemma-style weight bias (weight + 1)
|
||||
residual: Optional residual tensor to add before normalization
|
||||
|
||||
Returns:
|
||||
Tuple of (normalized output, new residual) if residual is provided,
|
||||
otherwise just normalized output
|
||||
"""
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
new_residual = None
|
||||
if residual is not None:
|
||||
hidden_states = hidden_states + residual.to(torch.float32)
|
||||
new_residual = hidden_states.to(input_dtype)
|
||||
|
||||
# Prepare weight with Gemma-style bias if needed
|
||||
if use_gemma:
|
||||
weight_to_apply = weight + 1.0
|
||||
else:
|
||||
weight_to_apply = weight
|
||||
|
||||
# Use torch.nn.functional.rms_norm for the normalization
|
||||
hidden_states = F.rms_norm(
|
||||
hidden_states, (hidden_states.shape[-1],), weight=weight_to_apply.to(torch.float32), eps=eps
|
||||
)
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
|
||||
if residual is not None:
|
||||
return hidden_states, new_residual
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M,N",
|
||||
[
|
||||
(1, 128),
|
||||
(4, 256),
|
||||
(16, 512),
|
||||
(32, 1024),
|
||||
(64, 2048),
|
||||
(128, 4096),
|
||||
(8, 8192),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("use_gemma", [False, True])
|
||||
@pytest.mark.parametrize("static_persistent", [False, True])
|
||||
@pytest.mark.parametrize("gather", [False, True])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_cuda_tile_rms_norm(M, N, use_gemma, static_persistent, gather, dtype):
|
||||
"""Test cuda_tile_rms_norm operator against reference implementation."""
|
||||
eps = 1e-5
|
||||
|
||||
# Create input tensors
|
||||
x = torch.randn(M, N, dtype=dtype, device="cuda")
|
||||
weight = torch.randn(N, dtype=dtype, device="cuda")
|
||||
|
||||
# Clone for reference computation
|
||||
x_ref = x.clone()
|
||||
weight_ref = weight.clone()
|
||||
|
||||
# Compute reference
|
||||
ref_output = reference_rms_norm(x_ref, weight_ref, eps, use_gemma)
|
||||
|
||||
# Compute with cuda_tile kernel
|
||||
cuda_output = torch.ops.trtllm.cuda_tile_rms_norm(
|
||||
x=x,
|
||||
weight=weight,
|
||||
eps=eps,
|
||||
static_persistent=static_persistent,
|
||||
gather=gather,
|
||||
use_gemma=use_gemma,
|
||||
)
|
||||
|
||||
# Compare results
|
||||
# Use relatively loose tolerance due to different computation orders
|
||||
rtol = 1e-2 if dtype == torch.float16 else 5e-2
|
||||
atol = 1e-3 if dtype == torch.float16 else 5e-3
|
||||
|
||||
torch.testing.assert_close(
|
||||
cuda_output,
|
||||
ref_output,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=f"cuda_tile_rms_norm output mismatch for M={M}, N={N}, "
|
||||
f"use_gemma={use_gemma}, static_persistent={static_persistent}, "
|
||||
f"gather={gather}, dtype={dtype}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M,N",
|
||||
[
|
||||
(1, 128),
|
||||
(4, 256),
|
||||
(16, 512),
|
||||
(32, 1024),
|
||||
(64, 2048),
|
||||
(128, 4096),
|
||||
(8, 8192),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("use_gemma", [False, True])
|
||||
@pytest.mark.parametrize("static_persistent", [False, True])
|
||||
@pytest.mark.parametrize("gather", [False, True])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_cuda_tile_rms_norm_fuse_residual(M, N, use_gemma, static_persistent, gather, dtype):
|
||||
"""Test cuda_tile_rms_norm_fuse_residual_ operator against reference implementation."""
|
||||
eps = 1e-5
|
||||
|
||||
# Create input tensors
|
||||
x = torch.randn(M, N, dtype=dtype, device="cuda")
|
||||
residual = torch.randn(M, N, dtype=dtype, device="cuda")
|
||||
weight = torch.randn(N, dtype=dtype, device="cuda")
|
||||
|
||||
# Clone for reference computation
|
||||
x_ref = x.clone()
|
||||
residual_ref = residual.clone()
|
||||
weight_ref = weight.clone()
|
||||
|
||||
# Compute reference
|
||||
ref_output, ref_new_residual = reference_rms_norm(
|
||||
x_ref, weight_ref, eps, use_gemma, residual_ref
|
||||
)
|
||||
|
||||
# Ensure tensors are contiguous for in-place operation
|
||||
x = x.contiguous()
|
||||
residual = residual.contiguous()
|
||||
|
||||
# Compute with cuda_tile kernel (in-place operation)
|
||||
torch.ops.trtllm.cuda_tile_rms_norm_fuse_residual_(
|
||||
x=x,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
eps=eps,
|
||||
static_persistent=static_persistent,
|
||||
gather=gather,
|
||||
use_gemma=use_gemma,
|
||||
)
|
||||
|
||||
# After in-place operation:
|
||||
# x contains the normalized output
|
||||
# residual contains the un-normalized sum (new residual)
|
||||
|
||||
# Compare results
|
||||
# Use relatively loose tolerance due to different computation orders
|
||||
rtol = 1e-2 if dtype == torch.float16 else 5e-2
|
||||
atol = 1e-3 if dtype == torch.float16 else 5e-3
|
||||
|
||||
torch.testing.assert_close(
|
||||
x,
|
||||
ref_output,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=f"cuda_tile_rms_norm_fuse_residual_ output mismatch for M={M}, N={N}, "
|
||||
f"use_gemma={use_gemma}, static_persistent={static_persistent}, "
|
||||
f"gather={gather}, dtype={dtype}",
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
residual,
|
||||
ref_new_residual,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=f"cuda_tile_rms_norm_fuse_residual_ residual mismatch for M={M}, N={N}, "
|
||||
f"use_gemma={use_gemma}, static_persistent={static_persistent}, "
|
||||
f"gather={gather}, dtype={dtype}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_cuda_tile_rms_norm_fuse_residual_inplace(dtype):
|
||||
"""Test that fuse_residual operator truly modifies tensors in-place."""
|
||||
eps = 1e-5
|
||||
M, N = 16, 256
|
||||
|
||||
x = torch.randn(M, N, dtype=dtype, device="cuda").contiguous()
|
||||
residual = torch.randn(M, N, dtype=dtype, device="cuda").contiguous()
|
||||
weight = torch.randn(N, dtype=dtype, device="cuda")
|
||||
|
||||
# Store original data pointers
|
||||
x_data_ptr = x.data_ptr()
|
||||
residual_data_ptr = residual.data_ptr()
|
||||
|
||||
# Call in-place operator
|
||||
torch.ops.trtllm.cuda_tile_rms_norm_fuse_residual_(
|
||||
x=x,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
eps=eps,
|
||||
static_persistent=True,
|
||||
gather=True,
|
||||
use_gemma=False,
|
||||
)
|
||||
|
||||
# Verify that tensors were modified in-place (same memory location)
|
||||
assert x.data_ptr() == x_data_ptr, "x tensor was not modified in-place"
|
||||
assert residual.data_ptr() == residual_data_ptr, "residual tensor was not modified in-place"
|
||||
|
||||
|
||||
def test_cuda_tile_rms_norm_fuse_residual_requires_contiguous():
|
||||
"""Test that fuse_residual operator requires contiguous tensors."""
|
||||
eps = 1e-5
|
||||
M, N = 16, 256
|
||||
dtype = torch.float16
|
||||
|
||||
# Create non-contiguous tensors
|
||||
x = torch.randn(M, N * 2, dtype=dtype, device="cuda")[:, ::2]
|
||||
residual = torch.randn(M, N, dtype=dtype, device="cuda").contiguous()
|
||||
weight = torch.randn(N, dtype=dtype, device="cuda")
|
||||
|
||||
assert not x.is_contiguous(), "x should be non-contiguous for this test"
|
||||
|
||||
# Should raise assertion error for non-contiguous x
|
||||
with pytest.raises(AssertionError, match="x must be contiguous"):
|
||||
torch.ops.trtllm.cuda_tile_rms_norm_fuse_residual_(
|
||||
x=x,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
eps=eps,
|
||||
static_persistent=True,
|
||||
gather=True,
|
||||
use_gemma=False,
|
||||
)
|
||||
|
||||
# Create non-contiguous residual
|
||||
x = torch.randn(M, N, dtype=dtype, device="cuda").contiguous()
|
||||
residual = torch.randn(M, N * 2, dtype=dtype, device="cuda")[:, ::2]
|
||||
|
||||
assert not residual.is_contiguous(), "residual should be non-contiguous for this test"
|
||||
|
||||
# Should raise assertion error for non-contiguous residual
|
||||
with pytest.raises(AssertionError, match="residual must be contiguous"):
|
||||
torch.ops.trtllm.cuda_tile_rms_norm_fuse_residual_(
|
||||
x=x,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
eps=eps,
|
||||
static_persistent=True,
|
||||
gather=True,
|
||||
use_gemma=False,
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user