This commit is contained in:
彭晋韬(jtao peng) 2026-01-13 21:07:28 +08:00 committed by GitHub
commit ae855d7f67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 497 additions and 50 deletions

View File

@ -115,8 +115,6 @@ struct LowLatencyLayerNorm
uint32_t work_id = blockIdx.x;
FusedOperator fused_operator(param);
constexpr auto PACKED_PER_N_BLOCK = Traits::N_BLOCK / N_THREADS / Traits::PACKED_ELEMS_PER_COMPUTE;
typename Traits::AccumulatorType data[PACKED_PER_N_BLOCK][Traits::PACKED_ELEMS_PER_COMPUTE];
@ -139,7 +137,7 @@ struct LowLatencyLayerNorm
for (int i = 0; i < PACKED_PER_N_BLOCK; i++)
{
auto offset = (thread_id + i * N_THREADS) * Traits::PACKED_ELEMS_PER_COMPUTE;
if (offset <= sz)
if (offset < sz)
{
data[i] = *reinterpret_cast<PackedType const*>(&g_data[offset]);
}
@ -155,6 +153,14 @@ struct LowLatencyLayerNorm
static_assert(Traits::OUTPUT_SCALE != SCALE_TYPE::VECTOR);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
{
cudaGridDependencySynchronize();
}
#endif
FusedOperator fused_operator(param);
if constexpr (Traits::BIAS == SCALE_TYPE::VECTOR)
{
load_to_register(param.bias, r_bias, param.n);
@ -175,13 +181,6 @@ struct LowLatencyLayerNorm
load_to_register(param.beta, r_beta, param.n);
}
#if (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 12))
if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
{
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
load_to_register(&param.input[work_id * param.n], data, param.n);
if constexpr (Traits::RESIDUAL)
@ -259,12 +258,12 @@ struct LowLatencyLayerNorm
if constexpr (!Traits::RMS_NORM)
{
mean = var_and_mean[1] / param.n;
variance = rsqrtf(
var_and_mean[0] / param.n - var_and_mean[1] * var_and_mean[1] + (Traits::AccumulatorType)(1e-5));
variance = rsqrtf(var_and_mean[0] / param.n - var_and_mean[1] * var_and_mean[1]
+ (Traits::AccumulatorType)(param.layernorm_eps));
}
else
{
variance = rsqrtf(var_and_mean[0] / param.n + (Traits::AccumulatorType)(1e-5));
variance = rsqrtf(var_and_mean[0] / param.n + (Traits::AccumulatorType)(param.layernorm_eps));
}
for (int i = 0; i < PACKED_PER_N_BLOCK; i++)
@ -333,6 +332,14 @@ struct LowLatencyLayerNorm
{
__shared__ Shared shared;
compute(param, &shared);
__syncthreads();
asm volatile("membar.gl;" : : : "memory");
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
{
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
}
};

View File

@ -201,25 +201,35 @@ struct WarpSpecializedLayerNorm
}
// if (blockIdx.x == 0) printf("Pushed tile %d to MATH.\n", m_base);
if constexpr (FIRST_RUN)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
{
// Ensure upstream kernel writes are visible before reading dependent activation/residual data.
cudaGridDependencySynchronize();
}
#endif
}
const uint32_t eff_m_block
= std::min(static_cast<uint32_t>(Traits::M_BLOCK), static_cast<uint32_t>(param.m - m_base));
const auto tx
= (Traits::M_BLOCK * param.n * sizeof(typename Traits::InputType) * (Traits::RESIDUAL ? 2 : 1))
+ (FIRST_RUN ? sizeof(AuxData) / Traits::N_BLOCK * param.n : 0);
= (eff_m_block * param.n * sizeof(typename Traits::InputType) * (Traits::RESIDUAL ? 2 : 1))
+ (FIRST_RUN ? (sizeof(AuxData) / Traits::N_BLOCK * param.n) : 0);
auto vec_buffer_ptr = input_vec_fifo_w.tmaReserve(tx);
// if (blockIdx.x == 0) printf("SMEM buffer ready, start loading tile %d.\n", m_base);
if constexpr (FIRST_RUN)
{
cudaGridDependencySynchronize();
}
for (int i = 0; i < Traits::M_BLOCK; i++)
{
load_a_vec(&param.input[(m_base + i) * param.n],
__nvvm_get_smem_pointer(&shared->input_vec[vec_buffer_ptr][0][i * Traits::N_BLOCK]),
param.n * sizeof(typename Traits::InputType),
__nvvm_get_smem_pointer(input_vec_fifo_w.barrier_ptr(vec_buffer_ptr)));
if (i < eff_m_block) [[likely]]
{
load_a_vec(&param.input[(m_base + i) * param.n],
__nvvm_get_smem_pointer(&shared->input_vec[vec_buffer_ptr][0][i * Traits::N_BLOCK]),
param.n * sizeof(typename Traits::InputType),
__nvvm_get_smem_pointer(input_vec_fifo_w.barrier_ptr(vec_buffer_ptr)));
}
}
// Use templated lambdas to defer resolving the symbols like "param.residual".
@ -231,10 +241,13 @@ struct WarpSpecializedLayerNorm
{
for (int i = 0; i < Traits::M_BLOCK; i++)
{
load_a_vec(&param.residual[(m_base + i) * param.n],
__nvvm_get_smem_pointer(&shared->input_vec[vec_buffer_ptr][1][i * Traits::N_BLOCK]),
param.n * sizeof(typename Traits::InputType),
__nvvm_get_smem_pointer(input_vec_fifo_w.barrier_ptr(vec_buffer_ptr)));
if (i < eff_m_block) [[likely]]
{
load_a_vec(&param.residual[(m_base + i) * param.n],
__nvvm_get_smem_pointer(&shared->input_vec[vec_buffer_ptr][1][i * Traits::N_BLOCK]),
param.n * sizeof(typename Traits::InputType),
__nvvm_get_smem_pointer(input_vec_fifo_w.barrier_ptr(vec_buffer_ptr)));
}
}
}(param);
}
@ -423,6 +436,13 @@ struct WarpSpecializedLayerNorm
using FusedOperator = GetFusedOperator<typename Traits::FusedOperator>;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
{
// Ensure upstream kernel writes are visible before reading dependent activation/residual data.
cudaGridDependencySynchronize();
}
#endif
FusedOperator fused_operator(param);
static_assert(Traits::PERSISTENT_MODE || Traits::MATH_WARPGROUPS == 1);
@ -446,6 +466,9 @@ struct WarpSpecializedLayerNorm
{
m_base = block_id;
}
const uint32_t eff_m_block
= std::min(static_cast<uint32_t>(Traits::M_BLOCK), static_cast<uint32_t>(param.m - m_base));
// if (blockIdx.x == 0 && thread_id == 0) printf("MATH got tile %d.\n", m_base);
// Peek for data ready.
@ -613,11 +636,12 @@ struct WarpSpecializedLayerNorm
{
mean[m_offset] /= param.n;
variance[m_offset] = rsqrtf(variance[m_offset] / param.n - mean[m_offset] * mean[m_offset]
+ (Traits::AccumulatorType)(1e-5));
+ (Traits::AccumulatorType)(param.layernorm_eps));
}
else
{
variance[m_offset] = rsqrtf(variance[m_offset] / param.n + (Traits::AccumulatorType)(1e-5));
variance[m_offset]
= rsqrtf(variance[m_offset] / param.n + (Traits::AccumulatorType)(param.layernorm_eps));
}
}
@ -659,8 +683,7 @@ struct WarpSpecializedLayerNorm
}
}
#pragma unroll Traits::M_BLOCK
for (int m_offset = 0; m_offset < Traits::M_BLOCK; m_offset++)
for (int m_offset = 0; m_offset < eff_m_block; m_offset++)
{
auto m = m_base + m_offset;
@ -801,23 +824,19 @@ struct WarpSpecializedLayerNorm
shared->init(threadIdx.x == 0);
__syncthreads();
#if (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 12))
#if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM100_ALL))
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)
if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
{
auto block_id = blockIdx.x;
auto warp_id = threadIdx.x / 32;
auto lane_id = threadIdx.x % 32;
auto tid_in_wg = threadIdx.x % 128;
if (warp_id < 4)
{
asm volatile("{setmaxnreg.dec.sync.aligned.u32 56; \n\t}");
if (warp_id == 0)
{
scheduler(lane_id, gridDim.x * gridDim.y * gridDim.z, param, shared);
// PRE-EXIT after all tiles have been scheduled.
cudaTriggerProgrammaticLaunchCompletion();
}
else if (warp_id == 1)
{
@ -829,8 +848,10 @@ struct WarpSpecializedLayerNorm
asm volatile("{setmaxnreg.inc.sync.aligned.u32 224; \n\t}");
compute(block_id, threadIdx.x / 128 - 1, tid_in_wg, param, shared);
}
__syncthreads();
asm volatile("membar.gl;" : : : "memory");
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
#endif
}
};

