/* * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/kernels/fusedLayernormKernels/layernorm_param.h" #include "tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.h" #include "tensorrt_llm/kernels/quantization.h" #include "tensorrt_llm/thop/thUtils.h" #include #include #include #include #include #include #include #include #include TRTLLM_NAMESPACE_BEGIN namespace torch_ext { // Fused Add + RMSNorm + FP4 Quantization kernel // input: [M, N] - input tensor (fp16/bf16) // residual: [M, N] - residual tensor (fp16/bf16) // gamma: [N] - RMSNorm weight (fp16/bf16) // sf_scale: [1] - optional scale factor for FP4 quantization (float) // use_rms_norm: bool - if true use RMSNorm, else use LayerNorm // Returns: // normed_output: [M, N/8] - FP4 quantized normalized output (uint32_t, packed) // output: [M, N] - pre-norm output (input + residual), same dtype as input // sf_out: scale factors for FP4 (uint8_t), swizzled layout // // NOTE: This kernel requires SM90 (Hopper) or SM100 (Blackwell) GPU architecture. // NOTE: Hidden dimension N must be >= 2048 and <= 16384. std::tuple fused_add_rms_norm_quant(at::Tensor const& input, at::Tensor const& residual, at::Tensor const& gamma, std::optional const& sf_scale, bool use_rms_norm, double eps) { CHECK_TH_CUDA(input); CHECK_CONTIGUOUS(input); CHECK_TH_CUDA(residual); CHECK_CONTIGUOUS(residual); CHECK_TH_CUDA(gamma); CHECK_CONTIGUOUS(gamma); // Check GPU architecture - kernel requires SM90+ (Hopper/Blackwell) auto const device = input.get_device(); cudaDeviceProp props; AT_CUDA_CHECK(cudaGetDeviceProperties(&props, device)); TORCH_CHECK(props.major >= 9, "fused_add_rms_norm_quant requires SM90 (Hopper) or newer GPU architecture. " "Current device: sm_", props.major, props.minor); auto const& inputShape = input.sizes(); auto const& rank = inputShape.size(); TORCH_CHECK(rank == 2, "input should be 2D tensor [M, N]."); TORCH_CHECK(residual.sizes() == inputShape, "residual shape must match input shape."); int64_t const m = inputShape[0]; int64_t const n = inputShape[1]; // Some warp-specialized kernels may issue vectorized stores that assume M is padded. // Allocate a bit of extra space to avoid out-of-bounds writes when M is not a multiple of 8. int64_t const m_padded = (m + 31) / 32 * 32; TORCH_CHECK(gamma.sizes()[0] == n, "gamma size must match hidden dimension N."); TORCH_CHECK(n >= 2048, "Hidden dimension N must be >= 2048 (kernel constraint)."); TORCH_CHECK(n <= 16384, "Hidden dimension N must be <= 16384."); TORCH_CHECK(n % 16 == 0, "Hidden dimension N must be divisible by 16 for FP4 quantization."); // Validate sf_scale if provided float* sfScalePtr = nullptr; if (sf_scale.has_value()) { CHECK_INPUT(sf_scale.value(), torch::kFloat32); sfScalePtr = sf_scale.value().data_ptr(); } // Allocate output tensors // normed_output: FP4 packed output [M, N/8] as uint32_t (8 FP4 values packed per uint32) // NOTE: allocate [M_padded, ...] to avoid OOB writes; return a view of [M, ...] to keep API stable. at::Tensor normed_output_padded = at::detail::empty_cuda({m_padded, n / 8}, torch::kInt32, input.device(), std::nullopt); at::Tensor normed_output = (m_padded == m) ? normed_output_padded : normed_output_padded.narrow(0, 0, m); // output: pre-norm output (input + residual) [M, N], same dtype as input // NOTE: allocate [M_padded, ...] to avoid OOB writes; return a view of [M, ...] to keep API stable. at::Tensor output_padded = at::detail::empty_cuda({m_padded, n}, input.scalar_type(), input.device(), std::nullopt); at::Tensor output = (m_padded == m) ? output_padded : output_padded.narrow(0, 0, m); // sf_out: scale factors for FP4, swizzled layout // sfVecSize = 16 for FP4 quantization (16 FP4 values share one scale factor) int64_t const sfVecSize = 16; // NOTE: allocate using m_padded to avoid OOB writes for warp-specialized/vectorized stores when M is not padded. // Return a view of the original (un-padded) size to keep the API stable. int64_t const sfSize = tensorrt_llm::computeSwizzledLayoutSFSize(m, n / sfVecSize); int64_t const sfSizePadded = tensorrt_llm::computeSwizzledLayoutSFSize(m_padded, n / sfVecSize); at::Tensor sf_out_padded = at::detail::empty_cuda({sfSizePadded}, SF_DTYPE, input.device(), std::nullopt); at::Tensor sf_out = (m_padded == m) ? sf_out_padded : sf_out_padded.narrow(0, 0, sfSize); // Get number of SMs for persistent kernel static int const multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount(); // Allocate counters for warp-specialized kernel using PyTorch allocator. // // NOTE: We cache this tensor to avoid per-call allocations. We use `thread_local` so // concurrent calls from different threads don't share the same counters buffer (which // could cause races across different CUDA streams). static thread_local std::unordered_map counters_tensor_cache; auto& counters_tensor = counters_tensor_cache[device]; int64_t const counters_bytes = static_cast(sizeof(tensorrt_llm::kernels::WarpSpecializedCounters)); if (!counters_tensor.defined() || counters_tensor.numel() != counters_bytes) { counters_tensor = at::detail::empty_cuda({counters_bytes}, torch::kByte, input.device(), std::nullopt); counters_tensor.zero_(); } auto* counters = reinterpret_cast(counters_tensor.mutable_data_ptr()); auto stream = at::cuda::getCurrentCUDAStream(device); #define LAUNCH_FUSED_ADD_RMS_NORM_QUANT(T) \ do \ { \ using Param = tensorrt_llm::kernels::GeneralFP4AddBiasResidualPreLayerNormParam; \ tensorrt_llm::kernels::WarpSpecializedParam param; \ param.normed_output = reinterpret_cast(normed_output.data_ptr()); \ param.output = reinterpret_cast(output.data_ptr()); \ param.input = const_cast(reinterpret_cast(input.data_ptr())); \ param.sf_scale = sfScalePtr; \ param.sf_out = reinterpret_cast(sf_out.data_ptr()); \ param.residual = reinterpret_cast(residual.data_ptr()); \ param.bias = nullptr; \ param.gamma = reinterpret_cast(gamma.data_ptr()); \ param.beta = nullptr; \ param.m = static_cast(m); \ param.n = static_cast(n); \ param.layernorm_eps = static_cast(eps); \ param.stream = stream; \ param.counters = counters; \ tensorrt_llm::kernels::invokeWSLayerNorm(param, use_rms_norm, multiProcessorCount); \ } while (0) if (input.scalar_type() == at::ScalarType::Half) { LAUNCH_FUSED_ADD_RMS_NORM_QUANT(half); } else if (input.scalar_type() == at::ScalarType::BFloat16) { #ifdef ENABLE_BF16 LAUNCH_FUSED_ADD_RMS_NORM_QUANT(__nv_bfloat16); #else C10_THROW_ERROR(NotImplementedError, "BFloat16 must be enabled for fused_add_rms_norm_quant with bf16 input."); #endif } else { C10_THROW_ERROR( NotImplementedError, "fused_add_rms_norm_quant only supports input tensor with dtypes fp16/bf16."); } #undef LAUNCH_FUSED_ADD_RMS_NORM_QUANT // No explicit sync needed - kernel runs asynchronously on the stream return std::make_tuple(normed_output, output, sf_out); } } // namespace torch_ext TRTLLM_NAMESPACE_END TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "fused_add_rms_norm_quant(Tensor input, Tensor residual, Tensor gamma, " "Tensor? sf_scale, bool use_rms_norm=True, float eps=1e-6) -> (Tensor, Tensor, Tensor)"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("fused_add_rms_norm_quant", &tensorrt_llm::torch_ext::fused_add_rms_norm_quant); }