mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge a08e8f7bbc into c1b0b7350f
This commit is contained in:
commit
ae855d7f67
@ -115,8 +115,6 @@ struct LowLatencyLayerNorm
|
||||
|
||||
uint32_t work_id = blockIdx.x;
|
||||
|
||||
FusedOperator fused_operator(param);
|
||||
|
||||
constexpr auto PACKED_PER_N_BLOCK = Traits::N_BLOCK / N_THREADS / Traits::PACKED_ELEMS_PER_COMPUTE;
|
||||
|
||||
typename Traits::AccumulatorType data[PACKED_PER_N_BLOCK][Traits::PACKED_ELEMS_PER_COMPUTE];
|
||||
@ -139,7 +137,7 @@ struct LowLatencyLayerNorm
|
||||
for (int i = 0; i < PACKED_PER_N_BLOCK; i++)
|
||||
{
|
||||
auto offset = (thread_id + i * N_THREADS) * Traits::PACKED_ELEMS_PER_COMPUTE;
|
||||
if (offset <= sz)
|
||||
if (offset < sz)
|
||||
{
|
||||
data[i] = *reinterpret_cast<PackedType const*>(&g_data[offset]);
|
||||
}
|
||||
@ -155,6 +153,14 @@ struct LowLatencyLayerNorm
|
||||
|
||||
static_assert(Traits::OUTPUT_SCALE != SCALE_TYPE::VECTOR);
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
|
||||
if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
|
||||
{
|
||||
cudaGridDependencySynchronize();
|
||||
}
|
||||
#endif
|
||||
FusedOperator fused_operator(param);
|
||||
|
||||
if constexpr (Traits::BIAS == SCALE_TYPE::VECTOR)
|
||||
{
|
||||
load_to_register(param.bias, r_bias, param.n);
|
||||
@ -175,13 +181,6 @@ struct LowLatencyLayerNorm
|
||||
load_to_register(param.beta, r_beta, param.n);
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 12))
|
||||
if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
|
||||
{
|
||||
cudaGridDependencySynchronize();
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
}
|
||||
#endif
|
||||
load_to_register(¶m.input[work_id * param.n], data, param.n);
|
||||
|
||||
if constexpr (Traits::RESIDUAL)
|
||||
@ -259,12 +258,12 @@ struct LowLatencyLayerNorm
|
||||
if constexpr (!Traits::RMS_NORM)
|
||||
{
|
||||
mean = var_and_mean[1] / param.n;
|
||||
variance = rsqrtf(
|
||||
var_and_mean[0] / param.n - var_and_mean[1] * var_and_mean[1] + (Traits::AccumulatorType)(1e-5));
|
||||
variance = rsqrtf(var_and_mean[0] / param.n - var_and_mean[1] * var_and_mean[1]
|
||||
+ (Traits::AccumulatorType)(param.layernorm_eps));
|
||||
}
|
||||
else
|
||||
{
|
||||
variance = rsqrtf(var_and_mean[0] / param.n + (Traits::AccumulatorType)(1e-5));
|
||||
variance = rsqrtf(var_and_mean[0] / param.n + (Traits::AccumulatorType)(param.layernorm_eps));
|
||||
}
|
||||
|
||||
for (int i = 0; i < PACKED_PER_N_BLOCK; i++)
|
||||
@ -333,6 +332,14 @@ struct LowLatencyLayerNorm
|
||||
{
|
||||
__shared__ Shared shared;
|
||||
compute(param, &shared);
|
||||
__syncthreads();
|
||||
asm volatile("membar.gl;" : : : "memory");
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
|
||||
if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
|
||||
{
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -201,25 +201,35 @@ struct WarpSpecializedLayerNorm
|
||||
}
|
||||
// if (blockIdx.x == 0) printf("Pushed tile %d to MATH.\n", m_base);
|
||||
|
||||
if constexpr (FIRST_RUN)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
|
||||
if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
|
||||
{
|
||||
// Ensure upstream kernel writes are visible before reading dependent activation/residual data.
|
||||
cudaGridDependencySynchronize();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
const uint32_t eff_m_block
|
||||
= std::min(static_cast<uint32_t>(Traits::M_BLOCK), static_cast<uint32_t>(param.m - m_base));
|
||||
const auto tx
|
||||
= (Traits::M_BLOCK * param.n * sizeof(typename Traits::InputType) * (Traits::RESIDUAL ? 2 : 1))
|
||||
+ (FIRST_RUN ? sizeof(AuxData) / Traits::N_BLOCK * param.n : 0);
|
||||
= (eff_m_block * param.n * sizeof(typename Traits::InputType) * (Traits::RESIDUAL ? 2 : 1))
|
||||
+ (FIRST_RUN ? (sizeof(AuxData) / Traits::N_BLOCK * param.n) : 0);
|
||||
|
||||
auto vec_buffer_ptr = input_vec_fifo_w.tmaReserve(tx);
|
||||
|
||||
// if (blockIdx.x == 0) printf("SMEM buffer ready, start loading tile %d.\n", m_base);
|
||||
|
||||
if constexpr (FIRST_RUN)
|
||||
{
|
||||
cudaGridDependencySynchronize();
|
||||
}
|
||||
|
||||
for (int i = 0; i < Traits::M_BLOCK; i++)
|
||||
{
|
||||
load_a_vec(¶m.input[(m_base + i) * param.n],
|
||||
__nvvm_get_smem_pointer(&shared->input_vec[vec_buffer_ptr][0][i * Traits::N_BLOCK]),
|
||||
param.n * sizeof(typename Traits::InputType),
|
||||
__nvvm_get_smem_pointer(input_vec_fifo_w.barrier_ptr(vec_buffer_ptr)));
|
||||
if (i < eff_m_block) [[likely]]
|
||||
{
|
||||
load_a_vec(¶m.input[(m_base + i) * param.n],
|
||||
__nvvm_get_smem_pointer(&shared->input_vec[vec_buffer_ptr][0][i * Traits::N_BLOCK]),
|
||||
param.n * sizeof(typename Traits::InputType),
|
||||
__nvvm_get_smem_pointer(input_vec_fifo_w.barrier_ptr(vec_buffer_ptr)));
|
||||
}
|
||||
}
|
||||
|
||||
// Use templated lambdas to defer resolving the symbols like "param.residual".
|
||||
@ -231,10 +241,13 @@ struct WarpSpecializedLayerNorm
|
||||
{
|
||||
for (int i = 0; i < Traits::M_BLOCK; i++)
|
||||
{
|
||||
load_a_vec(¶m.residual[(m_base + i) * param.n],
|
||||
__nvvm_get_smem_pointer(&shared->input_vec[vec_buffer_ptr][1][i * Traits::N_BLOCK]),
|
||||
param.n * sizeof(typename Traits::InputType),
|
||||
__nvvm_get_smem_pointer(input_vec_fifo_w.barrier_ptr(vec_buffer_ptr)));
|
||||
if (i < eff_m_block) [[likely]]
|
||||
{
|
||||
load_a_vec(¶m.residual[(m_base + i) * param.n],
|
||||
__nvvm_get_smem_pointer(&shared->input_vec[vec_buffer_ptr][1][i * Traits::N_BLOCK]),
|
||||
param.n * sizeof(typename Traits::InputType),
|
||||
__nvvm_get_smem_pointer(input_vec_fifo_w.barrier_ptr(vec_buffer_ptr)));
|
||||
}
|
||||
}
|
||||
}(param);
|
||||
}
|
||||
@ -423,6 +436,13 @@ struct WarpSpecializedLayerNorm
|
||||
|
||||
using FusedOperator = GetFusedOperator<typename Traits::FusedOperator>;
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
|
||||
if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
|
||||
{
|
||||
// Ensure upstream kernel writes are visible before reading dependent activation/residual data.
|
||||
cudaGridDependencySynchronize();
|
||||
}
|
||||
#endif
|
||||
FusedOperator fused_operator(param);
|
||||
|
||||
static_assert(Traits::PERSISTENT_MODE || Traits::MATH_WARPGROUPS == 1);
|
||||
@ -446,6 +466,9 @@ struct WarpSpecializedLayerNorm
|
||||
{
|
||||
m_base = block_id;
|
||||
}
|
||||
const uint32_t eff_m_block
|
||||
= std::min(static_cast<uint32_t>(Traits::M_BLOCK), static_cast<uint32_t>(param.m - m_base));
|
||||
|
||||
// if (blockIdx.x == 0 && thread_id == 0) printf("MATH got tile %d.\n", m_base);
|
||||
|
||||
// Peek for data ready.
|
||||
@ -613,11 +636,12 @@ struct WarpSpecializedLayerNorm
|
||||
{
|
||||
mean[m_offset] /= param.n;
|
||||
variance[m_offset] = rsqrtf(variance[m_offset] / param.n - mean[m_offset] * mean[m_offset]
|
||||
+ (Traits::AccumulatorType)(1e-5));
|
||||
+ (Traits::AccumulatorType)(param.layernorm_eps));
|
||||
}
|
||||
else
|
||||
{
|
||||
variance[m_offset] = rsqrtf(variance[m_offset] / param.n + (Traits::AccumulatorType)(1e-5));
|
||||
variance[m_offset]
|
||||
= rsqrtf(variance[m_offset] / param.n + (Traits::AccumulatorType)(param.layernorm_eps));
|
||||
}
|
||||
}
|
||||
|
||||
@ -659,8 +683,7 @@ struct WarpSpecializedLayerNorm
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll Traits::M_BLOCK
|
||||
for (int m_offset = 0; m_offset < Traits::M_BLOCK; m_offset++)
|
||||
for (int m_offset = 0; m_offset < eff_m_block; m_offset++)
|
||||
{
|
||||
auto m = m_base + m_offset;
|
||||
|
||||
@ -801,23 +824,19 @@ struct WarpSpecializedLayerNorm
|
||||
shared->init(threadIdx.x == 0);
|
||||
|
||||
__syncthreads();
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 12))
|
||||
#if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM100_ALL))
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)
|
||||
if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
|
||||
{
|
||||
auto block_id = blockIdx.x;
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
auto lane_id = threadIdx.x % 32;
|
||||
auto tid_in_wg = threadIdx.x % 128;
|
||||
|
||||
if (warp_id < 4)
|
||||
{
|
||||
asm volatile("{setmaxnreg.dec.sync.aligned.u32 56; \n\t}");
|
||||
if (warp_id == 0)
|
||||
{
|
||||
scheduler(lane_id, gridDim.x * gridDim.y * gridDim.z, param, shared);
|
||||
// PRE-EXIT after all tiles have been scheduled.
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
}
|
||||
else if (warp_id == 1)
|
||||
{
|
||||
@ -829,8 +848,10 @@ struct WarpSpecializedLayerNorm
|
||||
asm volatile("{setmaxnreg.inc.sync.aligned.u32 224; \n\t}");
|
||||
compute(block_id, threadIdx.x / 128 - 1, tid_in_wg, param, shared);
|
||||
}
|
||||
__syncthreads();
|
||||
asm volatile("membar.gl;" : : : "memory");
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@ -66,6 +66,7 @@ add_library(
|
||||
fp8Quantize.cpp
|
||||
dsv3FusedAGemmOp.cpp
|
||||
fusedQKNormRopeOp.cpp
|
||||
fusedAddRMSNormQuant.cpp
|
||||
fusedTopkSoftmax.cpp
|
||||
gatherTreeOp.cpp
|
||||
groupRmsNormOp.cpp
|
||||
|
||||
200
cpp/tensorrt_llm/thop/fusedAddRMSNormQuant.cpp
Normal file
200
cpp/tensorrt_llm/thop/fusedAddRMSNormQuant.cpp
Normal file
@ -0,0 +1,200 @@
|
||||
/*
|
||||
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/kernels/fusedLayernormKernels/layernorm_param.h"
|
||||
#include "tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.h"
|
||||
#include "tensorrt_llm/kernels/quantization.h"
|
||||
#include "tensorrt_llm/thop/thUtils.h"
|
||||
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/EmptyTensor.h>
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
|
||||
TRTLLM_NAMESPACE_BEGIN
|
||||
|
||||
namespace torch_ext
|
||||
{
|
||||
|
||||
// Fused Add + RMSNorm + FP4 Quantization kernel
|
||||
// input: [M, N] - input tensor (fp16/bf16)
|
||||
// residual: [M, N] - residual tensor (fp16/bf16)
|
||||
// gamma: [N] - RMSNorm weight (fp16/bf16)
|
||||
// sf_scale: [1] - optional scale factor for FP4 quantization (float)
|
||||
// use_rms_norm: bool - if true use RMSNorm, else use LayerNorm
|
||||
// Returns:
|
||||
// normed_output: [M, N/8] - FP4 quantized normalized output (uint32_t, packed)
|
||||
// output: [M, N] - pre-norm output (input + residual), same dtype as input
|
||||
// sf_out: scale factors for FP4 (uint8_t), swizzled layout
|
||||
//
|
||||
// NOTE: This kernel requires SM90 (Hopper) or SM100 (Blackwell) GPU architecture.
|
||||
// NOTE: Hidden dimension N must be >= 2048 and <= 16384.
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_add_rms_norm_quant(at::Tensor const& input,
|
||||
at::Tensor const& residual, at::Tensor const& gamma, std::optional<at::Tensor> const& sf_scale, bool use_rms_norm,
|
||||
double eps)
|
||||
{
|
||||
CHECK_TH_CUDA(input);
|
||||
CHECK_CONTIGUOUS(input);
|
||||
CHECK_TH_CUDA(residual);
|
||||
CHECK_CONTIGUOUS(residual);
|
||||
CHECK_TH_CUDA(gamma);
|
||||
CHECK_CONTIGUOUS(gamma);
|
||||
|
||||
// Check GPU architecture - kernel requires SM90+ (Hopper/Blackwell)
|
||||
auto const device = input.get_device();
|
||||
cudaDeviceProp props;
|
||||
AT_CUDA_CHECK(cudaGetDeviceProperties(&props, device));
|
||||
TORCH_CHECK(props.major >= 9,
|
||||
"fused_add_rms_norm_quant requires SM90 (Hopper) or newer GPU architecture. "
|
||||
"Current device: sm_",
|
||||
props.major, props.minor);
|
||||
|
||||
auto const& inputShape = input.sizes();
|
||||
auto const& rank = inputShape.size();
|
||||
|
||||
TORCH_CHECK(rank == 2, "input should be 2D tensor [M, N].");
|
||||
TORCH_CHECK(residual.sizes() == inputShape, "residual shape must match input shape.");
|
||||
|
||||
int64_t const m = inputShape[0];
|
||||
int64_t const n = inputShape[1];
|
||||
// Some warp-specialized kernels may issue vectorized stores that assume M is padded.
|
||||
// Allocate a bit of extra space to avoid out-of-bounds writes when M is not a multiple of 8.
|
||||
int64_t const m_padded = (m + 31) / 32 * 32;
|
||||
|
||||
TORCH_CHECK(gamma.sizes()[0] == n, "gamma size must match hidden dimension N.");
|
||||
TORCH_CHECK(n >= 2048, "Hidden dimension N must be >= 2048 (kernel constraint).");
|
||||
TORCH_CHECK(n <= 16384, "Hidden dimension N must be <= 16384.");
|
||||
TORCH_CHECK(n % 16 == 0, "Hidden dimension N must be divisible by 16 for FP4 quantization.");
|
||||
|
||||
// Validate sf_scale if provided
|
||||
float* sfScalePtr = nullptr;
|
||||
if (sf_scale.has_value())
|
||||
{
|
||||
CHECK_INPUT(sf_scale.value(), torch::kFloat32);
|
||||
sfScalePtr = sf_scale.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
// Allocate output tensors
|
||||
// normed_output: FP4 packed output [M, N/8] as uint32_t (8 FP4 values packed per uint32)
|
||||
// NOTE: allocate [M_padded, ...] to avoid OOB writes; return a view of [M, ...] to keep API stable.
|
||||
at::Tensor normed_output_padded
|
||||
= at::detail::empty_cuda({m_padded, n / 8}, torch::kInt32, input.device(), std::nullopt);
|
||||
at::Tensor normed_output = (m_padded == m) ? normed_output_padded : normed_output_padded.narrow(0, 0, m);
|
||||
|
||||
// output: pre-norm output (input + residual) [M, N], same dtype as input
|
||||
// NOTE: allocate [M_padded, ...] to avoid OOB writes; return a view of [M, ...] to keep API stable.
|
||||
at::Tensor output_padded = at::detail::empty_cuda({m_padded, n}, input.scalar_type(), input.device(), std::nullopt);
|
||||
at::Tensor output = (m_padded == m) ? output_padded : output_padded.narrow(0, 0, m);
|
||||
|
||||
// sf_out: scale factors for FP4, swizzled layout
|
||||
// sfVecSize = 16 for FP4 quantization (16 FP4 values share one scale factor)
|
||||
int64_t const sfVecSize = 16;
|
||||
// NOTE: allocate using m_padded to avoid OOB writes for warp-specialized/vectorized stores when M is not padded.
|
||||
// Return a view of the original (un-padded) size to keep the API stable.
|
||||
int64_t const sfSize = tensorrt_llm::computeSwizzledLayoutSFSize(m, n / sfVecSize);
|
||||
int64_t const sfSizePadded = tensorrt_llm::computeSwizzledLayoutSFSize(m_padded, n / sfVecSize);
|
||||
at::Tensor sf_out_padded = at::detail::empty_cuda({sfSizePadded}, SF_DTYPE, input.device(), std::nullopt);
|
||||
at::Tensor sf_out = (m_padded == m) ? sf_out_padded : sf_out_padded.narrow(0, 0, sfSize);
|
||||
|
||||
// Get number of SMs for persistent kernel
|
||||
static int const multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
|
||||
// Allocate counters for warp-specialized kernel using PyTorch allocator.
|
||||
//
|
||||
// NOTE: We cache this tensor to avoid per-call allocations. We use `thread_local` so
|
||||
// concurrent calls from different threads don't share the same counters buffer (which
|
||||
// could cause races across different CUDA streams).
|
||||
static thread_local std::unordered_map<int, at::Tensor> counters_tensor_cache;
|
||||
auto& counters_tensor = counters_tensor_cache[device];
|
||||
int64_t const counters_bytes = static_cast<int64_t>(sizeof(tensorrt_llm::kernels::WarpSpecializedCounters));
|
||||
if (!counters_tensor.defined() || counters_tensor.numel() != counters_bytes)
|
||||
{
|
||||
counters_tensor = at::detail::empty_cuda({counters_bytes}, torch::kByte, input.device(), std::nullopt);
|
||||
counters_tensor.zero_();
|
||||
}
|
||||
auto* counters
|
||||
= reinterpret_cast<tensorrt_llm::kernels::WarpSpecializedCounters*>(counters_tensor.mutable_data_ptr());
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device);
|
||||
|
||||
#define LAUNCH_FUSED_ADD_RMS_NORM_QUANT(T) \
|
||||
do \
|
||||
{ \
|
||||
using Param = tensorrt_llm::kernels::GeneralFP4AddBiasResidualPreLayerNormParam<T>; \
|
||||
tensorrt_llm::kernels::WarpSpecializedParam<Param> param; \
|
||||
param.normed_output = reinterpret_cast<uint32_t*>(normed_output.data_ptr()); \
|
||||
param.output = reinterpret_cast<T*>(output.data_ptr()); \
|
||||
param.input = const_cast<T*>(reinterpret_cast<T const*>(input.data_ptr())); \
|
||||
param.sf_scale = sfScalePtr; \
|
||||
param.sf_out = reinterpret_cast<uint32_t*>(sf_out.data_ptr()); \
|
||||
param.residual = reinterpret_cast<T const*>(residual.data_ptr()); \
|
||||
param.bias = nullptr; \
|
||||
param.gamma = reinterpret_cast<T const*>(gamma.data_ptr()); \
|
||||
param.beta = nullptr; \
|
||||
param.m = static_cast<int>(m); \
|
||||
param.n = static_cast<int>(n); \
|
||||
param.layernorm_eps = static_cast<float>(eps); \
|
||||
param.stream = stream; \
|
||||
param.counters = counters; \
|
||||
tensorrt_llm::kernels::invokeWSLayerNorm<Param>(param, use_rms_norm, multiProcessorCount); \
|
||||
} while (0)
|
||||
|
||||
if (input.scalar_type() == at::ScalarType::Half)
|
||||
{
|
||||
LAUNCH_FUSED_ADD_RMS_NORM_QUANT(half);
|
||||
}
|
||||
else if (input.scalar_type() == at::ScalarType::BFloat16)
|
||||
{
|
||||
#ifdef ENABLE_BF16
|
||||
LAUNCH_FUSED_ADD_RMS_NORM_QUANT(__nv_bfloat16);
|
||||
#else
|
||||
C10_THROW_ERROR(NotImplementedError, "BFloat16 must be enabled for fused_add_rms_norm_quant with bf16 input.");
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
C10_THROW_ERROR(
|
||||
NotImplementedError, "fused_add_rms_norm_quant only supports input tensor with dtypes fp16/bf16.");
|
||||
}
|
||||
|
||||
#undef LAUNCH_FUSED_ADD_RMS_NORM_QUANT
|
||||
// No explicit sync needed - kernel runs asynchronously on the stream
|
||||
return std::make_tuple(normed_output, output, sf_out);
|
||||
}
|
||||
|
||||
} // namespace torch_ext
|
||||
|
||||
TRTLLM_NAMESPACE_END
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def(
|
||||
"fused_add_rms_norm_quant(Tensor input, Tensor residual, Tensor gamma, "
|
||||
"Tensor? sf_scale, bool use_rms_norm=True, float eps=1e-6) -> (Tensor, Tensor, Tensor)");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
{
|
||||
m.impl("fused_add_rms_norm_quant", &tensorrt_llm::torch_ext::fused_add_rms_norm_quant);
|
||||
}
|
||||
@ -1869,3 +1869,56 @@ def record_stream(tensor: torch.Tensor, stream_id: int) -> None:
|
||||
stream = get_stream(stream_id)
|
||||
assert stream is not None
|
||||
tensor.record_stream(stream)
|
||||
|
||||
|
||||
def fused_add_rms_norm_quant(
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
gamma: torch.Tensor,
|
||||
sf_scale: Optional[torch.Tensor],
|
||||
use_rms_norm: bool = True,
|
||||
eps: float = 1e-6,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Fused Add + RMSNorm/LayerNorm + FP4 Quantization kernel.
|
||||
|
||||
Args:
|
||||
input: [M, N] input tensor (fp16/bf16)
|
||||
residual: [M, N] residual tensor (fp16/bf16)
|
||||
gamma: [N] normalization weight (fp16/bf16)
|
||||
sf_scale: [1] optional scale factor for FP4 quantization (float32)
|
||||
use_rms_norm: if True use RMSNorm, else use LayerNorm
|
||||
eps: epsilon for normalization
|
||||
|
||||
Returns:
|
||||
normed_output_fp4: [M, N/8] FP4 quantized normalized output (int32, packed)
|
||||
output: [M, N] pre-norm output (input + residual), same dtype as input
|
||||
sf_out: scale factors for FP4 quantization (uint8), swizzled layout
|
||||
|
||||
Note:
|
||||
This kernel requires SM90 (Hopper) or SM100 (Blackwell) GPU.
|
||||
Hidden dimension N must be >= 2048 and <= 16384.
|
||||
"""
|
||||
return torch.ops.trtllm.fused_add_rms_norm_quant(input, residual, gamma,
|
||||
sf_scale, use_rms_norm,
|
||||
eps)
|
||||
|
||||
|
||||
@torch.library.register_fake("trtllm::fused_add_rms_norm_quant")
|
||||
def _fused_add_rms_norm_quant_fake(
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
gamma: torch.Tensor,
|
||||
sf_scale: Optional[torch.Tensor],
|
||||
use_rms_norm: bool = True,
|
||||
eps: float = 1e-5,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
m, n = input.shape
|
||||
# normed_output_fp4: [M, N/8] as int32 (8 FP4 values packed per int32)
|
||||
normed_output_fp4 = input.new_empty((m, n // 8), dtype=torch.int32)
|
||||
# output: [M, N] pre-norm output, same dtype as input
|
||||
output = input.new_empty((m, n), dtype=input.dtype)
|
||||
# sf_out: scale factors, swizzled layout
|
||||
sf_vec_size = 16
|
||||
sf_size = ((m + 127) // 128) * 128 * ((n // sf_vec_size + 3) // 4) * 4
|
||||
sf_out = input.new_empty((sf_size, ), dtype=torch.uint8)
|
||||
return normed_output_fp4, output, sf_out
|
||||
|
||||
@ -625,6 +625,7 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
super().__init__()
|
||||
config = model_config.pretrained_config
|
||||
self.layer_idx = layer_idx
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.mapping = model_config.mapping
|
||||
self.enable_attention_dp = model_config.mapping.enable_attention_dp
|
||||
self.is_quanted = model_config.quant_config and model_config.quant_config.quant_mode.has_any_quant(
|
||||
@ -649,14 +650,30 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
layer_idx=layer_idx,
|
||||
use_custom_cublas_mm=use_custom_cublas_mm,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
|
||||
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
differ_pp_stage_with_previous_layer = False
|
||||
if self.mapping.has_pp():
|
||||
prev_layer_idx = max(self.layer_idx - 1, 0)
|
||||
differ_pp_stage_with_previous_layer = (
|
||||
self.layer_idx > 0 and self.mapping.pp_rank_of_layer(
|
||||
self.layer_idx,
|
||||
self.num_hidden_layers) != self.mapping.pp_rank_of_layer(
|
||||
prev_layer_idx, self.num_hidden_layers))
|
||||
self.disable_nvfp4_layernorm_fusion = os.environ.get(
|
||||
"TRTLLM_DISABLE_NVFP4_LAYERNORM_FUSION", "1") == "1"
|
||||
self.input_layernorm = RMSNorm(
|
||||
hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype,
|
||||
quantize_type="nvfp4"
|
||||
if not self.disable_nvfp4_layernorm_fusion and self.is_nvfp4
|
||||
and not (differ_pp_stage_with_previous_layer) else None)
|
||||
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype,
|
||||
quantize_type="nvfp4" if not self.disable_nvfp4_layernorm_fusion
|
||||
and self.is_nvfp4 else None)
|
||||
self.all_reduce = AllReduce(mapping=model_config.mapping,
|
||||
strategy=model_config.allreduce_strategy)
|
||||
|
||||
@ -676,7 +693,6 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
self.PRE_MLP_FUSION = self.mapping.has_tp(
|
||||
) and not self.enable_attention_dp and self.enable_fusion
|
||||
self.POST_MLP_FUSION = self.mapping.has_tp() and self.enable_fusion
|
||||
|
||||
if self.is_nvfp4:
|
||||
self.pre_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4
|
||||
self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4
|
||||
@ -697,17 +713,16 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
def forward(
|
||||
self,
|
||||
position_ids: torch.IntTensor,
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, Fp4QuantizedTensor]:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
@ -739,6 +754,8 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
else:
|
||||
hidden_states, residual = all_reduce_output
|
||||
else:
|
||||
if self.is_nvfp4:
|
||||
self.post_attention_layernorm.nvfp4_scale = self.mlp.gate_up_proj.input_scale
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
@ -803,6 +820,12 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
else:
|
||||
hidden_states, residual = all_reduce_output
|
||||
elif self.next_layer_layernorm:
|
||||
# NOTE: for the last decoder layer, `next_layer_layernorm` is the final model norm without nvfp4 quant
|
||||
# (`self.model.norm`), and `next_attn` is expected to be None.
|
||||
if self.next_attn is not None and hasattr(self.next_attn.qkv_proj,
|
||||
'input_scale'):
|
||||
self.next_layer_layernorm.nvfp4_scale = self.next_attn.qkv_proj.input_scale
|
||||
|
||||
hidden_states, residual = self.next_layer_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
|
||||
@ -20,7 +20,10 @@ from typing import Optional, Tuple, TypeAlias, Union, cast
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
|
||||
from ..utils import Fp4QuantizedTensor
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
@ -37,11 +40,17 @@ class RMSNorm(nn.Module):
|
||||
device: Optional[torch.device] = None,
|
||||
has_weights: bool = True,
|
||||
use_gemma: bool = False,
|
||||
quantize_type: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if use_gemma and not has_weights:
|
||||
raise ValueError("has_weights must be True if use_gemma is True")
|
||||
if quantize_type is not None:
|
||||
if quantize_type != "nvfp4":
|
||||
raise NotImplementedError(
|
||||
f"Quantize type {quantize_type} not implemented in RMSNorm")
|
||||
self.is_nvfp4 = quantize_type == "nvfp4"
|
||||
|
||||
if has_weights:
|
||||
if not use_gemma:
|
||||
@ -65,12 +74,112 @@ class RMSNorm(nn.Module):
|
||||
residual: Union[
|
||||
Optional[torch.Tensor],
|
||||
_ArgumentNotSpecifiedSentinelType] = _ARGUMENT_NOT_SPECIFIED_SENTINEL,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||
) -> Union[torch.Tensor, Fp4QuantizedTensor, Tuple[Union[
|
||||
torch.Tensor, Fp4QuantizedTensor], Optional[torch.Tensor]]]:
|
||||
return_residual = True
|
||||
if residual is self._ARGUMENT_NOT_SPECIFIED_SENTINEL:
|
||||
return_residual = False
|
||||
residual = None
|
||||
|
||||
if self.is_nvfp4 and residual is not None and not self.use_gemma:
|
||||
nvfp4_scale = getattr(self, "nvfp4_scale", None)
|
||||
if nvfp4_scale is None:
|
||||
raise ValueError(
|
||||
f"layeridx={getattr(self, 'layer_idx', None)} RMSNorm NVFP4 output requested "
|
||||
"but no `nvfp4_scale` is attached; ")
|
||||
else:
|
||||
|
||||
def _can_use_fused_kernel() -> Tuple[bool, str]:
|
||||
if not hidden_states.is_cuda or not residual.is_cuda:
|
||||
return False, "inputs must be CUDA tensors"
|
||||
if not self.weight.is_cuda:
|
||||
return False, "gamma/weight must be a CUDA tensor"
|
||||
if hidden_states.ndim < 2:
|
||||
return False, "input must have rank >= 2"
|
||||
if hidden_states.shape != residual.shape:
|
||||
return False, f"input/residual shape mismatch: {tuple(hidden_states.shape)} vs {tuple(residual.shape)}"
|
||||
n = int(hidden_states.shape[-1])
|
||||
if self.weight.ndim != 1 or int(self.weight.numel()) != n:
|
||||
return False, f"gamma/weight must be 1D with numel == hidden_size ({n}), got shape={tuple(self.weight.shape)}"
|
||||
# Match the underlying C++ op: fp16/bf16 only (no fp8).
|
||||
if hidden_states.dtype not in (torch.float16,
|
||||
torch.bfloat16):
|
||||
return False, f"unsupported dtype {hidden_states.dtype} (expected fp16/bf16)"
|
||||
if n % 16 != 0:
|
||||
return False, f"hidden size must be divisible by 16 (got {n})"
|
||||
# Kernel constraints (see fusedAddRMSNormQuant.cpp).
|
||||
if n < 2048 or n > 16384:
|
||||
return False, f"hidden size must be in [2048, 16384] (got {n})"
|
||||
# SM90+ only.
|
||||
major, _minor = torch.cuda.get_device_capability(
|
||||
hidden_states.device)
|
||||
if major < 9:
|
||||
return False, f"requires SM90+ GPU, got SM{major}{_minor}"
|
||||
# Scale tensor constraints.
|
||||
if (nvfp4_scale is not None
|
||||
and ((not nvfp4_scale.is_cuda) or nvfp4_scale.dtype
|
||||
!= torch.float32 or nvfp4_scale.numel() != 1)):
|
||||
return False, f"nvfp4_scale must be a CUDA float32 tensor with numel==1 (got dtype={getattr(nvfp4_scale, 'dtype', None)}, device={getattr(nvfp4_scale, 'device', None)}, numel={getattr(nvfp4_scale, 'numel', lambda: None)()})"
|
||||
return True, ""
|
||||
|
||||
ok, reason = _can_use_fused_kernel()
|
||||
if not ok:
|
||||
raise RuntimeError(
|
||||
"RMSNorm NVFP4 fused path disabled due to unsupported inputs "
|
||||
f"(falling back to unfused RMSNorm): {reason}")
|
||||
else:
|
||||
from ..custom_ops.torch_custom_ops import \
|
||||
fused_add_rms_norm_quant
|
||||
|
||||
orig_shape = tuple(hidden_states.shape)
|
||||
n = int(orig_shape[-1])
|
||||
hs_2d = hidden_states.reshape(-1, n).contiguous()
|
||||
res_2d = residual.reshape(-1, n)
|
||||
gamma = self.weight
|
||||
|
||||
def _ensure_contiguous_with_dtype(t: torch.Tensor,
|
||||
key: str):
|
||||
if t.dtype != hs_2d.dtype:
|
||||
logger.warning_once(
|
||||
f"RMSNorm NVFP4 fused path: casting {key} from {t.dtype} to {hs_2d.dtype}.",
|
||||
key=f"rmsnorm_nvfp4_cast_{key}",
|
||||
)
|
||||
t = t.to(dtype=hs_2d.dtype)
|
||||
return t.contiguous()
|
||||
|
||||
res_2d = _ensure_contiguous_with_dtype(res_2d, "residual")
|
||||
gamma = _ensure_contiguous_with_dtype(gamma, "gamma")
|
||||
|
||||
if hs_2d.device != res_2d.device or hs_2d.device != gamma.device:
|
||||
raise RuntimeError(
|
||||
"RMSNorm NVFP4 fused path requires all tensors on the same device. "
|
||||
f"Got input={hs_2d.device}, residual={res_2d.device}, gamma={gamma.device}."
|
||||
)
|
||||
|
||||
sf_scale = nvfp4_scale.contiguous(
|
||||
) if nvfp4_scale is not None else None
|
||||
|
||||
normed_fp4_i32, residual_out_2d, sf_fused = fused_add_rms_norm_quant(
|
||||
hs_2d,
|
||||
res_2d,
|
||||
gamma,
|
||||
sf_scale,
|
||||
True,
|
||||
eps=self.variance_epsilon,
|
||||
)
|
||||
normed_fp4_u8 = normed_fp4_i32.view(torch.uint8)
|
||||
if len(orig_shape) != 2:
|
||||
normed_fp4_u8 = normed_fp4_u8.reshape(
|
||||
*orig_shape[:-1], n // 2)
|
||||
residual_out = residual_out_2d.reshape(orig_shape)
|
||||
else:
|
||||
residual_out = residual_out_2d
|
||||
|
||||
hidden_states_fused = Fp4QuantizedTensor(
|
||||
normed_fp4_u8, sf_fused)
|
||||
return (hidden_states_fused, residual_out
|
||||
) if return_residual else hidden_states_fused
|
||||
|
||||
if IS_FLASHINFER_AVAILABLE:
|
||||
from ..custom_ops import (flashinfer_fused_add_rmsnorm,
|
||||
flashinfer_gemma_fused_add_rmsnorm,
|
||||
|
||||
@ -328,6 +328,39 @@ class MappingBase:
|
||||
return torch.tensor_split(torch.arange(num_layers),
|
||||
self.pp_size)[self.pp_rank].tolist()
|
||||
|
||||
def pp_rank_of_layer(self, layer_idx: int, num_layers: int) -> int:
|
||||
"""Return pipeline-parallel rank that owns `layer_idx` for a model with `num_layers` layers.
|
||||
Mirrors the partitioning behavior in `pp_layers()`.
|
||||
"""
|
||||
if layer_idx < 0 or layer_idx >= num_layers:
|
||||
raise ValueError(f"{layer_idx=} is out of range for {num_layers=}.")
|
||||
if not self.has_pp():
|
||||
return 0
|
||||
|
||||
if self.pp_partition is not None:
|
||||
if len(self.pp_partition) != self.pp_size:
|
||||
raise ValueError(
|
||||
f"{len(self.pp_partition)=} does not match {self.pp_size=}."
|
||||
)
|
||||
if sum(self.pp_partition) != num_layers:
|
||||
raise ValueError(
|
||||
f"{sum(self.pp_partition)=} does not match {num_layers=}.")
|
||||
end = 0
|
||||
for pp_rank, n in enumerate(self.pp_partition):
|
||||
end += n
|
||||
if layer_idx < end:
|
||||
return pp_rank
|
||||
raise RuntimeError("Unreachable: invalid pp_partition.")
|
||||
|
||||
base, rem = divmod(num_layers, self.pp_size)
|
||||
if base == 0:
|
||||
# Matches torch.tensor_split: first `num_layers` ranks get one layer.
|
||||
return layer_idx
|
||||
cutoff = (base + 1) * rem
|
||||
if layer_idx < cutoff:
|
||||
return layer_idx // (base + 1)
|
||||
return rem + (layer_idx - cutoff) // base
|
||||
|
||||
def ep_experts(self, num_experts: int) -> List[int]:
|
||||
assert self.cp_size == 1
|
||||
experts_per_rank = num_experts // self.moe_ep_size
|
||||
|
||||
Loading…
Reference in New Issue
Block a user