View File

@ -66,6 +66,7 @@ add_library(
fp8Quantize.cpp
dsv3FusedAGemmOp.cpp
fusedQKNormRopeOp.cpp
fusedAddRMSNormQuant.cpp
fusedTopkSoftmax.cpp
gatherTreeOp.cpp
groupRmsNormOp.cpp

View File

@ -0,0 +1,200 @@
/*
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/fusedLayernormKernels/layernorm_param.h"
#include "tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.h"
#include "tensorrt_llm/kernels/quantization.h"
#include "tensorrt_llm/thop/thUtils.h"
#include <ATen/Functions.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/EmptyTensor.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <optional>
#include <tuple>
#include <unordered_map>
TRTLLM_NAMESPACE_BEGIN
namespace torch_ext
{
// Fused Add + RMSNorm + FP4 Quantization kernel
// input: [M, N] - input tensor (fp16/bf16)
// residual: [M, N] - residual tensor (fp16/bf16)
// gamma: [N] - RMSNorm weight (fp16/bf16)
// sf_scale: [1] - optional scale factor for FP4 quantization (float)
// use_rms_norm: bool - if true use RMSNorm, else use LayerNorm
// Returns:
// normed_output: [M, N/8] - FP4 quantized normalized output (uint32_t, packed)
// output: [M, N] - pre-norm output (input + residual), same dtype as input
// sf_out: scale factors for FP4 (uint8_t), swizzled layout
//
// NOTE: This kernel requires SM90 (Hopper) or SM100 (Blackwell) GPU architecture.
// NOTE: Hidden dimension N must be >= 2048 and <= 16384.
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_add_rms_norm_quant(at::Tensor const& input,
at::Tensor const& residual, at::Tensor const& gamma, std::optional<at::Tensor> const& sf_scale, bool use_rms_norm,
double eps)
{
CHECK_TH_CUDA(input);
CHECK_CONTIGUOUS(input);
CHECK_TH_CUDA(residual);
CHECK_CONTIGUOUS(residual);
CHECK_TH_CUDA(gamma);
CHECK_CONTIGUOUS(gamma);
// Check GPU architecture - kernel requires SM90+ (Hopper/Blackwell)
auto const device = input.get_device();
cudaDeviceProp props;
AT_CUDA_CHECK(cudaGetDeviceProperties(&props, device));
TORCH_CHECK(props.major >= 9,
"fused_add_rms_norm_quant requires SM90 (Hopper) or newer GPU architecture. "
"Current device: sm_",
props.major, props.minor);
auto const& inputShape = input.sizes();
auto const& rank = inputShape.size();
TORCH_CHECK(rank == 2, "input should be 2D tensor [M, N].");
TORCH_CHECK(residual.sizes() == inputShape, "residual shape must match input shape.");
int64_t const m = inputShape[0];
int64_t const n = inputShape[1];
// Some warp-specialized kernels may issue vectorized stores that assume M is padded.
// Allocate a bit of extra space to avoid out-of-bounds writes when M is not a multiple of 8.
int64_t const m_padded = (m + 31) / 32 * 32;
TORCH_CHECK(gamma.sizes()[0] == n, "gamma size must match hidden dimension N.");
TORCH_CHECK(n >= 2048, "Hidden dimension N must be >= 2048 (kernel constraint).");
TORCH_CHECK(n <= 16384, "Hidden dimension N must be <= 16384.");
TORCH_CHECK(n % 16 == 0, "Hidden dimension N must be divisible by 16 for FP4 quantization.");
// Validate sf_scale if provided
float* sfScalePtr = nullptr;
if (sf_scale.has_value())
{
CHECK_INPUT(sf_scale.value(), torch::kFloat32);
sfScalePtr = sf_scale.value().data_ptr<float>();
}
// Allocate output tensors
// normed_output: FP4 packed output [M, N/8] as uint32_t (8 FP4 values packed per uint32)
// NOTE: allocate [M_padded, ...] to avoid OOB writes; return a view of [M, ...] to keep API stable.
at::Tensor normed_output_padded
= at::detail::empty_cuda({m_padded, n / 8}, torch::kInt32, input.device(), std::nullopt);
at::Tensor normed_output = (m_padded == m) ? normed_output_padded : normed_output_padded.narrow(0, 0, m);
// output: pre-norm output (input + residual) [M, N], same dtype as input
// NOTE: allocate [M_padded, ...] to avoid OOB writes; return a view of [M, ...] to keep API stable.
at::Tensor output_padded = at::detail::empty_cuda({m_padded, n}, input.scalar_type(), input.device(), std::nullopt);
at::Tensor output = (m_padded == m) ? output_padded : output_padded.narrow(0, 0, m);
// sf_out: scale factors for FP4, swizzled layout
// sfVecSize = 16 for FP4 quantization (16 FP4 values share one scale factor)
int64_t const sfVecSize = 16;
// NOTE: allocate using m_padded to avoid OOB writes for warp-specialized/vectorized stores when M is not padded.
// Return a view of the original (un-padded) size to keep the API stable.
int64_t const sfSize = tensorrt_llm::computeSwizzledLayoutSFSize(m, n / sfVecSize);
int64_t const sfSizePadded = tensorrt_llm::computeSwizzledLayoutSFSize(m_padded, n / sfVecSize);
at::Tensor sf_out_padded = at::detail::empty_cuda({sfSizePadded}, SF_DTYPE, input.device(), std::nullopt);
at::Tensor sf_out = (m_padded == m) ? sf_out_padded : sf_out_padded.narrow(0, 0, sfSize);
// Get number of SMs for persistent kernel
static int const multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();
// Allocate counters for warp-specialized kernel using PyTorch allocator.
//
// NOTE: We cache this tensor to avoid per-call allocations. We use `thread_local` so
// concurrent calls from different threads don't share the same counters buffer (which
// could cause races across different CUDA streams).
static thread_local std::unordered_map<int, at::Tensor> counters_tensor_cache;
auto& counters_tensor = counters_tensor_cache[device];
int64_t const counters_bytes = static_cast<int64_t>(sizeof(tensorrt_llm::kernels::WarpSpecializedCounters));
if (!counters_tensor.defined() || counters_tensor.numel() != counters_bytes)
{
counters_tensor = at::detail::empty_cuda({counters_bytes}, torch::kByte, input.device(), std::nullopt);
counters_tensor.zero_();
}
auto* counters
= reinterpret_cast<tensorrt_llm::kernels::WarpSpecializedCounters*>(counters_tensor.mutable_data_ptr());
auto stream = at::cuda::getCurrentCUDAStream(device);
#define LAUNCH_FUSED_ADD_RMS_NORM_QUANT(T) \
do \
{ \
using Param = tensorrt_llm::kernels::GeneralFP4AddBiasResidualPreLayerNormParam<T>; \
tensorrt_llm::kernels::WarpSpecializedParam<Param> param; \
param.normed_output = reinterpret_cast<uint32_t*>(normed_output.data_ptr()); \
param.output = reinterpret_cast<T*>(output.data_ptr()); \
param.input = const_cast<T*>(reinterpret_cast<T const*>(input.data_ptr())); \
param.sf_scale = sfScalePtr; \
param.sf_out = reinterpret_cast<uint32_t*>(sf_out.data_ptr()); \
param.residual = reinterpret_cast<T const*>(residual.data_ptr()); \
param.bias = nullptr; \
param.gamma = reinterpret_cast<T const*>(gamma.data_ptr()); \
param.beta = nullptr; \
param.m = static_cast<int>(m); \
param.n = static_cast<int>(n); \
param.layernorm_eps = static_cast<float>(eps); \
param.stream = stream; \
param.counters = counters; \
tensorrt_llm::kernels::invokeWSLayerNorm<Param>(param, use_rms_norm, multiProcessorCount); \
} while (0)
if (input.scalar_type() == at::ScalarType::Half)
{
LAUNCH_FUSED_ADD_RMS_NORM_QUANT(half);
}
else if (input.scalar_type() == at::ScalarType::BFloat16)
{
#ifdef ENABLE_BF16
LAUNCH_FUSED_ADD_RMS_NORM_QUANT(__nv_bfloat16);
#else
C10_THROW_ERROR(NotImplementedError, "BFloat16 must be enabled for fused_add_rms_norm_quant with bf16 input.");
#endif
}
else
{
C10_THROW_ERROR(
NotImplementedError, "fused_add_rms_norm_quant only supports input tensor with dtypes fp16/bf16.");
}
#undef LAUNCH_FUSED_ADD_RMS_NORM_QUANT
// No explicit sync needed - kernel runs asynchronously on the stream
return std::make_tuple(normed_output, output, sf_out);
}
} // namespace torch_ext
TRTLLM_NAMESPACE_END
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"fused_add_rms_norm_quant(Tensor input, Tensor residual, Tensor gamma, "
"Tensor? sf_scale, bool use_rms_norm=True, float eps=1e-6) -> (Tensor, Tensor, Tensor)");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("fused_add_rms_norm_quant", &tensorrt_llm::torch_ext::fused_add_rms_norm_quant);
}

View File

@ -1869,3 +1869,56 @@ def record_stream(tensor: torch.Tensor, stream_id: int) -> None:
stream = get_stream(stream_id)
assert stream is not None
tensor.record_stream(stream)
def fused_add_rms_norm_quant(
input: torch.Tensor,
residual: torch.Tensor,
gamma: torch.Tensor,
sf_scale: Optional[torch.Tensor],
use_rms_norm: bool = True,
eps: float = 1e-6,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Fused Add + RMSNorm/LayerNorm + FP4 Quantization kernel.
Args:
input: [M, N] input tensor (fp16/bf16)
residual: [M, N] residual tensor (fp16/bf16)
gamma: [N] normalization weight (fp16/bf16)
sf_scale: [1] optional scale factor for FP4 quantization (float32)
use_rms_norm: if True use RMSNorm, else use LayerNorm
eps: epsilon for normalization
Returns:
normed_output_fp4: [M, N/8] FP4 quantized normalized output (int32, packed)
output: [M, N] pre-norm output (input + residual), same dtype as input
sf_out: scale factors for FP4 quantization (uint8), swizzled layout
Note:
This kernel requires SM90 (Hopper) or SM100 (Blackwell) GPU.
Hidden dimension N must be >= 2048 and <= 16384.
"""
return torch.ops.trtllm.fused_add_rms_norm_quant(input, residual, gamma,
sf_scale, use_rms_norm,
eps)
@torch.library.register_fake("trtllm::fused_add_rms_norm_quant")
def _fused_add_rms_norm_quant_fake(
input: torch.Tensor,
residual: torch.Tensor,
gamma: torch.Tensor,
sf_scale: Optional[torch.Tensor],
use_rms_norm: bool = True,
eps: float = 1e-5,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
m, n = input.shape
# normed_output_fp4: [M, N/8] as int32 (8 FP4 values packed per int32)
normed_output_fp4 = input.new_empty((m, n // 8), dtype=torch.int32)
# output: [M, N] pre-norm output, same dtype as input
output = input.new_empty((m, n), dtype=input.dtype)
# sf_out: scale factors, swizzled layout
sf_vec_size = 16
sf_size = ((m + 127) // 128) * 128 * ((n // sf_vec_size + 3) // 4) * 4
sf_out = input.new_empty((sf_size, ), dtype=torch.uint8)
return normed_output_fp4, output, sf_out

View File

@ -625,6 +625,7 @@ class LlamaDecoderLayer(DecoderLayer):
super().__init__()
config = model_config.pretrained_config
self.layer_idx = layer_idx
self.num_hidden_layers = config.num_hidden_layers
self.mapping = model_config.mapping
self.enable_attention_dp = model_config.mapping.enable_attention_dp
self.is_quanted = model_config.quant_config and model_config.quant_config.quant_mode.has_any_quant(
@ -649,14 +650,30 @@ class LlamaDecoderLayer(DecoderLayer):
layer_idx=layer_idx,
use_custom_cublas_mm=use_custom_cublas_mm,
)
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
differ_pp_stage_with_previous_layer = False
if self.mapping.has_pp():
prev_layer_idx = max(self.layer_idx - 1, 0)
differ_pp_stage_with_previous_layer = (
self.layer_idx > 0 and self.mapping.pp_rank_of_layer(
self.layer_idx,
self.num_hidden_layers) != self.mapping.pp_rank_of_layer(
prev_layer_idx, self.num_hidden_layers))
self.disable_nvfp4_layernorm_fusion = os.environ.get(
"TRTLLM_DISABLE_NVFP4_LAYERNORM_FUSION", "1") == "1"
self.input_layernorm = RMSNorm(
hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
quantize_type="nvfp4"
if not self.disable_nvfp4_layernorm_fusion and self.is_nvfp4
and not (differ_pp_stage_with_previous_layer) else None)
self.post_attention_layernorm = RMSNorm(
hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
quantize_type="nvfp4" if not self.disable_nvfp4_layernorm_fusion
and self.is_nvfp4 else None)
self.all_reduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy)
@ -676,7 +693,6 @@ class LlamaDecoderLayer(DecoderLayer):
self.PRE_MLP_FUSION = self.mapping.has_tp(
) and not self.enable_attention_dp and self.enable_fusion
self.POST_MLP_FUSION = self.mapping.has_tp() and self.enable_fusion
if self.is_nvfp4:
self.pre_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4
self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4
@ -697,17 +713,16 @@ class LlamaDecoderLayer(DecoderLayer):
def forward(
self,
position_ids: torch.IntTensor,
hidden_states: torch.Tensor,
hidden_states: Union[torch.Tensor, Fp4QuantizedTensor],
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
spec_metadata: Optional[SpecMetadata] = None,
**kwargs,
) -> torch.Tensor:
) -> Union[torch.Tensor, Fp4QuantizedTensor]:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
@ -739,6 +754,8 @@ class LlamaDecoderLayer(DecoderLayer):
else:
hidden_states, residual = all_reduce_output
else:
if self.is_nvfp4:
self.post_attention_layernorm.nvfp4_scale = self.mlp.gate_up_proj.input_scale
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
@ -803,6 +820,12 @@ class LlamaDecoderLayer(DecoderLayer):
else:
hidden_states, residual = all_reduce_output
elif self.next_layer_layernorm:
# NOTE: for the last decoder layer, `next_layer_layernorm` is the final model norm without nvfp4 quant
# (`self.model.norm`), and `next_attn` is expected to be None.
if self.next_attn is not None and hasattr(self.next_attn.qkv_proj,
'input_scale'):
self.next_layer_layernorm.nvfp4_scale = self.next_attn.qkv_proj.input_scale
hidden_states, residual = self.next_layer_layernorm(
hidden_states, residual)

View File

@ -20,7 +20,10 @@ from typing import Optional, Tuple, TypeAlias, Union, cast
import torch
from torch import nn
from tensorrt_llm.logger import logger
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
from ..utils import Fp4QuantizedTensor
class RMSNorm(nn.Module):
@ -37,11 +40,17 @@ class RMSNorm(nn.Module):
device: Optional[torch.device] = None,
has_weights: bool = True,
use_gemma: bool = False,
quantize_type: Optional[str] = None,
):
super().__init__()
if use_gemma and not has_weights:
raise ValueError("has_weights must be True if use_gemma is True")
if quantize_type is not None:
if quantize_type != "nvfp4":
raise NotImplementedError(
f"Quantize type {quantize_type} not implemented in RMSNorm")
self.is_nvfp4 = quantize_type == "nvfp4"
if has_weights:
if not use_gemma:
@ -65,12 +74,112 @@ class RMSNorm(nn.Module):
residual: Union[
Optional[torch.Tensor],
_ArgumentNotSpecifiedSentinelType] = _ARGUMENT_NOT_SPECIFIED_SENTINEL,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
) -> Union[torch.Tensor, Fp4QuantizedTensor, Tuple[Union[
torch.Tensor, Fp4QuantizedTensor], Optional[torch.Tensor]]]:
return_residual = True
if residual is self._ARGUMENT_NOT_SPECIFIED_SENTINEL:
return_residual = False
residual = None
if self.is_nvfp4 and residual is not None and not self.use_gemma:
nvfp4_scale = getattr(self, "nvfp4_scale", None)
if nvfp4_scale is None:
raise ValueError(
f"layeridx={getattr(self, 'layer_idx', None)} RMSNorm NVFP4 output requested "
"but no `nvfp4_scale` is attached; ")
else:
def _can_use_fused_kernel() -> Tuple[bool, str]:
if not hidden_states.is_cuda or not residual.is_cuda:
return False, "inputs must be CUDA tensors"
if not self.weight.is_cuda:
return False, "gamma/weight must be a CUDA tensor"
if hidden_states.ndim < 2:
return False, "input must have rank >= 2"
if hidden_states.shape != residual.shape:
return False, f"input/residual shape mismatch: {tuple(hidden_states.shape)} vs {tuple(residual.shape)}"
n = int(hidden_states.shape[-1])
if self.weight.ndim != 1 or int(self.weight.numel()) != n:
return False, f"gamma/weight must be 1D with numel == hidden_size ({n}), got shape={tuple(self.weight.shape)}"
# Match the underlying C++ op: fp16/bf16 only (no fp8).
if hidden_states.dtype not in (torch.float16,
torch.bfloat16):
return False, f"unsupported dtype {hidden_states.dtype} (expected fp16/bf16)"
if n % 16 != 0:
return False, f"hidden size must be divisible by 16 (got {n})"
# Kernel constraints (see fusedAddRMSNormQuant.cpp).
if n < 2048 or n > 16384:
return False, f"hidden size must be in [2048, 16384] (got {n})"
# SM90+ only.
major, _minor = torch.cuda.get_device_capability(
hidden_states.device)
if major < 9:
return False, f"requires SM90+ GPU, got SM{major}{_minor}"
# Scale tensor constraints.
if (nvfp4_scale is not None
and ((not nvfp4_scale.is_cuda) or nvfp4_scale.dtype
!= torch.float32 or nvfp4_scale.numel() != 1)):
return False, f"nvfp4_scale must be a CUDA float32 tensor with numel==1 (got dtype={getattr(nvfp4_scale, 'dtype', None)}, device={getattr(nvfp4_scale, 'device', None)}, numel={getattr(nvfp4_scale, 'numel', lambda: None)()})"
return True, ""
ok, reason = _can_use_fused_kernel()
if not ok:
raise RuntimeError(
"RMSNorm NVFP4 fused path disabled due to unsupported inputs "
f"(falling back to unfused RMSNorm): {reason}")
else:
from ..custom_ops.torch_custom_ops import \
fused_add_rms_norm_quant
orig_shape = tuple(hidden_states.shape)
n = int(orig_shape[-1])
hs_2d = hidden_states.reshape(-1, n).contiguous()
res_2d = residual.reshape(-1, n)
gamma = self.weight
def _ensure_contiguous_with_dtype(t: torch.Tensor,
key: str):
if t.dtype != hs_2d.dtype:
logger.warning_once(
f"RMSNorm NVFP4 fused path: casting {key} from {t.dtype} to {hs_2d.dtype}.",
key=f"rmsnorm_nvfp4_cast_{key}",
)
t = t.to(dtype=hs_2d.dtype)
return t.contiguous()
res_2d = _ensure_contiguous_with_dtype(res_2d, "residual")
gamma = _ensure_contiguous_with_dtype(gamma, "gamma")
if hs_2d.device != res_2d.device or hs_2d.device != gamma.device:
raise RuntimeError(
"RMSNorm NVFP4 fused path requires all tensors on the same device. "
f"Got input={hs_2d.device}, residual={res_2d.device}, gamma={gamma.device}."
)
sf_scale = nvfp4_scale.contiguous(
) if nvfp4_scale is not None else None
normed_fp4_i32, residual_out_2d, sf_fused = fused_add_rms_norm_quant(
hs_2d,
res_2d,
gamma,
sf_scale,
True,
eps=self.variance_epsilon,
)
normed_fp4_u8 = normed_fp4_i32.view(torch.uint8)
if len(orig_shape) != 2:
normed_fp4_u8 = normed_fp4_u8.reshape(
*orig_shape[:-1], n // 2)
residual_out = residual_out_2d.reshape(orig_shape)
else:
residual_out = residual_out_2d
hidden_states_fused = Fp4QuantizedTensor(
normed_fp4_u8, sf_fused)
return (hidden_states_fused, residual_out
) if return_residual else hidden_states_fused
if IS_FLASHINFER_AVAILABLE:
from ..custom_ops import (flashinfer_fused_add_rmsnorm,
flashinfer_gemma_fused_add_rmsnorm,

View File

@ -328,6 +328,39 @@ class MappingBase:
return torch.tensor_split(torch.arange(num_layers),
self.pp_size)[self.pp_rank].tolist()
def pp_rank_of_layer(self, layer_idx: int, num_layers: int) -> int:
"""Return pipeline-parallel rank that owns `layer_idx` for a model with `num_layers` layers.
Mirrors the partitioning behavior in `pp_layers()`.
"""
if layer_idx < 0 or layer_idx >= num_layers:
raise ValueError(f"{layer_idx=} is out of range for {num_layers=}.")
if not self.has_pp():
return 0
if self.pp_partition is not None:
if len(self.pp_partition) != self.pp_size:
raise ValueError(
f"{len(self.pp_partition)=} does not match {self.pp_size=}."
)
if sum(self.pp_partition) != num_layers:
raise ValueError(
f"{sum(self.pp_partition)=} does not match {num_layers=}.")
end = 0
for pp_rank, n in enumerate(self.pp_partition):
end += n
if layer_idx < end:
return pp_rank
raise RuntimeError("Unreachable: invalid pp_partition.")
base, rem = divmod(num_layers, self.pp_size)
if base == 0:
# Matches torch.tensor_split: first `num_layers` ranks get one layer.
return layer_idx
cutoff = (base + 1) * rem
if layer_idx < cutoff:
return layer_idx // (base + 1)
return rem + (layer_idx - cutoff) // base
def ep_experts(self, num_experts: int) -> List[int]:
assert self.cp_size == 1
experts_per_rank = num_experts // self.moe_ep_size