diff --git a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py index 4fc53b676b..73ee7ee3d6 100644 --- a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py @@ -916,3 +916,23 @@ def _register_fake(): # This is a fake implementation for shape inference # The actual operation modifies fused_q and q_pe in-place return None + + @torch.library.register_fake("trtllm::fused_add_rms_norm_quant") + def _( + input: torch.Tensor, + residual: torch.Tensor, + gamma: torch.Tensor, + sf_scale: Optional[torch.Tensor], + use_rms_norm: bool = True, + eps: float = 1e-5, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + m, n = input.shape + # normed_output_fp4: [M, N/8] as int32 (8 FP4 values packed per int32) + normed_output_fp4 = input.new_empty((m, n // 8), dtype=torch.int32) + # output: [M, N] pre-norm output, same dtype as input + output = input.new_empty((m, n), dtype=input.dtype) + # sf_out: scale factors, swizzled layout + sf_vec_size = 16 + sf_size = ((m + 127) // 128) * 128 * ((n // sf_vec_size + 3) // 4) * 4 + sf_out = input.new_empty((sf_size, ), dtype=torch.uint8) + return normed_output_fp4, output, sf_out diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index c176fe70b3..5b683637c6 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -1876,59 +1876,6 @@ def record_stream(tensor: torch.Tensor, stream_id: int) -> None: tensor.record_stream(stream) -def fused_add_rms_norm_quant( - input: torch.Tensor, - residual: torch.Tensor, - gamma: torch.Tensor, - sf_scale: Optional[torch.Tensor], - use_rms_norm: bool = True, - eps: float = 1e-6, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Fused Add + RMSNorm/LayerNorm + FP4 Quantization kernel. - - Args: - input: [M, N] input tensor (fp16/bf16) - residual: [M, N] residual tensor (fp16/bf16) - gamma: [N] normalization weight (fp16/bf16) - sf_scale: [1] optional scale factor for FP4 quantization (float32) - use_rms_norm: if True use RMSNorm, else use LayerNorm - eps: epsilon for normalization - - Returns: - normed_output_fp4: [M, N/8] FP4 quantized normalized output (int32, packed) - output: [M, N] pre-norm output (input + residual), same dtype as input - sf_out: scale factors for FP4 quantization (uint8), swizzled layout - - Note: - This kernel requires SM90 (Hopper) or SM100 (Blackwell) GPU. - Hidden dimension N must be >= 2048 and <= 16384. - """ - return torch.ops.trtllm.fused_add_rms_norm_quant(input, residual, gamma, - sf_scale, use_rms_norm, - eps) - - -@torch.library.register_fake("trtllm::fused_add_rms_norm_quant") -def _fused_add_rms_norm_quant_fake( - input: torch.Tensor, - residual: torch.Tensor, - gamma: torch.Tensor, - sf_scale: Optional[torch.Tensor], - use_rms_norm: bool = True, - eps: float = 1e-5, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - m, n = input.shape - # normed_output_fp4: [M, N/8] as int32 (8 FP4 values packed per int32) - normed_output_fp4 = input.new_empty((m, n // 8), dtype=torch.int32) - # output: [M, N] pre-norm output, same dtype as input - output = input.new_empty((m, n), dtype=input.dtype) - # sf_out: scale factors, swizzled layout - sf_vec_size = 16 - sf_size = ((m + 127) // 128) * 128 * ((n // sf_vec_size + 3) // 4) * 4 - sf_out = input.new_empty((sf_size, ), dtype=torch.uint8) - return normed_output_fp4, output, sf_out - - class Fp4GemmAllreduceRunner(TunableRunner): runner_dict = dict() tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec( diff --git a/tensorrt_llm/_torch/modules/rms_norm.py b/tensorrt_llm/_torch/modules/rms_norm.py index a61f91e0cf..a3e6bcde96 100644 --- a/tensorrt_llm/_torch/modules/rms_norm.py +++ b/tensorrt_llm/_torch/modules/rms_norm.py @@ -20,8 +20,6 @@ from typing import Optional, Tuple, TypeAlias, Union, cast import torch from torch import nn -from tensorrt_llm.logger import logger - from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE from ..utils import Fp4QuantizedTensor @@ -76,109 +74,59 @@ class RMSNorm(nn.Module): _ArgumentNotSpecifiedSentinelType] = _ARGUMENT_NOT_SPECIFIED_SENTINEL, ) -> Union[torch.Tensor, Fp4QuantizedTensor, Tuple[Union[ torch.Tensor, Fp4QuantizedTensor], Optional[torch.Tensor]]]: - return_residual = True - if residual is self._ARGUMENT_NOT_SPECIFIED_SENTINEL: - return_residual = False + has_residual = residual is not self._ARGUMENT_NOT_SPECIFIED_SENTINEL + if not has_residual: residual = None - if self.is_nvfp4 and residual is not None and not self.use_gemma: + if self.is_nvfp4 and has_residual and not self.use_gemma: nvfp4_scale = getattr(self, "nvfp4_scale", None) if nvfp4_scale is None: raise ValueError( f"layeridx={getattr(self, 'layer_idx', None)} RMSNorm NVFP4 output requested " "but no `nvfp4_scale` is attached; ") - else: - def _can_use_fused_kernel() -> Tuple[bool, str]: - if not hidden_states.is_cuda or not residual.is_cuda: - return False, "inputs must be CUDA tensors" - if not self.weight.is_cuda: - return False, "gamma/weight must be a CUDA tensor" - if hidden_states.ndim < 2: - return False, "input must have rank >= 2" - if hidden_states.shape != residual.shape: - return False, f"input/residual shape mismatch: {tuple(hidden_states.shape)} vs {tuple(residual.shape)}" - n = int(hidden_states.shape[-1]) - if self.weight.ndim != 1 or int(self.weight.numel()) != n: - return False, f"gamma/weight must be 1D with numel == hidden_size ({n}), got shape={tuple(self.weight.shape)}" - # Match the underlying C++ op: fp16/bf16 only (no fp8). - if hidden_states.dtype not in (torch.float16, - torch.bfloat16): - return False, f"unsupported dtype {hidden_states.dtype} (expected fp16/bf16)" - if n % 16 != 0: - return False, f"hidden size must be divisible by 16 (got {n})" - # Kernel constraints (see fusedAddRMSNormQuant.cpp). - if n < 2048 or n > 16384: - return False, f"hidden size must be in [2048, 16384] (got {n})" - # SM90+ only. - major, _minor = torch.cuda.get_device_capability( - hidden_states.device) - if major < 9: - return False, f"requires SM90+ GPU, got SM{major}{_minor}" - # Scale tensor constraints. - if (nvfp4_scale is not None - and ((not nvfp4_scale.is_cuda) or nvfp4_scale.dtype - != torch.float32 or nvfp4_scale.numel() != 1)): - return False, f"nvfp4_scale must be a CUDA float32 tensor with numel==1 (got dtype={getattr(nvfp4_scale, 'dtype', None)}, device={getattr(nvfp4_scale, 'device', None)}, numel={getattr(nvfp4_scale, 'numel', lambda: None)()})" - return True, "" + orig_shape = tuple(hidden_states.shape) + n = int(orig_shape[-1]) + hs_2d = hidden_states.reshape(-1, n).contiguous() + res_2d = residual.reshape(-1, n) + gamma = self.weight - ok, reason = _can_use_fused_kernel() - if not ok: - raise RuntimeError( - "RMSNorm NVFP4 fused path disabled due to unsupported inputs " - f"(falling back to unfused RMSNorm): {reason}") - else: - from ..custom_ops.torch_custom_ops import \ - fused_add_rms_norm_quant - - orig_shape = tuple(hidden_states.shape) - n = int(orig_shape[-1]) - hs_2d = hidden_states.reshape(-1, n).contiguous() - res_2d = residual.reshape(-1, n) - gamma = self.weight - - def _ensure_contiguous_with_dtype(t: torch.Tensor, - key: str): - if t.dtype != hs_2d.dtype: - logger.warning_once( - f"RMSNorm NVFP4 fused path: casting {key} from {t.dtype} to {hs_2d.dtype}.", - key=f"rmsnorm_nvfp4_cast_{key}", - ) - t = t.to(dtype=hs_2d.dtype) - return t.contiguous() - - res_2d = _ensure_contiguous_with_dtype(res_2d, "residual") - gamma = _ensure_contiguous_with_dtype(gamma, "gamma") - - if hs_2d.device != res_2d.device or hs_2d.device != gamma.device: - raise RuntimeError( - "RMSNorm NVFP4 fused path requires all tensors on the same device. " - f"Got input={hs_2d.device}, residual={res_2d.device}, gamma={gamma.device}." - ) - - sf_scale = nvfp4_scale.contiguous( - ) if nvfp4_scale is not None else None - - normed_fp4_i32, residual_out_2d, sf_fused = fused_add_rms_norm_quant( - hs_2d, - res_2d, - gamma, - sf_scale, - True, - eps=self.variance_epsilon, + def _ensure_contiguous_with_dtype(t: torch.Tensor, key: str): + if t.dtype != hs_2d.dtype: + raise ValueError( + f"RMSNorm NVFP4 fused path: casting {key} from {t.dtype} to {hs_2d.dtype}." ) - normed_fp4_u8 = normed_fp4_i32.view(torch.uint8) - if len(orig_shape) != 2: - normed_fp4_u8 = normed_fp4_u8.reshape( - *orig_shape[:-1], n // 2) - residual_out = residual_out_2d.reshape(orig_shape) - else: - residual_out = residual_out_2d + return t.contiguous() - hidden_states_fused = Fp4QuantizedTensor( - normed_fp4_u8, sf_fused) - return (hidden_states_fused, residual_out - ) if return_residual else hidden_states_fused + res_2d = _ensure_contiguous_with_dtype(res_2d, "residual") + gamma = _ensure_contiguous_with_dtype(gamma, "gamma") + + if hs_2d.device != res_2d.device or hs_2d.device != gamma.device: + raise RuntimeError( + "RMSNorm NVFP4 fused path requires all tensors on the same device. " + f"Got input={hs_2d.device}, residual={res_2d.device}, gamma={gamma.device}." + ) + + sf_scale = nvfp4_scale.contiguous() + + normed_fp4_i32, residual_out_2d, sf_fused = torch.ops.trtllm.fused_add_rms_norm_quant( + hs_2d, + res_2d, + gamma, + sf_scale, + True, + eps=self.variance_epsilon, + ) + normed_fp4_u8 = normed_fp4_i32.view(torch.uint8) + if len(orig_shape) != 2: + normed_fp4_u8 = normed_fp4_u8.reshape(*orig_shape[:-1], n // 2) + residual_out = residual_out_2d.reshape(orig_shape) + else: + residual_out = residual_out_2d + + hidden_states_fused = Fp4QuantizedTensor(normed_fp4_u8, sf_fused) + return (hidden_states_fused, + residual_out) if has_residual else hidden_states_fused if IS_FLASHINFER_AVAILABLE: from ..custom_ops import (flashinfer_fused_add_rmsnorm, @@ -218,7 +166,7 @@ class RMSNorm(nn.Module): hidden_states = (self.weight + 1) * hidden_states.to(input_dtype) - if return_residual: + if has_residual: return hidden_states, cast(Optional[torch.Tensor], residual) else: return hidden_states