[None][fix] Update RMSNorm custom op plumbing (#10843)

Signed-off-by: jintaop <jintaop@nvidia.com>
This commit is contained in:
彭晋韬(jtao peng) 2026-01-22 21:03:22 +08:00 committed by GitHub
parent 1dc49b266e
commit 9beb971827
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 63 additions and 148 deletions

View File

@ -916,3 +916,23 @@ def _register_fake():
# This is a fake implementation for shape inference
# The actual operation modifies fused_q and q_pe in-place
return None
@torch.library.register_fake("trtllm::fused_add_rms_norm_quant")
def _(
input: torch.Tensor,
residual: torch.Tensor,
gamma: torch.Tensor,
sf_scale: Optional[torch.Tensor],
use_rms_norm: bool = True,
eps: float = 1e-5,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
m, n = input.shape
# normed_output_fp4: [M, N/8] as int32 (8 FP4 values packed per int32)
normed_output_fp4 = input.new_empty((m, n // 8), dtype=torch.int32)
# output: [M, N] pre-norm output, same dtype as input
output = input.new_empty((m, n), dtype=input.dtype)
# sf_out: scale factors, swizzled layout
sf_vec_size = 16
sf_size = ((m + 127) // 128) * 128 * ((n // sf_vec_size + 3) // 4) * 4
sf_out = input.new_empty((sf_size, ), dtype=torch.uint8)
return normed_output_fp4, output, sf_out

View File

@ -1876,59 +1876,6 @@ def record_stream(tensor: torch.Tensor, stream_id: int) -> None:
tensor.record_stream(stream)
def fused_add_rms_norm_quant(
input: torch.Tensor,
residual: torch.Tensor,
gamma: torch.Tensor,
sf_scale: Optional[torch.Tensor],
use_rms_norm: bool = True,
eps: float = 1e-6,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Fused Add + RMSNorm/LayerNorm + FP4 Quantization kernel.
Args:
input: [M, N] input tensor (fp16/bf16)
residual: [M, N] residual tensor (fp16/bf16)
gamma: [N] normalization weight (fp16/bf16)
sf_scale: [1] optional scale factor for FP4 quantization (float32)
use_rms_norm: if True use RMSNorm, else use LayerNorm
eps: epsilon for normalization
Returns:
normed_output_fp4: [M, N/8] FP4 quantized normalized output (int32, packed)
output: [M, N] pre-norm output (input + residual), same dtype as input
sf_out: scale factors for FP4 quantization (uint8), swizzled layout
Note:
This kernel requires SM90 (Hopper) or SM100 (Blackwell) GPU.
Hidden dimension N must be >= 2048 and <= 16384.
"""
return torch.ops.trtllm.fused_add_rms_norm_quant(input, residual, gamma,
sf_scale, use_rms_norm,
eps)
@torch.library.register_fake("trtllm::fused_add_rms_norm_quant")
def _fused_add_rms_norm_quant_fake(
input: torch.Tensor,
residual: torch.Tensor,
gamma: torch.Tensor,
sf_scale: Optional[torch.Tensor],
use_rms_norm: bool = True,
eps: float = 1e-5,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
m, n = input.shape
# normed_output_fp4: [M, N/8] as int32 (8 FP4 values packed per int32)
normed_output_fp4 = input.new_empty((m, n // 8), dtype=torch.int32)
# output: [M, N] pre-norm output, same dtype as input
output = input.new_empty((m, n), dtype=input.dtype)
# sf_out: scale factors, swizzled layout
sf_vec_size = 16
sf_size = ((m + 127) // 128) * 128 * ((n // sf_vec_size + 3) // 4) * 4
sf_out = input.new_empty((sf_size, ), dtype=torch.uint8)
return normed_output_fp4, output, sf_out
class Fp4GemmAllreduceRunner(TunableRunner):
runner_dict = dict()
tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec(

View File

@ -20,8 +20,6 @@ from typing import Optional, Tuple, TypeAlias, Union, cast
import torch
from torch import nn
from tensorrt_llm.logger import logger
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
from ..utils import Fp4QuantizedTensor
@ -76,109 +74,59 @@ class RMSNorm(nn.Module):
_ArgumentNotSpecifiedSentinelType] = _ARGUMENT_NOT_SPECIFIED_SENTINEL,
) -> Union[torch.Tensor, Fp4QuantizedTensor, Tuple[Union[
torch.Tensor, Fp4QuantizedTensor], Optional[torch.Tensor]]]:
return_residual = True
if residual is self._ARGUMENT_NOT_SPECIFIED_SENTINEL:
return_residual = False
has_residual = residual is not self._ARGUMENT_NOT_SPECIFIED_SENTINEL
if not has_residual:
residual = None
if self.is_nvfp4 and residual is not None and not self.use_gemma:
if self.is_nvfp4 and has_residual and not self.use_gemma:
nvfp4_scale = getattr(self, "nvfp4_scale", None)
if nvfp4_scale is None:
raise ValueError(
f"layeridx={getattr(self, 'layer_idx', None)} RMSNorm NVFP4 output requested "
"but no `nvfp4_scale` is attached; ")
else:
def _can_use_fused_kernel() -> Tuple[bool, str]:
if not hidden_states.is_cuda or not residual.is_cuda:
return False, "inputs must be CUDA tensors"
if not self.weight.is_cuda:
return False, "gamma/weight must be a CUDA tensor"
if hidden_states.ndim < 2:
return False, "input must have rank >= 2"
if hidden_states.shape != residual.shape:
return False, f"input/residual shape mismatch: {tuple(hidden_states.shape)} vs {tuple(residual.shape)}"
n = int(hidden_states.shape[-1])
if self.weight.ndim != 1 or int(self.weight.numel()) != n:
return False, f"gamma/weight must be 1D with numel == hidden_size ({n}), got shape={tuple(self.weight.shape)}"
# Match the underlying C++ op: fp16/bf16 only (no fp8).
if hidden_states.dtype not in (torch.float16,
torch.bfloat16):
return False, f"unsupported dtype {hidden_states.dtype} (expected fp16/bf16)"
if n % 16 != 0:
return False, f"hidden size must be divisible by 16 (got {n})"
# Kernel constraints (see fusedAddRMSNormQuant.cpp).
if n < 2048 or n > 16384:
return False, f"hidden size must be in [2048, 16384] (got {n})"
# SM90+ only.
major, _minor = torch.cuda.get_device_capability(
hidden_states.device)
if major < 9:
return False, f"requires SM90+ GPU, got SM{major}{_minor}"
# Scale tensor constraints.
if (nvfp4_scale is not None
and ((not nvfp4_scale.is_cuda) or nvfp4_scale.dtype
!= torch.float32 or nvfp4_scale.numel() != 1)):
return False, f"nvfp4_scale must be a CUDA float32 tensor with numel==1 (got dtype={getattr(nvfp4_scale, 'dtype', None)}, device={getattr(nvfp4_scale, 'device', None)}, numel={getattr(nvfp4_scale, 'numel', lambda: None)()})"
return True, ""
orig_shape = tuple(hidden_states.shape)
n = int(orig_shape[-1])
hs_2d = hidden_states.reshape(-1, n).contiguous()
res_2d = residual.reshape(-1, n)
gamma = self.weight
ok, reason = _can_use_fused_kernel()
if not ok:
raise RuntimeError(
"RMSNorm NVFP4 fused path disabled due to unsupported inputs "
f"(falling back to unfused RMSNorm): {reason}")
else:
from ..custom_ops.torch_custom_ops import \
fused_add_rms_norm_quant
orig_shape = tuple(hidden_states.shape)
n = int(orig_shape[-1])
hs_2d = hidden_states.reshape(-1, n).contiguous()
res_2d = residual.reshape(-1, n)
gamma = self.weight
def _ensure_contiguous_with_dtype(t: torch.Tensor,
key: str):
if t.dtype != hs_2d.dtype:
logger.warning_once(
f"RMSNorm NVFP4 fused path: casting {key} from {t.dtype} to {hs_2d.dtype}.",
key=f"rmsnorm_nvfp4_cast_{key}",
)
t = t.to(dtype=hs_2d.dtype)
return t.contiguous()
res_2d = _ensure_contiguous_with_dtype(res_2d, "residual")
gamma = _ensure_contiguous_with_dtype(gamma, "gamma")
if hs_2d.device != res_2d.device or hs_2d.device != gamma.device:
raise RuntimeError(
"RMSNorm NVFP4 fused path requires all tensors on the same device. "
f"Got input={hs_2d.device}, residual={res_2d.device}, gamma={gamma.device}."
)
sf_scale = nvfp4_scale.contiguous(
) if nvfp4_scale is not None else None
normed_fp4_i32, residual_out_2d, sf_fused = fused_add_rms_norm_quant(
hs_2d,
res_2d,
gamma,
sf_scale,
True,
eps=self.variance_epsilon,
def _ensure_contiguous_with_dtype(t: torch.Tensor, key: str):
if t.dtype != hs_2d.dtype:
raise ValueError(
f"RMSNorm NVFP4 fused path: casting {key} from {t.dtype} to {hs_2d.dtype}."
)
normed_fp4_u8 = normed_fp4_i32.view(torch.uint8)
if len(orig_shape) != 2:
normed_fp4_u8 = normed_fp4_u8.reshape(
*orig_shape[:-1], n // 2)
residual_out = residual_out_2d.reshape(orig_shape)
else:
residual_out = residual_out_2d
return t.contiguous()
hidden_states_fused = Fp4QuantizedTensor(
normed_fp4_u8, sf_fused)
return (hidden_states_fused, residual_out
) if return_residual else hidden_states_fused
res_2d = _ensure_contiguous_with_dtype(res_2d, "residual")
gamma = _ensure_contiguous_with_dtype(gamma, "gamma")
if hs_2d.device != res_2d.device or hs_2d.device != gamma.device:
raise RuntimeError(
"RMSNorm NVFP4 fused path requires all tensors on the same device. "
f"Got input={hs_2d.device}, residual={res_2d.device}, gamma={gamma.device}."
)
sf_scale = nvfp4_scale.contiguous()
normed_fp4_i32, residual_out_2d, sf_fused = torch.ops.trtllm.fused_add_rms_norm_quant(
hs_2d,
res_2d,
gamma,
sf_scale,
True,
eps=self.variance_epsilon,
)
normed_fp4_u8 = normed_fp4_i32.view(torch.uint8)
if len(orig_shape) != 2:
normed_fp4_u8 = normed_fp4_u8.reshape(*orig_shape[:-1], n // 2)
residual_out = residual_out_2d.reshape(orig_shape)
else:
residual_out = residual_out_2d
hidden_states_fused = Fp4QuantizedTensor(normed_fp4_u8, sf_fused)
return (hidden_states_fused,
residual_out) if has_residual else hidden_states_fused
if IS_FLASHINFER_AVAILABLE:
from ..custom_ops import (flashinfer_fused_add_rmsnorm,
@ -218,7 +166,7 @@ class RMSNorm(nn.Module):
hidden_states = (self.weight +
1) * hidden_states.to(input_dtype)
if return_residual:
if has_residual:
return hidden_states, cast(Optional[torch.Tensor], residual)
else:
return hidden_states