mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
[None][fix] Update RMSNorm custom op plumbing (#10843)
Signed-off-by: jintaop <jintaop@nvidia.com>
This commit is contained in:
parent
1dc49b266e
commit
9beb971827
@ -916,3 +916,23 @@ def _register_fake():
|
||||
# This is a fake implementation for shape inference
|
||||
# The actual operation modifies fused_q and q_pe in-place
|
||||
return None
|
||||
|
||||
@torch.library.register_fake("trtllm::fused_add_rms_norm_quant")
|
||||
def _(
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
gamma: torch.Tensor,
|
||||
sf_scale: Optional[torch.Tensor],
|
||||
use_rms_norm: bool = True,
|
||||
eps: float = 1e-5,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
m, n = input.shape
|
||||
# normed_output_fp4: [M, N/8] as int32 (8 FP4 values packed per int32)
|
||||
normed_output_fp4 = input.new_empty((m, n // 8), dtype=torch.int32)
|
||||
# output: [M, N] pre-norm output, same dtype as input
|
||||
output = input.new_empty((m, n), dtype=input.dtype)
|
||||
# sf_out: scale factors, swizzled layout
|
||||
sf_vec_size = 16
|
||||
sf_size = ((m + 127) // 128) * 128 * ((n // sf_vec_size + 3) // 4) * 4
|
||||
sf_out = input.new_empty((sf_size, ), dtype=torch.uint8)
|
||||
return normed_output_fp4, output, sf_out
|
||||
|
||||
@ -1876,59 +1876,6 @@ def record_stream(tensor: torch.Tensor, stream_id: int) -> None:
|
||||
tensor.record_stream(stream)
|
||||
|
||||
|
||||
def fused_add_rms_norm_quant(
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
gamma: torch.Tensor,
|
||||
sf_scale: Optional[torch.Tensor],
|
||||
use_rms_norm: bool = True,
|
||||
eps: float = 1e-6,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Fused Add + RMSNorm/LayerNorm + FP4 Quantization kernel.
|
||||
|
||||
Args:
|
||||
input: [M, N] input tensor (fp16/bf16)
|
||||
residual: [M, N] residual tensor (fp16/bf16)
|
||||
gamma: [N] normalization weight (fp16/bf16)
|
||||
sf_scale: [1] optional scale factor for FP4 quantization (float32)
|
||||
use_rms_norm: if True use RMSNorm, else use LayerNorm
|
||||
eps: epsilon for normalization
|
||||
|
||||
Returns:
|
||||
normed_output_fp4: [M, N/8] FP4 quantized normalized output (int32, packed)
|
||||
output: [M, N] pre-norm output (input + residual), same dtype as input
|
||||
sf_out: scale factors for FP4 quantization (uint8), swizzled layout
|
||||
|
||||
Note:
|
||||
This kernel requires SM90 (Hopper) or SM100 (Blackwell) GPU.
|
||||
Hidden dimension N must be >= 2048 and <= 16384.
|
||||
"""
|
||||
return torch.ops.trtllm.fused_add_rms_norm_quant(input, residual, gamma,
|
||||
sf_scale, use_rms_norm,
|
||||
eps)
|
||||
|
||||
|
||||
@torch.library.register_fake("trtllm::fused_add_rms_norm_quant")
|
||||
def _fused_add_rms_norm_quant_fake(
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
gamma: torch.Tensor,
|
||||
sf_scale: Optional[torch.Tensor],
|
||||
use_rms_norm: bool = True,
|
||||
eps: float = 1e-5,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
m, n = input.shape
|
||||
# normed_output_fp4: [M, N/8] as int32 (8 FP4 values packed per int32)
|
||||
normed_output_fp4 = input.new_empty((m, n // 8), dtype=torch.int32)
|
||||
# output: [M, N] pre-norm output, same dtype as input
|
||||
output = input.new_empty((m, n), dtype=input.dtype)
|
||||
# sf_out: scale factors, swizzled layout
|
||||
sf_vec_size = 16
|
||||
sf_size = ((m + 127) // 128) * 128 * ((n // sf_vec_size + 3) // 4) * 4
|
||||
sf_out = input.new_empty((sf_size, ), dtype=torch.uint8)
|
||||
return normed_output_fp4, output, sf_out
|
||||
|
||||
|
||||
class Fp4GemmAllreduceRunner(TunableRunner):
|
||||
runner_dict = dict()
|
||||
tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec(
|
||||
|
||||
@ -20,8 +20,6 @@ from typing import Optional, Tuple, TypeAlias, Union, cast
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
|
||||
from ..utils import Fp4QuantizedTensor
|
||||
|
||||
@ -76,109 +74,59 @@ class RMSNorm(nn.Module):
|
||||
_ArgumentNotSpecifiedSentinelType] = _ARGUMENT_NOT_SPECIFIED_SENTINEL,
|
||||
) -> Union[torch.Tensor, Fp4QuantizedTensor, Tuple[Union[
|
||||
torch.Tensor, Fp4QuantizedTensor], Optional[torch.Tensor]]]:
|
||||
return_residual = True
|
||||
if residual is self._ARGUMENT_NOT_SPECIFIED_SENTINEL:
|
||||
return_residual = False
|
||||
has_residual = residual is not self._ARGUMENT_NOT_SPECIFIED_SENTINEL
|
||||
if not has_residual:
|
||||
residual = None
|
||||
|
||||
if self.is_nvfp4 and residual is not None and not self.use_gemma:
|
||||
if self.is_nvfp4 and has_residual and not self.use_gemma:
|
||||
nvfp4_scale = getattr(self, "nvfp4_scale", None)
|
||||
if nvfp4_scale is None:
|
||||
raise ValueError(
|
||||
f"layeridx={getattr(self, 'layer_idx', None)} RMSNorm NVFP4 output requested "
|
||||
"but no `nvfp4_scale` is attached; ")
|
||||
else:
|
||||
|
||||
def _can_use_fused_kernel() -> Tuple[bool, str]:
|
||||
if not hidden_states.is_cuda or not residual.is_cuda:
|
||||
return False, "inputs must be CUDA tensors"
|
||||
if not self.weight.is_cuda:
|
||||
return False, "gamma/weight must be a CUDA tensor"
|
||||
if hidden_states.ndim < 2:
|
||||
return False, "input must have rank >= 2"
|
||||
if hidden_states.shape != residual.shape:
|
||||
return False, f"input/residual shape mismatch: {tuple(hidden_states.shape)} vs {tuple(residual.shape)}"
|
||||
n = int(hidden_states.shape[-1])
|
||||
if self.weight.ndim != 1 or int(self.weight.numel()) != n:
|
||||
return False, f"gamma/weight must be 1D with numel == hidden_size ({n}), got shape={tuple(self.weight.shape)}"
|
||||
# Match the underlying C++ op: fp16/bf16 only (no fp8).
|
||||
if hidden_states.dtype not in (torch.float16,
|
||||
torch.bfloat16):
|
||||
return False, f"unsupported dtype {hidden_states.dtype} (expected fp16/bf16)"
|
||||
if n % 16 != 0:
|
||||
return False, f"hidden size must be divisible by 16 (got {n})"
|
||||
# Kernel constraints (see fusedAddRMSNormQuant.cpp).
|
||||
if n < 2048 or n > 16384:
|
||||
return False, f"hidden size must be in [2048, 16384] (got {n})"
|
||||
# SM90+ only.
|
||||
major, _minor = torch.cuda.get_device_capability(
|
||||
hidden_states.device)
|
||||
if major < 9:
|
||||
return False, f"requires SM90+ GPU, got SM{major}{_minor}"
|
||||
# Scale tensor constraints.
|
||||
if (nvfp4_scale is not None
|
||||
and ((not nvfp4_scale.is_cuda) or nvfp4_scale.dtype
|
||||
!= torch.float32 or nvfp4_scale.numel() != 1)):
|
||||
return False, f"nvfp4_scale must be a CUDA float32 tensor with numel==1 (got dtype={getattr(nvfp4_scale, 'dtype', None)}, device={getattr(nvfp4_scale, 'device', None)}, numel={getattr(nvfp4_scale, 'numel', lambda: None)()})"
|
||||
return True, ""
|
||||
orig_shape = tuple(hidden_states.shape)
|
||||
n = int(orig_shape[-1])
|
||||
hs_2d = hidden_states.reshape(-1, n).contiguous()
|
||||
res_2d = residual.reshape(-1, n)
|
||||
gamma = self.weight
|
||||
|
||||
ok, reason = _can_use_fused_kernel()
|
||||
if not ok:
|
||||
raise RuntimeError(
|
||||
"RMSNorm NVFP4 fused path disabled due to unsupported inputs "
|
||||
f"(falling back to unfused RMSNorm): {reason}")
|
||||
else:
|
||||
from ..custom_ops.torch_custom_ops import \
|
||||
fused_add_rms_norm_quant
|
||||
|
||||
orig_shape = tuple(hidden_states.shape)
|
||||
n = int(orig_shape[-1])
|
||||
hs_2d = hidden_states.reshape(-1, n).contiguous()
|
||||
res_2d = residual.reshape(-1, n)
|
||||
gamma = self.weight
|
||||
|
||||
def _ensure_contiguous_with_dtype(t: torch.Tensor,
|
||||
key: str):
|
||||
if t.dtype != hs_2d.dtype:
|
||||
logger.warning_once(
|
||||
f"RMSNorm NVFP4 fused path: casting {key} from {t.dtype} to {hs_2d.dtype}.",
|
||||
key=f"rmsnorm_nvfp4_cast_{key}",
|
||||
)
|
||||
t = t.to(dtype=hs_2d.dtype)
|
||||
return t.contiguous()
|
||||
|
||||
res_2d = _ensure_contiguous_with_dtype(res_2d, "residual")
|
||||
gamma = _ensure_contiguous_with_dtype(gamma, "gamma")
|
||||
|
||||
if hs_2d.device != res_2d.device or hs_2d.device != gamma.device:
|
||||
raise RuntimeError(
|
||||
"RMSNorm NVFP4 fused path requires all tensors on the same device. "
|
||||
f"Got input={hs_2d.device}, residual={res_2d.device}, gamma={gamma.device}."
|
||||
)
|
||||
|
||||
sf_scale = nvfp4_scale.contiguous(
|
||||
) if nvfp4_scale is not None else None
|
||||
|
||||
normed_fp4_i32, residual_out_2d, sf_fused = fused_add_rms_norm_quant(
|
||||
hs_2d,
|
||||
res_2d,
|
||||
gamma,
|
||||
sf_scale,
|
||||
True,
|
||||
eps=self.variance_epsilon,
|
||||
def _ensure_contiguous_with_dtype(t: torch.Tensor, key: str):
|
||||
if t.dtype != hs_2d.dtype:
|
||||
raise ValueError(
|
||||
f"RMSNorm NVFP4 fused path: casting {key} from {t.dtype} to {hs_2d.dtype}."
|
||||
)
|
||||
normed_fp4_u8 = normed_fp4_i32.view(torch.uint8)
|
||||
if len(orig_shape) != 2:
|
||||
normed_fp4_u8 = normed_fp4_u8.reshape(
|
||||
*orig_shape[:-1], n // 2)
|
||||
residual_out = residual_out_2d.reshape(orig_shape)
|
||||
else:
|
||||
residual_out = residual_out_2d
|
||||
return t.contiguous()
|
||||
|
||||
hidden_states_fused = Fp4QuantizedTensor(
|
||||
normed_fp4_u8, sf_fused)
|
||||
return (hidden_states_fused, residual_out
|
||||
) if return_residual else hidden_states_fused
|
||||
res_2d = _ensure_contiguous_with_dtype(res_2d, "residual")
|
||||
gamma = _ensure_contiguous_with_dtype(gamma, "gamma")
|
||||
|
||||
if hs_2d.device != res_2d.device or hs_2d.device != gamma.device:
|
||||
raise RuntimeError(
|
||||
"RMSNorm NVFP4 fused path requires all tensors on the same device. "
|
||||
f"Got input={hs_2d.device}, residual={res_2d.device}, gamma={gamma.device}."
|
||||
)
|
||||
|
||||
sf_scale = nvfp4_scale.contiguous()
|
||||
|
||||
normed_fp4_i32, residual_out_2d, sf_fused = torch.ops.trtllm.fused_add_rms_norm_quant(
|
||||
hs_2d,
|
||||
res_2d,
|
||||
gamma,
|
||||
sf_scale,
|
||||
True,
|
||||
eps=self.variance_epsilon,
|
||||
)
|
||||
normed_fp4_u8 = normed_fp4_i32.view(torch.uint8)
|
||||
if len(orig_shape) != 2:
|
||||
normed_fp4_u8 = normed_fp4_u8.reshape(*orig_shape[:-1], n // 2)
|
||||
residual_out = residual_out_2d.reshape(orig_shape)
|
||||
else:
|
||||
residual_out = residual_out_2d
|
||||
|
||||
hidden_states_fused = Fp4QuantizedTensor(normed_fp4_u8, sf_fused)
|
||||
return (hidden_states_fused,
|
||||
residual_out) if has_residual else hidden_states_fused
|
||||
|
||||
if IS_FLASHINFER_AVAILABLE:
|
||||
from ..custom_ops import (flashinfer_fused_add_rmsnorm,
|
||||
@ -218,7 +166,7 @@ class RMSNorm(nn.Module):
|
||||
hidden_states = (self.weight +
|
||||
1) * hidden_states.to(input_dtype)
|
||||
|
||||
if return_residual:
|
||||
if has_residual:
|
||||
return hidden_states, cast(Optional[torch.Tensor], residual)
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
Loading…
Reference in New Issue
Block a user