From a08e8f7bbcb11da2663a509a16d03462ddc769d2 Mon Sep 17 00:00:00 2001 From: JtaoPeng Date: Tue, 13 Jan 2026 06:28:18 +0000 Subject: [PATCH] update torch_ext API and debugging test for FusedAddRMSNorm update #define for hopper & blackwell Update cpp/tensorrt_llm/thop/fusedAddRMSNormQuant.cpp Signed-off-by: jintaop --- .../low_latency_layernorm.cuh | 33 +-- .../fusedLayernormKernels/ws_layernorm.cuh | 71 ++++--- cpp/tensorrt_llm/thop/CMakeLists.txt | 1 + .../thop/fusedAddRMSNormQuant.cpp | 200 ++++++++++++++++++ .../_torch/custom_ops/torch_custom_ops.py | 53 +++++ tensorrt_llm/_torch/models/modeling_llama.py | 45 +++- tensorrt_llm/_torch/modules/rms_norm.py | 111 +++++++++- tensorrt_llm/mapping.py | 33 +++ 8 files changed, 497 insertions(+), 50 deletions(-) create mode 100644 cpp/tensorrt_llm/thop/fusedAddRMSNormQuant.cpp diff --git a/cpp/tensorrt_llm/kernels/fusedLayernormKernels/low_latency_layernorm.cuh b/cpp/tensorrt_llm/kernels/fusedLayernormKernels/low_latency_layernorm.cuh index 9545d919c5..6a925c5510 100644 --- a/cpp/tensorrt_llm/kernels/fusedLayernormKernels/low_latency_layernorm.cuh +++ b/cpp/tensorrt_llm/kernels/fusedLayernormKernels/low_latency_layernorm.cuh @@ -115,8 +115,6 @@ struct LowLatencyLayerNorm uint32_t work_id = blockIdx.x; - FusedOperator fused_operator(param); - constexpr auto PACKED_PER_N_BLOCK = Traits::N_BLOCK / N_THREADS / Traits::PACKED_ELEMS_PER_COMPUTE; typename Traits::AccumulatorType data[PACKED_PER_N_BLOCK][Traits::PACKED_ELEMS_PER_COMPUTE]; @@ -139,7 +137,7 @@ struct LowLatencyLayerNorm for (int i = 0; i < PACKED_PER_N_BLOCK; i++) { auto offset = (thread_id + i * N_THREADS) * Traits::PACKED_ELEMS_PER_COMPUTE; - if (offset <= sz) + if (offset < sz) { data[i] = *reinterpret_cast(&g_data[offset]); } @@ -155,6 +153,14 @@ struct LowLatencyLayerNorm static_assert(Traits::OUTPUT_SCALE != SCALE_TYPE::VECTOR); +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) + if constexpr (arch::is_major_v<9> || arch::is_major_v<10>) + { + cudaGridDependencySynchronize(); + } +#endif + FusedOperator fused_operator(param); + if constexpr (Traits::BIAS == SCALE_TYPE::VECTOR) { load_to_register(param.bias, r_bias, param.n); @@ -175,13 +181,6 @@ struct LowLatencyLayerNorm load_to_register(param.beta, r_beta, param.n); } -#if (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 12)) - if constexpr (arch::is_major_v<9> || arch::is_major_v<10>) - { - cudaGridDependencySynchronize(); - cudaTriggerProgrammaticLaunchCompletion(); - } -#endif load_to_register(¶m.input[work_id * param.n], data, param.n); if constexpr (Traits::RESIDUAL) @@ -259,12 +258,12 @@ struct LowLatencyLayerNorm if constexpr (!Traits::RMS_NORM) { mean = var_and_mean[1] / param.n; - variance = rsqrtf( - var_and_mean[0] / param.n - var_and_mean[1] * var_and_mean[1] + (Traits::AccumulatorType)(1e-5)); + variance = rsqrtf(var_and_mean[0] / param.n - var_and_mean[1] * var_and_mean[1] + + (Traits::AccumulatorType)(param.layernorm_eps)); } else { - variance = rsqrtf(var_and_mean[0] / param.n + (Traits::AccumulatorType)(1e-5)); + variance = rsqrtf(var_and_mean[0] / param.n + (Traits::AccumulatorType)(param.layernorm_eps)); } for (int i = 0; i < PACKED_PER_N_BLOCK; i++) @@ -333,6 +332,14 @@ struct LowLatencyLayerNorm { __shared__ Shared shared; compute(param, &shared); + __syncthreads(); + asm volatile("membar.gl;" : : : "memory"); +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) + if constexpr (arch::is_major_v<9> || arch::is_major_v<10>) + { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif } }; diff --git a/cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.cuh b/cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.cuh index e850086f1b..5359b9dc55 100644 --- a/cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.cuh +++ b/cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.cuh @@ -201,25 +201,35 @@ struct WarpSpecializedLayerNorm } // if (blockIdx.x == 0) printf("Pushed tile %d to MATH.\n", m_base); + if constexpr (FIRST_RUN) + { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) + if constexpr (arch::is_major_v<9> || arch::is_major_v<10>) + { + // Ensure upstream kernel writes are visible before reading dependent activation/residual data. + cudaGridDependencySynchronize(); + } +#endif + } + const uint32_t eff_m_block + = std::min(static_cast(Traits::M_BLOCK), static_cast(param.m - m_base)); const auto tx - = (Traits::M_BLOCK * param.n * sizeof(typename Traits::InputType) * (Traits::RESIDUAL ? 2 : 1)) - + (FIRST_RUN ? sizeof(AuxData) / Traits::N_BLOCK * param.n : 0); + = (eff_m_block * param.n * sizeof(typename Traits::InputType) * (Traits::RESIDUAL ? 2 : 1)) + + (FIRST_RUN ? (sizeof(AuxData) / Traits::N_BLOCK * param.n) : 0); auto vec_buffer_ptr = input_vec_fifo_w.tmaReserve(tx); // if (blockIdx.x == 0) printf("SMEM buffer ready, start loading tile %d.\n", m_base); - if constexpr (FIRST_RUN) - { - cudaGridDependencySynchronize(); - } - for (int i = 0; i < Traits::M_BLOCK; i++) { - load_a_vec(¶m.input[(m_base + i) * param.n], - __nvvm_get_smem_pointer(&shared->input_vec[vec_buffer_ptr][0][i * Traits::N_BLOCK]), - param.n * sizeof(typename Traits::InputType), - __nvvm_get_smem_pointer(input_vec_fifo_w.barrier_ptr(vec_buffer_ptr))); + if (i < eff_m_block) [[likely]] + { + load_a_vec(¶m.input[(m_base + i) * param.n], + __nvvm_get_smem_pointer(&shared->input_vec[vec_buffer_ptr][0][i * Traits::N_BLOCK]), + param.n * sizeof(typename Traits::InputType), + __nvvm_get_smem_pointer(input_vec_fifo_w.barrier_ptr(vec_buffer_ptr))); + } } // Use templated lambdas to defer resolving the symbols like "param.residual". @@ -231,10 +241,13 @@ struct WarpSpecializedLayerNorm { for (int i = 0; i < Traits::M_BLOCK; i++) { - load_a_vec(¶m.residual[(m_base + i) * param.n], - __nvvm_get_smem_pointer(&shared->input_vec[vec_buffer_ptr][1][i * Traits::N_BLOCK]), - param.n * sizeof(typename Traits::InputType), - __nvvm_get_smem_pointer(input_vec_fifo_w.barrier_ptr(vec_buffer_ptr))); + if (i < eff_m_block) [[likely]] + { + load_a_vec(¶m.residual[(m_base + i) * param.n], + __nvvm_get_smem_pointer(&shared->input_vec[vec_buffer_ptr][1][i * Traits::N_BLOCK]), + param.n * sizeof(typename Traits::InputType), + __nvvm_get_smem_pointer(input_vec_fifo_w.barrier_ptr(vec_buffer_ptr))); + } } }(param); } @@ -423,6 +436,13 @@ struct WarpSpecializedLayerNorm using FusedOperator = GetFusedOperator; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) + if constexpr (arch::is_major_v<9> || arch::is_major_v<10>) + { + // Ensure upstream kernel writes are visible before reading dependent activation/residual data. + cudaGridDependencySynchronize(); + } +#endif FusedOperator fused_operator(param); static_assert(Traits::PERSISTENT_MODE || Traits::MATH_WARPGROUPS == 1); @@ -446,6 +466,9 @@ struct WarpSpecializedLayerNorm { m_base = block_id; } + const uint32_t eff_m_block + = std::min(static_cast(Traits::M_BLOCK), static_cast(param.m - m_base)); + // if (blockIdx.x == 0 && thread_id == 0) printf("MATH got tile %d.\n", m_base); // Peek for data ready. @@ -613,11 +636,12 @@ struct WarpSpecializedLayerNorm { mean[m_offset] /= param.n; variance[m_offset] = rsqrtf(variance[m_offset] / param.n - mean[m_offset] * mean[m_offset] - + (Traits::AccumulatorType)(1e-5)); + + (Traits::AccumulatorType)(param.layernorm_eps)); } else { - variance[m_offset] = rsqrtf(variance[m_offset] / param.n + (Traits::AccumulatorType)(1e-5)); + variance[m_offset] + = rsqrtf(variance[m_offset] / param.n + (Traits::AccumulatorType)(param.layernorm_eps)); } } @@ -659,8 +683,7 @@ struct WarpSpecializedLayerNorm } } -#pragma unroll Traits::M_BLOCK - for (int m_offset = 0; m_offset < Traits::M_BLOCK; m_offset++) + for (int m_offset = 0; m_offset < eff_m_block; m_offset++) { auto m = m_base + m_offset; @@ -801,23 +824,19 @@ struct WarpSpecializedLayerNorm shared->init(threadIdx.x == 0); __syncthreads(); -#if (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 12)) -#if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM100_ALL)) +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12) if constexpr (arch::is_major_v<9> || arch::is_major_v<10>) { auto block_id = blockIdx.x; auto warp_id = threadIdx.x / 32; auto lane_id = threadIdx.x % 32; auto tid_in_wg = threadIdx.x % 128; - if (warp_id < 4) { asm volatile("{setmaxnreg.dec.sync.aligned.u32 56; \n\t}"); if (warp_id == 0) { scheduler(lane_id, gridDim.x * gridDim.y * gridDim.z, param, shared); - // PRE-EXIT after all tiles have been scheduled. - cudaTriggerProgrammaticLaunchCompletion(); } else if (warp_id == 1) { @@ -829,8 +848,10 @@ struct WarpSpecializedLayerNorm asm volatile("{setmaxnreg.inc.sync.aligned.u32 224; \n\t}"); compute(block_id, threadIdx.x / 128 - 1, tid_in_wg, param, shared); } + __syncthreads(); + asm volatile("membar.gl;" : : : "memory"); + cudaTriggerProgrammaticLaunchCompletion(); } -#endif #endif } }; diff --git a/cpp/tensorrt_llm/thop/CMakeLists.txt b/cpp/tensorrt_llm/thop/CMakeLists.txt index 20fcc35b82..b06a0782a3 100644 --- a/cpp/tensorrt_llm/thop/CMakeLists.txt +++ b/cpp/tensorrt_llm/thop/CMakeLists.txt @@ -66,6 +66,7 @@ add_library( fp8Quantize.cpp dsv3FusedAGemmOp.cpp fusedQKNormRopeOp.cpp + fusedAddRMSNormQuant.cpp fusedTopkSoftmax.cpp gatherTreeOp.cpp groupRmsNormOp.cpp diff --git a/cpp/tensorrt_llm/thop/fusedAddRMSNormQuant.cpp b/cpp/tensorrt_llm/thop/fusedAddRMSNormQuant.cpp new file mode 100644 index 0000000000..0d76aff4f8 --- /dev/null +++ b/cpp/tensorrt_llm/thop/fusedAddRMSNormQuant.cpp @@ -0,0 +1,200 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/kernels/fusedLayernormKernels/layernorm_param.h" +#include "tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.h" +#include "tensorrt_llm/kernels/quantization.h" +#include "tensorrt_llm/thop/thUtils.h" + +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +TRTLLM_NAMESPACE_BEGIN + +namespace torch_ext +{ + +// Fused Add + RMSNorm + FP4 Quantization kernel +// input: [M, N] - input tensor (fp16/bf16) +// residual: [M, N] - residual tensor (fp16/bf16) +// gamma: [N] - RMSNorm weight (fp16/bf16) +// sf_scale: [1] - optional scale factor for FP4 quantization (float) +// use_rms_norm: bool - if true use RMSNorm, else use LayerNorm +// Returns: +// normed_output: [M, N/8] - FP4 quantized normalized output (uint32_t, packed) +// output: [M, N] - pre-norm output (input + residual), same dtype as input +// sf_out: scale factors for FP4 (uint8_t), swizzled layout +// +// NOTE: This kernel requires SM90 (Hopper) or SM100 (Blackwell) GPU architecture. +// NOTE: Hidden dimension N must be >= 2048 and <= 16384. +std::tuple fused_add_rms_norm_quant(at::Tensor const& input, + at::Tensor const& residual, at::Tensor const& gamma, std::optional const& sf_scale, bool use_rms_norm, + double eps) +{ + CHECK_TH_CUDA(input); + CHECK_CONTIGUOUS(input); + CHECK_TH_CUDA(residual); + CHECK_CONTIGUOUS(residual); + CHECK_TH_CUDA(gamma); + CHECK_CONTIGUOUS(gamma); + + // Check GPU architecture - kernel requires SM90+ (Hopper/Blackwell) + auto const device = input.get_device(); + cudaDeviceProp props; + AT_CUDA_CHECK(cudaGetDeviceProperties(&props, device)); + TORCH_CHECK(props.major >= 9, + "fused_add_rms_norm_quant requires SM90 (Hopper) or newer GPU architecture. " + "Current device: sm_", + props.major, props.minor); + + auto const& inputShape = input.sizes(); + auto const& rank = inputShape.size(); + + TORCH_CHECK(rank == 2, "input should be 2D tensor [M, N]."); + TORCH_CHECK(residual.sizes() == inputShape, "residual shape must match input shape."); + + int64_t const m = inputShape[0]; + int64_t const n = inputShape[1]; + // Some warp-specialized kernels may issue vectorized stores that assume M is padded. + // Allocate a bit of extra space to avoid out-of-bounds writes when M is not a multiple of 8. + int64_t const m_padded = (m + 31) / 32 * 32; + + TORCH_CHECK(gamma.sizes()[0] == n, "gamma size must match hidden dimension N."); + TORCH_CHECK(n >= 2048, "Hidden dimension N must be >= 2048 (kernel constraint)."); + TORCH_CHECK(n <= 16384, "Hidden dimension N must be <= 16384."); + TORCH_CHECK(n % 16 == 0, "Hidden dimension N must be divisible by 16 for FP4 quantization."); + + // Validate sf_scale if provided + float* sfScalePtr = nullptr; + if (sf_scale.has_value()) + { + CHECK_INPUT(sf_scale.value(), torch::kFloat32); + sfScalePtr = sf_scale.value().data_ptr(); + } + + // Allocate output tensors + // normed_output: FP4 packed output [M, N/8] as uint32_t (8 FP4 values packed per uint32) + // NOTE: allocate [M_padded, ...] to avoid OOB writes; return a view of [M, ...] to keep API stable. + at::Tensor normed_output_padded + = at::detail::empty_cuda({m_padded, n / 8}, torch::kInt32, input.device(), std::nullopt); + at::Tensor normed_output = (m_padded == m) ? normed_output_padded : normed_output_padded.narrow(0, 0, m); + + // output: pre-norm output (input + residual) [M, N], same dtype as input + // NOTE: allocate [M_padded, ...] to avoid OOB writes; return a view of [M, ...] to keep API stable. + at::Tensor output_padded = at::detail::empty_cuda({m_padded, n}, input.scalar_type(), input.device(), std::nullopt); + at::Tensor output = (m_padded == m) ? output_padded : output_padded.narrow(0, 0, m); + + // sf_out: scale factors for FP4, swizzled layout + // sfVecSize = 16 for FP4 quantization (16 FP4 values share one scale factor) + int64_t const sfVecSize = 16; + // NOTE: allocate using m_padded to avoid OOB writes for warp-specialized/vectorized stores when M is not padded. + // Return a view of the original (un-padded) size to keep the API stable. + int64_t const sfSize = tensorrt_llm::computeSwizzledLayoutSFSize(m, n / sfVecSize); + int64_t const sfSizePadded = tensorrt_llm::computeSwizzledLayoutSFSize(m_padded, n / sfVecSize); + at::Tensor sf_out_padded = at::detail::empty_cuda({sfSizePadded}, SF_DTYPE, input.device(), std::nullopt); + at::Tensor sf_out = (m_padded == m) ? sf_out_padded : sf_out_padded.narrow(0, 0, sfSize); + + // Get number of SMs for persistent kernel + static int const multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount(); + + // Allocate counters for warp-specialized kernel using PyTorch allocator. + // + // NOTE: We cache this tensor to avoid per-call allocations. We use `thread_local` so + // concurrent calls from different threads don't share the same counters buffer (which + // could cause races across different CUDA streams). + static thread_local std::unordered_map counters_tensor_cache; + auto& counters_tensor = counters_tensor_cache[device]; + int64_t const counters_bytes = static_cast(sizeof(tensorrt_llm::kernels::WarpSpecializedCounters)); + if (!counters_tensor.defined() || counters_tensor.numel() != counters_bytes) + { + counters_tensor = at::detail::empty_cuda({counters_bytes}, torch::kByte, input.device(), std::nullopt); + counters_tensor.zero_(); + } + auto* counters + = reinterpret_cast(counters_tensor.mutable_data_ptr()); + + auto stream = at::cuda::getCurrentCUDAStream(device); + +#define LAUNCH_FUSED_ADD_RMS_NORM_QUANT(T) \ + do \ + { \ + using Param = tensorrt_llm::kernels::GeneralFP4AddBiasResidualPreLayerNormParam; \ + tensorrt_llm::kernels::WarpSpecializedParam param; \ + param.normed_output = reinterpret_cast(normed_output.data_ptr()); \ + param.output = reinterpret_cast(output.data_ptr()); \ + param.input = const_cast(reinterpret_cast(input.data_ptr())); \ + param.sf_scale = sfScalePtr; \ + param.sf_out = reinterpret_cast(sf_out.data_ptr()); \ + param.residual = reinterpret_cast(residual.data_ptr()); \ + param.bias = nullptr; \ + param.gamma = reinterpret_cast(gamma.data_ptr()); \ + param.beta = nullptr; \ + param.m = static_cast(m); \ + param.n = static_cast(n); \ + param.layernorm_eps = static_cast(eps); \ + param.stream = stream; \ + param.counters = counters; \ + tensorrt_llm::kernels::invokeWSLayerNorm(param, use_rms_norm, multiProcessorCount); \ + } while (0) + + if (input.scalar_type() == at::ScalarType::Half) + { + LAUNCH_FUSED_ADD_RMS_NORM_QUANT(half); + } + else if (input.scalar_type() == at::ScalarType::BFloat16) + { +#ifdef ENABLE_BF16 + LAUNCH_FUSED_ADD_RMS_NORM_QUANT(__nv_bfloat16); +#else + C10_THROW_ERROR(NotImplementedError, "BFloat16 must be enabled for fused_add_rms_norm_quant with bf16 input."); +#endif + } + else + { + C10_THROW_ERROR( + NotImplementedError, "fused_add_rms_norm_quant only supports input tensor with dtypes fp16/bf16."); + } + +#undef LAUNCH_FUSED_ADD_RMS_NORM_QUANT + // No explicit sync needed - kernel runs asynchronously on the stream + return std::make_tuple(normed_output, output, sf_out); +} + +} // namespace torch_ext + +TRTLLM_NAMESPACE_END + +TORCH_LIBRARY_FRAGMENT(trtllm, m) +{ + m.def( + "fused_add_rms_norm_quant(Tensor input, Tensor residual, Tensor gamma, " + "Tensor? sf_scale, bool use_rms_norm=True, float eps=1e-6) -> (Tensor, Tensor, Tensor)"); +} + +TORCH_LIBRARY_IMPL(trtllm, CUDA, m) +{ + m.impl("fused_add_rms_norm_quant", &tensorrt_llm::torch_ext::fused_add_rms_norm_quant); +} diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index 2ee8d29ccc..cd3bd0f556 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -1869,3 +1869,56 @@ def record_stream(tensor: torch.Tensor, stream_id: int) -> None: stream = get_stream(stream_id) assert stream is not None tensor.record_stream(stream) + + +def fused_add_rms_norm_quant( + input: torch.Tensor, + residual: torch.Tensor, + gamma: torch.Tensor, + sf_scale: Optional[torch.Tensor], + use_rms_norm: bool = True, + eps: float = 1e-6, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Fused Add + RMSNorm/LayerNorm + FP4 Quantization kernel. + + Args: + input: [M, N] input tensor (fp16/bf16) + residual: [M, N] residual tensor (fp16/bf16) + gamma: [N] normalization weight (fp16/bf16) + sf_scale: [1] optional scale factor for FP4 quantization (float32) + use_rms_norm: if True use RMSNorm, else use LayerNorm + eps: epsilon for normalization + + Returns: + normed_output_fp4: [M, N/8] FP4 quantized normalized output (int32, packed) + output: [M, N] pre-norm output (input + residual), same dtype as input + sf_out: scale factors for FP4 quantization (uint8), swizzled layout + + Note: + This kernel requires SM90 (Hopper) or SM100 (Blackwell) GPU. + Hidden dimension N must be >= 2048 and <= 16384. + """ + return torch.ops.trtllm.fused_add_rms_norm_quant(input, residual, gamma, + sf_scale, use_rms_norm, + eps) + + +@torch.library.register_fake("trtllm::fused_add_rms_norm_quant") +def _fused_add_rms_norm_quant_fake( + input: torch.Tensor, + residual: torch.Tensor, + gamma: torch.Tensor, + sf_scale: Optional[torch.Tensor], + use_rms_norm: bool = True, + eps: float = 1e-5, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + m, n = input.shape + # normed_output_fp4: [M, N/8] as int32 (8 FP4 values packed per int32) + normed_output_fp4 = input.new_empty((m, n // 8), dtype=torch.int32) + # output: [M, N] pre-norm output, same dtype as input + output = input.new_empty((m, n), dtype=input.dtype) + # sf_out: scale factors, swizzled layout + sf_vec_size = 16 + sf_size = ((m + 127) // 128) * 128 * ((n // sf_vec_size + 3) // 4) * 4 + sf_out = input.new_empty((sf_size, ), dtype=torch.uint8) + return normed_output_fp4, output, sf_out diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 464a446cb3..f5678536c9 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -625,6 +625,7 @@ class LlamaDecoderLayer(DecoderLayer): super().__init__() config = model_config.pretrained_config self.layer_idx = layer_idx + self.num_hidden_layers = config.num_hidden_layers self.mapping = model_config.mapping self.enable_attention_dp = model_config.mapping.enable_attention_dp self.is_quanted = model_config.quant_config and model_config.quant_config.quant_mode.has_any_quant( @@ -649,14 +650,30 @@ class LlamaDecoderLayer(DecoderLayer): layer_idx=layer_idx, use_custom_cublas_mm=use_custom_cublas_mm, ) - self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, - eps=config.rms_norm_eps, - dtype=config.torch_dtype) - - self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size, - eps=config.rms_norm_eps, - dtype=config.torch_dtype) + differ_pp_stage_with_previous_layer = False + if self.mapping.has_pp(): + prev_layer_idx = max(self.layer_idx - 1, 0) + differ_pp_stage_with_previous_layer = ( + self.layer_idx > 0 and self.mapping.pp_rank_of_layer( + self.layer_idx, + self.num_hidden_layers) != self.mapping.pp_rank_of_layer( + prev_layer_idx, self.num_hidden_layers)) + self.disable_nvfp4_layernorm_fusion = os.environ.get( + "TRTLLM_DISABLE_NVFP4_LAYERNORM_FUSION", "1") == "1" + self.input_layernorm = RMSNorm( + hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype, + quantize_type="nvfp4" + if not self.disable_nvfp4_layernorm_fusion and self.is_nvfp4 + and not (differ_pp_stage_with_previous_layer) else None) + self.post_attention_layernorm = RMSNorm( + hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype, + quantize_type="nvfp4" if not self.disable_nvfp4_layernorm_fusion + and self.is_nvfp4 else None) self.all_reduce = AllReduce(mapping=model_config.mapping, strategy=model_config.allreduce_strategy) @@ -676,7 +693,6 @@ class LlamaDecoderLayer(DecoderLayer): self.PRE_MLP_FUSION = self.mapping.has_tp( ) and not self.enable_attention_dp and self.enable_fusion self.POST_MLP_FUSION = self.mapping.has_tp() and self.enable_fusion - if self.is_nvfp4: self.pre_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4 self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4 @@ -697,17 +713,16 @@ class LlamaDecoderLayer(DecoderLayer): def forward( self, position_ids: torch.IntTensor, - hidden_states: torch.Tensor, + hidden_states: Union[torch.Tensor, Fp4QuantizedTensor], attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], spec_metadata: Optional[SpecMetadata] = None, **kwargs, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, Fp4QuantizedTensor]: if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - # Self Attention hidden_states = self.self_attn( position_ids=position_ids, hidden_states=hidden_states, @@ -739,6 +754,8 @@ class LlamaDecoderLayer(DecoderLayer): else: hidden_states, residual = all_reduce_output else: + if self.is_nvfp4: + self.post_attention_layernorm.nvfp4_scale = self.mlp.gate_up_proj.input_scale hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) @@ -803,6 +820,12 @@ class LlamaDecoderLayer(DecoderLayer): else: hidden_states, residual = all_reduce_output elif self.next_layer_layernorm: + # NOTE: for the last decoder layer, `next_layer_layernorm` is the final model norm without nvfp4 quant + # (`self.model.norm`), and `next_attn` is expected to be None. + if self.next_attn is not None and hasattr(self.next_attn.qkv_proj, + 'input_scale'): + self.next_layer_layernorm.nvfp4_scale = self.next_attn.qkv_proj.input_scale + hidden_states, residual = self.next_layer_layernorm( hidden_states, residual) diff --git a/tensorrt_llm/_torch/modules/rms_norm.py b/tensorrt_llm/_torch/modules/rms_norm.py index 0d2228eff5..a61f91e0cf 100644 --- a/tensorrt_llm/_torch/modules/rms_norm.py +++ b/tensorrt_llm/_torch/modules/rms_norm.py @@ -20,7 +20,10 @@ from typing import Optional, Tuple, TypeAlias, Union, cast import torch from torch import nn +from tensorrt_llm.logger import logger + from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE +from ..utils import Fp4QuantizedTensor class RMSNorm(nn.Module): @@ -37,11 +40,17 @@ class RMSNorm(nn.Module): device: Optional[torch.device] = None, has_weights: bool = True, use_gemma: bool = False, + quantize_type: Optional[str] = None, ): super().__init__() if use_gemma and not has_weights: raise ValueError("has_weights must be True if use_gemma is True") + if quantize_type is not None: + if quantize_type != "nvfp4": + raise NotImplementedError( + f"Quantize type {quantize_type} not implemented in RMSNorm") + self.is_nvfp4 = quantize_type == "nvfp4" if has_weights: if not use_gemma: @@ -65,12 +74,112 @@ class RMSNorm(nn.Module): residual: Union[ Optional[torch.Tensor], _ArgumentNotSpecifiedSentinelType] = _ARGUMENT_NOT_SPECIFIED_SENTINEL, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: + ) -> Union[torch.Tensor, Fp4QuantizedTensor, Tuple[Union[ + torch.Tensor, Fp4QuantizedTensor], Optional[torch.Tensor]]]: return_residual = True if residual is self._ARGUMENT_NOT_SPECIFIED_SENTINEL: return_residual = False residual = None + if self.is_nvfp4 and residual is not None and not self.use_gemma: + nvfp4_scale = getattr(self, "nvfp4_scale", None) + if nvfp4_scale is None: + raise ValueError( + f"layeridx={getattr(self, 'layer_idx', None)} RMSNorm NVFP4 output requested " + "but no `nvfp4_scale` is attached; ") + else: + + def _can_use_fused_kernel() -> Tuple[bool, str]: + if not hidden_states.is_cuda or not residual.is_cuda: + return False, "inputs must be CUDA tensors" + if not self.weight.is_cuda: + return False, "gamma/weight must be a CUDA tensor" + if hidden_states.ndim < 2: + return False, "input must have rank >= 2" + if hidden_states.shape != residual.shape: + return False, f"input/residual shape mismatch: {tuple(hidden_states.shape)} vs {tuple(residual.shape)}" + n = int(hidden_states.shape[-1]) + if self.weight.ndim != 1 or int(self.weight.numel()) != n: + return False, f"gamma/weight must be 1D with numel == hidden_size ({n}), got shape={tuple(self.weight.shape)}" + # Match the underlying C++ op: fp16/bf16 only (no fp8). + if hidden_states.dtype not in (torch.float16, + torch.bfloat16): + return False, f"unsupported dtype {hidden_states.dtype} (expected fp16/bf16)" + if n % 16 != 0: + return False, f"hidden size must be divisible by 16 (got {n})" + # Kernel constraints (see fusedAddRMSNormQuant.cpp). + if n < 2048 or n > 16384: + return False, f"hidden size must be in [2048, 16384] (got {n})" + # SM90+ only. + major, _minor = torch.cuda.get_device_capability( + hidden_states.device) + if major < 9: + return False, f"requires SM90+ GPU, got SM{major}{_minor}" + # Scale tensor constraints. + if (nvfp4_scale is not None + and ((not nvfp4_scale.is_cuda) or nvfp4_scale.dtype + != torch.float32 or nvfp4_scale.numel() != 1)): + return False, f"nvfp4_scale must be a CUDA float32 tensor with numel==1 (got dtype={getattr(nvfp4_scale, 'dtype', None)}, device={getattr(nvfp4_scale, 'device', None)}, numel={getattr(nvfp4_scale, 'numel', lambda: None)()})" + return True, "" + + ok, reason = _can_use_fused_kernel() + if not ok: + raise RuntimeError( + "RMSNorm NVFP4 fused path disabled due to unsupported inputs " + f"(falling back to unfused RMSNorm): {reason}") + else: + from ..custom_ops.torch_custom_ops import \ + fused_add_rms_norm_quant + + orig_shape = tuple(hidden_states.shape) + n = int(orig_shape[-1]) + hs_2d = hidden_states.reshape(-1, n).contiguous() + res_2d = residual.reshape(-1, n) + gamma = self.weight + + def _ensure_contiguous_with_dtype(t: torch.Tensor, + key: str): + if t.dtype != hs_2d.dtype: + logger.warning_once( + f"RMSNorm NVFP4 fused path: casting {key} from {t.dtype} to {hs_2d.dtype}.", + key=f"rmsnorm_nvfp4_cast_{key}", + ) + t = t.to(dtype=hs_2d.dtype) + return t.contiguous() + + res_2d = _ensure_contiguous_with_dtype(res_2d, "residual") + gamma = _ensure_contiguous_with_dtype(gamma, "gamma") + + if hs_2d.device != res_2d.device or hs_2d.device != gamma.device: + raise RuntimeError( + "RMSNorm NVFP4 fused path requires all tensors on the same device. " + f"Got input={hs_2d.device}, residual={res_2d.device}, gamma={gamma.device}." + ) + + sf_scale = nvfp4_scale.contiguous( + ) if nvfp4_scale is not None else None + + normed_fp4_i32, residual_out_2d, sf_fused = fused_add_rms_norm_quant( + hs_2d, + res_2d, + gamma, + sf_scale, + True, + eps=self.variance_epsilon, + ) + normed_fp4_u8 = normed_fp4_i32.view(torch.uint8) + if len(orig_shape) != 2: + normed_fp4_u8 = normed_fp4_u8.reshape( + *orig_shape[:-1], n // 2) + residual_out = residual_out_2d.reshape(orig_shape) + else: + residual_out = residual_out_2d + + hidden_states_fused = Fp4QuantizedTensor( + normed_fp4_u8, sf_fused) + return (hidden_states_fused, residual_out + ) if return_residual else hidden_states_fused + if IS_FLASHINFER_AVAILABLE: from ..custom_ops import (flashinfer_fused_add_rmsnorm, flashinfer_gemma_fused_add_rmsnorm, diff --git a/tensorrt_llm/mapping.py b/tensorrt_llm/mapping.py index 818ee33dce..a650e76122 100644 --- a/tensorrt_llm/mapping.py +++ b/tensorrt_llm/mapping.py @@ -328,6 +328,39 @@ class MappingBase: return torch.tensor_split(torch.arange(num_layers), self.pp_size)[self.pp_rank].tolist() + def pp_rank_of_layer(self, layer_idx: int, num_layers: int) -> int: + """Return pipeline-parallel rank that owns `layer_idx` for a model with `num_layers` layers. + Mirrors the partitioning behavior in `pp_layers()`. + """ + if layer_idx < 0 or layer_idx >= num_layers: + raise ValueError(f"{layer_idx=} is out of range for {num_layers=}.") + if not self.has_pp(): + return 0 + + if self.pp_partition is not None: + if len(self.pp_partition) != self.pp_size: + raise ValueError( + f"{len(self.pp_partition)=} does not match {self.pp_size=}." + ) + if sum(self.pp_partition) != num_layers: + raise ValueError( + f"{sum(self.pp_partition)=} does not match {num_layers=}.") + end = 0 + for pp_rank, n in enumerate(self.pp_partition): + end += n + if layer_idx < end: + return pp_rank + raise RuntimeError("Unreachable: invalid pp_partition.") + + base, rem = divmod(num_layers, self.pp_size) + if base == 0: + # Matches torch.tensor_split: first `num_layers` ranks get one layer. + return layer_idx + cutoff = (base + 1) * rem + if layer_idx < cutoff: + return layer_idx // (base + 1) + return rem + (layer_idx - cutoff) // base + def ep_experts(self, num_experts: int) -> List[int]: assert self.cp_size == 1 experts_per_rank = num_experts // self.moe_ep_size