[None][feat] Optimize NemotronH model with elementwise and nvfp4 fusion (#11273)

Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
This commit is contained in:
Wanli Jiang 2026-02-12 22:25:31 +08:00 committed by GitHub
parent ef7830d137
commit 421eb9e39c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 2203 additions and 214 deletions

View File

@ -43,6 +43,8 @@ struct Causal_conv1d_fwd_kernel_traits
static_assert(kWidth <= kNElts);
static constexpr bool kIsVecLoad = kIsVecLoad_;
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
static_assert(kNThreads_ % 32 == 0, "kNThreads must be a multiple of 32 for warp shuffle");
static_assert(sizeof(vec_t) == 16, "vec_t must be 16 bytes for warp shuffle optimization");
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
@ -123,7 +125,7 @@ __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_fwd_kernel(C
#pragma unroll
for (int i = 0; i < kWidth; ++i)
{
weight_vals[i] = float(weight[i * params.weight_width_stride]);
weight_vals[i] = float(__ldg(&weight[i * params.weight_width_stride]));
}
constexpr int kChunkSize = kNThreads * kNElts;
@ -144,20 +146,41 @@ __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_fwd_kernel(C
x, *reinterpret_cast<input_t(*)[kNElts]>(&x_vals_load[kNElts]), seqlen - chunk * kChunkSize);
}
x += kChunkSize;
int const lane_id = tidx & 31;
vec_t high_val = reinterpret_cast<vec_t*>(x_vals_load)[1];
__syncthreads();
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
// the last elements of the previous chunk.
if (tidx < kNThreads - 1)
{
smem_exchange[tidx] = reinterpret_cast<vec_t*>(x_vals_load)[1];
smem_exchange[tidx] = high_val;
}
__syncthreads();
reinterpret_cast<vec_t*>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
// Get neighbor data: use warp shuffle for most threads, shared memory for warp boundaries
vec_t neighbor;
uint32_t* high_val_p = reinterpret_cast<uint32_t*>(&high_val);
uint32_t* nbr_p = reinterpret_cast<uint32_t*>(&neighbor);
nbr_p[0] = __shfl_up_sync(0xFFFFFFFF, high_val_p[0], 1);
nbr_p[1] = __shfl_up_sync(0xFFFFFFFF, high_val_p[1], 1);
nbr_p[2] = __shfl_up_sync(0xFFFFFFFF, high_val_p[2], 1);
nbr_p[3] = __shfl_up_sync(0xFFFFFFFF, high_val_p[3], 1);
// Lane 0 must use shared memory to handle the cross-warp boundary.
// thread 0 uses the last element of the previous chunk.
if (lane_id == 0)
{
neighbor = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
}
reinterpret_cast<vec_t*>(x_vals_load)[0] = neighbor;
__syncthreads();
// Now thread kNThreads - 1 can write the last elements of the current chunk.
if (tidx == kNThreads - 1)
{
smem_exchange[tidx] = reinterpret_cast<vec_t*>(x_vals_load)[1];
smem_exchange[tidx] = high_val;
}
float x_vals[2 * kNElts];
@ -169,22 +192,33 @@ __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_fwd_kernel(C
float out_vals[kNElts];
#pragma unroll
for (int i = 0; i < kNElts; ++i)
// Process 2 outputs at a time for better ILP (instruction level parallelism).
for (int i = 0; i < kNElts; i += 2)
{
out_vals[i] = bias_val;
float acc0 = bias_val;
float acc1 = bias_val;
#pragma unroll
for (int w = 0; w < kWidth; ++w)
{
out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
float wt = weight_vals[w];
acc0 = __fmaf_rn(wt, x_vals[kNElts + i - (kWidth - w - 1)], acc0);
acc1 = __fmaf_rn(wt, x_vals[kNElts + i + 1 - (kWidth - w - 1)], acc1);
}
out_vals[i] = acc0;
out_vals[i + 1] = acc1;
}
if (params.silu_activation)
{
#pragma unroll
for (int i = 0; i < kNElts; ++i)
for (int i = 0; i < kNElts; i += 2)
{
out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
// SiLU: x * sigmoid(x) = x / (1 + exp(-x))
// Using fast math: __expf and __frcp_rn
float v0 = out_vals[i];
float v1 = out_vals[i + 1];
out_vals[i] = v0 * __frcp_rn(1.0f + __expf(-v0));
out_vals[i + 1] = v1 * __frcp_rn(1.0f + __expf(-v1));
}
}

View File

@ -0,0 +1,187 @@
/*
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/fusedActivationQuant.h"
#include "tensorrt_llm/kernels/quantization.cuh"
#include "tensorrt_llm/kernels/quantization.h"
#include <cstdint>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
TRTLLM_NAMESPACE_BEGIN
namespace kernels
{
constexpr int kEltsPerThread = 8;
__device__ __forceinline__ float relu2_f32(float x)
{
float r = fmaxf(0.0f, x);
return r * r;
}
// Fused relu2 + NVFP4 quantization kernel.
//
// To match the unfused path (PyTorch relu2 -> cvt_warp_fp16_to_fp4), relu2 is
// computed in f32 then rounded back to native precision (bf16/fp16) before
// quantization. Absmax and scale-factor math follow cvt_warp_fp16_to_fp4 exactly.
// Column padding to a multiple of (4 * kSfVecSize) matches quantize_with_block_size
// for the swizzled SF layout.
template <typename T>
__global__ void fusedRelu2QuantizeKernel(T const* __restrict__ input, float const* __restrict__ sfScale,
uint32_t* __restrict__ outputFp4, uint32_t* __restrict__ outputSf, int m, int n)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr int kSfVecSize = 16;
constexpr int kNumThreadsPerSf = kSfVecSize / kEltsPerThread;
constexpr int kPackedPerThread = kEltsPerThread / 2;
using PackedType = std::conditional_t<std::is_same_v<T, half>, __half2, __nv_bfloat162>;
float const SFScaleVal = sfScale[0];
int const numColThreads = n / kEltsPerThread;
int const numColVecs = n / kSfVecSize;
int const numColThreadsPadded = ((n + 4 * kSfVecSize - 1) / (4 * kSfVecSize)) * (4 * kSfVecSize) / kEltsPerThread;
int const rowIdx = blockIdx.x;
if (rowIdx >= m)
return;
for (int colIdx = threadIdx.x; colIdx < numColThreadsPadded; colIdx += blockDim.x)
{
bool const isValidCol = colIdx < numColThreads;
PackedType packedVals[kPackedPerThread];
if (isValidCol)
{
int const inputOffset = rowIdx * n + colIdx * kEltsPerThread;
#pragma unroll
for (int i = 0; i < kPackedPerThread; i++)
{
float f0 = relu2_f32(static_cast<float>(input[inputOffset + i * 2]));
float f1 = relu2_f32(static_cast<float>(input[inputOffset + i * 2 + 1]));
if constexpr (std::is_same_v<T, half>)
{
packedVals[i] = __floats2half2_rn(f0, f1);
}
else
{
packedVals[i] = __floats2bfloat162_rn(f0, f1);
}
}
}
else
{
#pragma unroll
for (int i = 0; i < kPackedPerThread; i++)
{
if constexpr (std::is_same_v<T, half>)
{
packedVals[i] = __float2half2_rn(0.0f);
}
else
{
packedVals[i] = __float2bfloat162_rn(0.0f);
}
}
}
// Absmax in native precision, then reduce across the SF group (2 threads).
auto localMax = cuda_abs(packedVals[0]);
#pragma unroll
for (int i = 1; i < kPackedPerThread; i++)
{
localMax = cuda_max(localMax, cuda_abs(packedVals[i]));
}
localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
float vecMax = float(cuda_max(localMax.x, localMax.y));
// Scale-factor computation (identical to cvt_warp_fp16_to_fp4).
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
__nv_fp8_e4m3 fp8SF = __nv_fp8_e4m3(SFValue);
uint8_t fp8SFVal = fp8SF.__x;
SFValue = static_cast<float>(fp8SF);
float outputScale
= vecMax != 0.0f ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) : 0.0f;
if (colIdx % kNumThreadsPerSf == 0)
{
auto sfOutPtr = cvt_quant_get_sf_out_offset<uint32_t, kNumThreadsPerSf>(std::nullopt, rowIdx, colIdx,
std::optional<int>(m), numColVecs, outputSf, QuantizationSFLayout::SWIZZLED);
if (sfOutPtr != nullptr)
{
*sfOutPtr = fp8SFVal;
}
}
if (isValidCol)
{
float2 fp2Vals[kPackedPerThread];
#pragma unroll
for (int i = 0; i < kPackedPerThread; i++)
{
if constexpr (std::is_same_v<T, half>)
{
fp2Vals[i] = __half22float2(packedVals[i]);
}
else
{
fp2Vals[i] = __bfloat1622float2(packedVals[i]);
}
fp2Vals[i].x *= outputScale;
fp2Vals[i].y *= outputScale;
}
outputFp4[rowIdx * numColThreads + colIdx] = fp32_vec_to_e2m1(fp2Vals);
}
}
#else
if (threadIdx.x == 0 && blockIdx.x == 0)
{
printf("FP4 quantization requires SM100 (Blackwell) or later!\n");
}
#endif
}
template <typename T>
void invokeFusedRelu2Quantize(T const* input, float const* sfScale, std::uint8_t* outputFp4, std::uint8_t* outputSf,
int m, int n, int sfVecSize, cudaStream_t stream)
{
constexpr int kSfVecSize = 16;
int const numColThreadsPadded = ((n + 4 * kSfVecSize - 1) / (4 * kSfVecSize)) * (4 * kSfVecSize) / kEltsPerThread;
int threadsPerBlock = min(512, numColThreadsPadded);
threadsPerBlock = max(32, ((threadsPerBlock + 31) / 32) * 32);
fusedRelu2QuantizeKernel<T><<<m, threadsPerBlock, 0, stream>>>(
input, sfScale, reinterpret_cast<uint32_t*>(outputFp4), reinterpret_cast<uint32_t*>(outputSf), m, n);
}
template void invokeFusedRelu2Quantize<half>(
half const*, float const*, std::uint8_t*, std::uint8_t*, int, int, int, cudaStream_t);
#ifdef ENABLE_BF16
template void invokeFusedRelu2Quantize<__nv_bfloat16>(
__nv_bfloat16 const*, float const*, std::uint8_t*, std::uint8_t*, int, int, int, cudaStream_t);
#endif
} // namespace kernels
TRTLLM_NAMESPACE_END

View File

@ -0,0 +1,33 @@
/*
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/config.h"
#include <cstdint>
#include <cuda_runtime.h>
TRTLLM_NAMESPACE_BEGIN
namespace kernels
{
template <typename T>
void invokeFusedRelu2Quantize(T const* input, float const* sfScale, std::uint8_t* outputFp4, std::uint8_t* outputSf,
int m, int n, int sfVecSize, cudaStream_t stream);
} // namespace kernels
TRTLLM_NAMESPACE_END

View File

@ -37,6 +37,7 @@ struct GeneralFP4AddBiasResidualPreLayerNormParam
T const* bias = nullptr;
T const* gamma = nullptr;
T const* beta = nullptr;
T* high_precision_normed_output = nullptr;
int m;
int n;

View File

@ -276,7 +276,7 @@ struct LowLatencyLayerNorm
}
typename PackType<typename Traits::OutputType, Traits::PACKED_ELEMS_PER_COMPUTE>::type normed_output;
typename PackType<typename Traits::AccumulatorType, Traits::PACKED_ELEMS_PER_COMPUTE>::type
typename PackType<typename Traits::InputType, Traits::PACKED_ELEMS_PER_COMPUTE>::type
high_precision_normed_output;
for (int j = 0; j < Traits::PACKED_ELEMS_PER_COMPUTE; j++)
{
@ -300,7 +300,7 @@ struct LowLatencyLayerNorm
}
if constexpr (Traits::HIGH_PRECISION_NORMED_OUTPUT)
{
high_precision_normed_output.array[j] = normed_out;
high_precision_normed_output.array[j] = (typename Traits::InputType) normed_out;
}
if constexpr (Traits::OUTPUT_SCALE == SCALE_TYPE::SCALAR)
{

View File

@ -690,7 +690,7 @@ struct WarpSpecializedLayerNorm
typename PackType<typename Traits::OutputType, Traits::PACKED_ELEMS_PER_COMPUTE>::type
normed_output;
typename PackType<typename Traits::InputType, Traits::PACKED_ELEMS_PER_COMPUTE>::type output;
typename PackType<typename Traits::AccumulatorType, Traits::PACKED_ELEMS_PER_COMPUTE>::type
typename PackType<typename Traits::InputType, Traits::PACKED_ELEMS_PER_COMPUTE>::type
high_precision_normed_output;
#pragma unroll Traits::PACKED_ELEMS_PER_COMPUTE
@ -719,6 +719,11 @@ struct WarpSpecializedLayerNorm
normed_out += beta[j];
}
if constexpr (Traits::HIGH_PRECISION_NORMED_OUTPUT)
{
high_precision_normed_output.array[j] = (typename Traits::InputType) normed_out;
}
if constexpr (Traits::OUTPUT_SCALE != SCALE_TYPE::NONE)
{
static_assert(Traits::OUTPUT_SCALE == SCALE_TYPE::SCALAR);
@ -730,11 +735,6 @@ struct WarpSpecializedLayerNorm
output.array[j] = (typename Traits::InputType) data[m_offset][i][j];
}
if constexpr (Traits::HIGH_PRECISION_NORMED_OUTPUT)
{
high_precision_normed_output.array[j] = normed_out;
}
normed_output.array[j] = (typename Traits::OutputType) normed_out;
}

View File

@ -44,7 +44,7 @@ enum class SCALE_TYPE
};
template <typename T>
void invokeWSLayerNorm(WarpSpecializedParam<T> param, bool use_rms_norm, int ctas);
void invokeWSLayerNorm(WarpSpecializedParam<T> param, bool use_rms_norm, int ctas, bool output_hp_norm = false);
} // namespace kernels

View File

@ -31,7 +31,8 @@ TRTLLM_NAMESPACE_BEGIN
namespace kernels
{
template <typename _Param, typename _InputType, typename _OutputType, typename _AccumulatorType, bool _RMS_NORM,
int _M_BLOCK, int _N_BLOCK, int _STAGES = 3, bool _PERSISTENT_MODE = true, bool _LOW_LATENCY_MODE = false>
int _M_BLOCK, int _N_BLOCK, int _STAGES = 3, bool _PERSISTENT_MODE = true, bool _LOW_LATENCY_MODE = false,
bool _HIGH_PRECISION_NORMED_OUTPUT = false>
struct FP4AddBiasResidualPreLayerNormTraits
{
@ -59,12 +60,12 @@ struct FP4AddBiasResidualPreLayerNormTraits
static constexpr bool PERSISTENT_MODE = _PERSISTENT_MODE;
static constexpr bool LOW_LATENCY_MODE = _LOW_LATENCY_MODE;
static constexpr bool PREFETCH_TO_L2 = false;
static constexpr bool HIGH_PRECISION_NORMED_OUTPUT = false;
static constexpr bool HIGH_PRECISION_NORMED_OUTPUT = _HIGH_PRECISION_NORMED_OUTPUT;
};
template <typename T>
void invokeWSLayerNormImpl(
WarpSpecializedParam<GeneralFP4AddBiasResidualPreLayerNormParam<T>> param, bool use_rms_norm, int ctas)
void invokeWSLayerNormImpl(WarpSpecializedParam<GeneralFP4AddBiasResidualPreLayerNormParam<T>> param, bool use_rms_norm,
int ctas, bool output_hp_norm)
{
auto _invoke = [&](auto traits)
@ -80,10 +81,11 @@ void invokeWSLayerNormImpl(
{
int waves = ((param.m + Traits::M_BLOCK - 1) / Traits::M_BLOCK + ctas - 1) / ctas;
TLLM_LOG_DEBUG(
"Selected TILE_M = %d, N = %d, STAGE = %d, PERSISTENT_MODE = %d, LOW_LATENCY_MODE = %d for param M = "
"Selected TILE_M = %d, N = %d, STAGE = %d, PERSISTENT_MODE = %d, LOW_LATENCY_MODE = %d, "
"HIGH_PRECISION_NORMED_OUTPUT = %d for param M = "
"%d, N = %d, num_sms = %d. (waves = %d)\n",
Traits::M_BLOCK, Traits::N_BLOCK, Traits::STAGES, Traits::PERSISTENT_MODE, Traits::LOW_LATENCY_MODE,
param.m, param.n, ctas, waves);
Traits::HIGH_PRECISION_NORMED_OUTPUT, param.m, param.n, ctas, waves);
printed = true;
}
@ -117,15 +119,32 @@ void invokeWSLayerNormImpl(
constexpr auto PERSISTENT = decltype(persistent)::value;
constexpr auto LOW_LATENCY_MODE = decltype(low_latency_mode)::value;
// Select kernel variant based on use_rms_norm and output_hp_norm
if (use_rms_norm)
{
_invoke(FP4AddBiasResidualPreLayerNormTraits<GeneralFP4AddBiasResidualPreLayerNormParam<T>, T, T, float,
true, M_BLOCK, N_BLOCK, STAGES, PERSISTENT, LOW_LATENCY_MODE>{});
if (output_hp_norm)
{
_invoke(FP4AddBiasResidualPreLayerNormTraits<GeneralFP4AddBiasResidualPreLayerNormParam<T>, T, T, float,
true, M_BLOCK, N_BLOCK, STAGES, PERSISTENT, LOW_LATENCY_MODE, true>{});
}
else
{
_invoke(FP4AddBiasResidualPreLayerNormTraits<GeneralFP4AddBiasResidualPreLayerNormParam<T>, T, T, float,
true, M_BLOCK, N_BLOCK, STAGES, PERSISTENT, LOW_LATENCY_MODE, false>{});
}
}
else
{
_invoke(FP4AddBiasResidualPreLayerNormTraits<GeneralFP4AddBiasResidualPreLayerNormParam<T>, T, T, float,
false, M_BLOCK, N_BLOCK, STAGES, PERSISTENT, LOW_LATENCY_MODE>{});
if (output_hp_norm)
{
_invoke(FP4AddBiasResidualPreLayerNormTraits<GeneralFP4AddBiasResidualPreLayerNormParam<T>, T, T, float,
false, M_BLOCK, N_BLOCK, STAGES, PERSISTENT, LOW_LATENCY_MODE, true>{});
}
else
{
_invoke(FP4AddBiasResidualPreLayerNormTraits<GeneralFP4AddBiasResidualPreLayerNormParam<T>, T, T, float,
false, M_BLOCK, N_BLOCK, STAGES, PERSISTENT, LOW_LATENCY_MODE, false>{});
}
}
};
@ -308,16 +327,18 @@ void invokeWSLayerNormImpl(
template <>
void invokeWSLayerNorm<GeneralFP4AddBiasResidualPreLayerNormParam<half>>(
WarpSpecializedParam<GeneralFP4AddBiasResidualPreLayerNormParam<half>> param, bool use_rms_norm, int ctas)
WarpSpecializedParam<GeneralFP4AddBiasResidualPreLayerNormParam<half>> param, bool use_rms_norm, int ctas,
bool output_hp_norm)
{
invokeWSLayerNormImpl(param, use_rms_norm, ctas);
invokeWSLayerNormImpl(param, use_rms_norm, ctas, output_hp_norm);
}
template <>
void invokeWSLayerNorm<GeneralFP4AddBiasResidualPreLayerNormParam<__nv_bfloat16>>(
WarpSpecializedParam<GeneralFP4AddBiasResidualPreLayerNormParam<__nv_bfloat16>> param, bool use_rms_norm, int ctas)
WarpSpecializedParam<GeneralFP4AddBiasResidualPreLayerNormParam<__nv_bfloat16>> param, bool use_rms_norm, int ctas,
bool output_hp_norm)
{
invokeWSLayerNormImpl(param, use_rms_norm, ctas);
invokeWSLayerNormImpl(param, use_rms_norm, ctas, output_hp_norm);
}
} // namespace kernels

View File

@ -67,6 +67,7 @@ add_library(
dsv3FusedAGemmOp.cpp
fusedQKNormRopeOp.cpp
fusedAddRMSNormQuant.cpp
fusedActivationQuant.cpp
fusedTopkSoftmax.cpp
gatherTreeOp.cpp
groupRmsNormOp.cpp

View File

@ -0,0 +1,94 @@
/*
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/fusedActivationQuant.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/quantization.h"
#include "tensorrt_llm/thop/thUtils.h"
#include <ATen/Functions.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/EmptyTensor.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
TRTLLM_NAMESPACE_BEGIN
namespace torch_ext
{
std::tuple<at::Tensor, at::Tensor> fused_relu2_quantize(
at::Tensor const& input, at::Tensor const& sf_scale, int64_t sf_vec_size)
{
CHECK_TH_CUDA(input);
CHECK_CONTIGUOUS(input);
CHECK_INPUT(sf_scale, torch::kFloat32);
auto const& inputShape = input.sizes();
TORCH_CHECK(inputShape.size() == 2, "input should be 2D tensor [M, N].");
int64_t const m = inputShape[0];
int64_t const n = inputShape[1];
TORCH_CHECK(sf_vec_size == 16, "sf_vec_size must be 16 for NVFP4.");
TORCH_CHECK(n % sf_vec_size == 0, "N must be divisible by sf_vec_size.");
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
at::Tensor output_fp4 = at::detail::empty_cuda({m, n / 2}, torch::kUInt8, input.device(), std::nullopt);
int64_t const sfSize = tensorrt_llm::computeSwizzledLayoutSFSize(m, n / sf_vec_size);
at::Tensor output_sf = at::detail::empty_cuda({sfSize}, SF_DTYPE, input.device(), std::nullopt);
float const* sfScalePtr = sf_scale.data_ptr<float>();
if (input.scalar_type() == at::ScalarType::Half)
{
kernels::invokeFusedRelu2Quantize<half>(reinterpret_cast<half const*>(input.data_ptr()), sfScalePtr,
output_fp4.data_ptr<uint8_t>(), output_sf.data_ptr<uint8_t>(), static_cast<int>(m), static_cast<int>(n),
static_cast<int>(sf_vec_size), stream);
}
else if (input.scalar_type() == at::ScalarType::BFloat16)
{
#ifdef ENABLE_BF16
kernels::invokeFusedRelu2Quantize<__nv_bfloat16>(reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr()),
sfScalePtr, output_fp4.data_ptr<uint8_t>(), output_sf.data_ptr<uint8_t>(), static_cast<int>(m),
static_cast<int>(n), static_cast<int>(sf_vec_size), stream);
#else
C10_THROW_ERROR(NotImplementedError, "BFloat16 not enabled.");
#endif
}
else
{
C10_THROW_ERROR(NotImplementedError, "fused_relu2_quantize only supports fp16/bf16.");
}
return std::make_tuple(output_fp4, output_sf);
}
} // namespace torch_ext
TRTLLM_NAMESPACE_END
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("fused_relu2_quantize(Tensor input, Tensor sf_scale, int sf_vec_size=16) -> (Tensor, Tensor)");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("fused_relu2_quantize", &tensorrt_llm::torch_ext::fused_relu2_quantize);
}

View File

@ -43,16 +43,18 @@ namespace torch_ext
// gamma: [N] - RMSNorm weight (fp16/bf16)
// sf_scale: [1] - optional scale factor for FP4 quantization (float)
// use_rms_norm: bool - if true use RMSNorm, else use LayerNorm
// output_hp_norm: bool - if true, also output high precision normalized values (same dtype as input) for MoE gate.
// Returns:
// normed_output: [M, N/8] - FP4 quantized normalized output (uint32_t, packed)
// output: [M, N] - pre-norm output (input + residual), same dtype as input
// sf_out: scale factors for FP4 (uint8_t), swizzled layout
// high_precision_normed_output: [M, N] - normalized output before quant (only if output_hp_norm=true, else empty)
//
// NOTE: This kernel requires SM90 (Hopper) or SM100 (Blackwell) GPU architecture.
// NOTE: Hidden dimension N must be >= 2048 and <= 16384.
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_add_rms_norm_quant(at::Tensor const& input,
at::Tensor const& residual, at::Tensor const& gamma, std::optional<at::Tensor> const& sf_scale, bool use_rms_norm,
double eps)
std::tuple<at::Tensor, at::Tensor, at::Tensor, std::optional<at::Tensor>> fused_add_rms_norm_quant(
at::Tensor const& input, at::Tensor const& residual, at::Tensor const& gamma,
std::optional<at::Tensor> const& sf_scale, bool use_rms_norm, double eps, bool output_hp_norm)
{
CHECK_TH_CUDA(input);
CHECK_CONTIGUOUS(input);
@ -116,6 +118,14 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_add_rms_norm_quant(at::Tens
int64_t const sfSizePadded = tensorrt_llm::computeSwizzledLayoutSFSize(m_padded, n / sfVecSize);
at::Tensor sf_out_padded = at::detail::empty_cuda({sfSizePadded}, SF_DTYPE, input.device(), std::nullopt);
at::Tensor sf_out = (m_padded == m) ? sf_out_padded : sf_out_padded.narrow(0, 0, sfSize);
std::optional<at::Tensor> high_precision_normed_output = std::nullopt;
if (output_hp_norm)
{
at::Tensor hp_normed_output_padded
= at::detail::empty_cuda({m_padded, n}, input.scalar_type(), input.device(), std::nullopt);
high_precision_normed_output
= (m_padded == m) ? hp_normed_output_padded : hp_normed_output_padded.narrow(0, 0, m);
}
// Get number of SMs for persistent kernel
static int const multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();
@ -152,12 +162,14 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_add_rms_norm_quant(at::Tens
param.bias = nullptr; \
param.gamma = reinterpret_cast<T const*>(gamma.data_ptr()); \
param.beta = nullptr; \
param.high_precision_normed_output \
= output_hp_norm ? reinterpret_cast<T*>(high_precision_normed_output.value().data_ptr()) : nullptr; \
param.m = static_cast<int>(m); \
param.n = static_cast<int>(n); \
param.layernorm_eps = static_cast<float>(eps); \
param.stream = stream; \
param.counters = counters; \
tensorrt_llm::kernels::invokeWSLayerNorm<Param>(param, use_rms_norm, multiProcessorCount); \
tensorrt_llm::kernels::invokeWSLayerNorm<Param>(param, use_rms_norm, multiProcessorCount, output_hp_norm); \
} while (0)
if (input.scalar_type() == at::ScalarType::Half)
@ -180,7 +192,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_add_rms_norm_quant(at::Tens
#undef LAUNCH_FUSED_ADD_RMS_NORM_QUANT
// No explicit sync needed - kernel runs asynchronously on the stream
return std::make_tuple(normed_output, output, sf_out);
return std::make_tuple(normed_output, output, sf_out, high_precision_normed_output);
}
} // namespace torch_ext
@ -191,7 +203,8 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"fused_add_rms_norm_quant(Tensor input, Tensor residual, Tensor gamma, "
"Tensor? sf_scale, bool use_rms_norm=True, float eps=1e-6) -> (Tensor, Tensor, Tensor)");
"Tensor? sf_scale, bool use_rms_norm=True, float eps=1e-6, bool output_hp_norm=False) -> (Tensor, Tensor, "
"Tensor, Tensor?)");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)

View File

@ -1003,7 +1003,9 @@ def _register_fake():
sf_scale: Optional[torch.Tensor],
use_rms_norm: bool = True,
eps: float = 1e-5,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
output_hp_norm: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
m, n = input.shape
# normed_output_fp4: [M, N/8] as int32 (8 FP4 values packed per int32)
normed_output_fp4 = input.new_empty((m, n // 8), dtype=torch.int32)
@ -1013,4 +1015,22 @@ def _register_fake():
sf_vec_size = 16
sf_size = ((m + 127) // 128) * 128 * ((n // sf_vec_size + 3) // 4) * 4
sf_out = input.new_empty((sf_size, ), dtype=torch.uint8)
return normed_output_fp4, output, sf_out
# high_precision_normed_output: [M, N] optional, only when output_hp_norm=True
hp_output = input.new_empty(
(m, n), dtype=input.dtype) if output_hp_norm else None
return normed_output_fp4, output, sf_out, hp_output
@torch.library.register_fake("trtllm::fused_relu2_quantize")
def _(
input: torch.Tensor,
sf_scale: torch.Tensor,
sf_vec_size: int = 16,
) -> Tuple[torch.Tensor, torch.Tensor]:
# input: 2D tensor [M, N] (bf16 or fp16)
# output_fp4: [M, N/2] (packed FP4 values, 2 values per byte)
# output_sf: swizzled scale factors
output_shape, scale_shape = fp4_utils.get_fp4_shape(
input.shape, sf_vec_size, is_swizzled_layout=True)
output_fp4 = input.new_empty(output_shape, dtype=torch.uint8)
output_sf = input.new_empty((scale_shape, ), dtype=torch.uint8)
return output_fp4, output_sf

View File

@ -14,7 +14,6 @@
# limitations under the License.
import re
from typing import Dict, List, Optional
import torch
from torch import nn
@ -23,6 +22,7 @@ from transformers import AutoConfig, PretrainedConfig
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
BaseWeightMapper
from tensorrt_llm._torch.utils import ActivationType, relu2
from tensorrt_llm.logger import logger
from ..attention_backend import AttentionMetadata
from ..distributed import AllReduce
@ -37,7 +37,7 @@ from ..modules.mlp import MLP
from ..modules.multi_stream_utils import maybe_execute_in_parallel
from ..modules.rms_norm import RMSNorm
from ..speculative import SpecMetadata
from ..utils import AuxStreamType, EventType
from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor
from .modeling_deepseekv3 import DeepseekV3MTPHead
from .modeling_speculative import SpecDecOneEngineForCausalLM
from .modeling_utils import DecoderModel, register_auto_model
@ -121,7 +121,7 @@ class NemotronHMOE(nn.Module):
self,
model_config: ModelConfig[PretrainedConfig],
layer_idx: int,
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
aux_stream_dict: dict[AuxStreamType, torch.cuda.Stream],
):
super().__init__()
@ -242,13 +242,20 @@ class NemotronHMOE(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
hidden_states: torch.Tensor
| tuple[torch.Tensor | Fp4QuantizedTensor, torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs,
) -> torch.Tensor:
assert hidden_states.shape[-1] == self.hidden_dim
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_dim)
if isinstance(hidden_states, tuple):
hidden_states, hidden_states_hp = hidden_states
else:
hidden_states_hp = hidden_states
assert hidden_states_hp.shape[-1] == self.hidden_dim
orig_shape = hidden_states_hp.shape
hidden_states_hp_2d = hidden_states_hp.view(-1, self.hidden_dim)
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
def _compute_shared_output():
@ -259,7 +266,8 @@ class NemotronHMOE(nn.Module):
return shared_expert_output
def _compute_routed_output():
router_logits = self.gate(hidden_states)
# Gate uses high precision input for accurate routing decisions.
router_logits = self.gate(hidden_states_hp_2d)
routed_hidden_states = hidden_states
if self.use_latent_moe:
@ -301,7 +309,7 @@ class NemotronHLayer(DecoderLayer):
# - -> MLPLayer
# * -> TransformerLayer
layer_type: str,
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
aux_stream_dict: dict[AuxStreamType, torch.cuda.Stream],
):
super().__init__()
@ -310,10 +318,31 @@ class NemotronHLayer(DecoderLayer):
self.layer_idx = layer_idx
self.layer_type = layer_type
self.is_nvfp4 = (model_config.quant_config is not None
and model_config.quant_config.quant_mode is not None
and model_config.quant_config.quant_mode.has_nvfp4())
# The fused RMSNorm+NVFP4 CUDA kernel requires hidden_size to be
# a supported tile size. Non-power-of-2 hidden sizes within tile
# ranges may cause kernel hangs. Disable fused NVFP4 for such cases.
# Supported tile sizes: 2048, 4096, 8192, 16384
_SUPPORTED_NVFP4_HIDDEN_SIZES = {2048, 4096, 8192, 16384}
if self.is_nvfp4 and config.hidden_size not in _SUPPORTED_NVFP4_HIDDEN_SIZES:
logger.warning_once(
f"Layer {layer_idx}: Disabling fused NVFP4 RMSNorm for hidden_size={config.hidden_size}. "
f"Supported sizes: {_SUPPORTED_NVFP4_HIDDEN_SIZES}. Using non-fused path.",
key=f"disable_nvfp4_rmsnorm_with_{config.hidden_size}")
self.is_nvfp4 = False
self.norm = RMSNorm(
hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
# Enable fused NVFP4 quantization if possible.
# It might be overridden in `_try_attach_nvfp4_scale` function.
quantize_type="nvfp4" if self.is_nvfp4 else None,
# Enable high precision output for MoE layer (only with NVFP4).
# It might be overridden in `_try_attach_nvfp4_scale` function.
return_hp_output=layer_type == "E" and self.is_nvfp4,
)
if layer_type == "M":
@ -343,29 +372,71 @@ class NemotronHLayer(DecoderLayer):
else:
raise ValueError(f"{layer_type} is not supported")
def post_load_weights(self):
"""Post-process after loading weights."""
if self.norm.is_nvfp4 and not hasattr(self.norm, 'nvfp4_scale'):
self._try_attach_nvfp4_scale()
def _try_attach_nvfp4_scale(self):
"""Attach input_scale from mixer's first linear to norm for fused RMSNorm+Quant."""
# Normal handling for Mamba, MLP, and Attention layers.
first_linear_attr = {
'M': 'in_proj',
'-': 'up_proj',
'*': 'qkv_proj'
}.get(self.layer_type)
if first_linear_attr:
first_linear = getattr(self.mixer, first_linear_attr, None)
if first_linear and hasattr(first_linear, 'input_scale'):
self.norm.nvfp4_scale = first_linear.input_scale
return
# Special handling for MoE layer: fetch shared_expert.up_proj.input_scale
# as representation of the input scale.
if self.layer_type == 'E':
if (hasattr(self.mixer, 'shared_experts')
and self.mixer.shared_experts is not None
and hasattr(self.mixer.shared_experts, 'up_proj')
and hasattr(self.mixer.shared_experts.up_proj,
'input_scale') and
self.mixer.shared_experts.up_proj.input_scale is not None):
self.norm.nvfp4_scale = self.mixer.shared_experts.up_proj.input_scale
# Enable high precision output for MoE layer.
self.norm.return_hp_output = True
return
self.norm.is_nvfp4 = False
self.norm.return_hp_output = False
def forward(
self,
position_ids: torch.IntTensor,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
spec_metadata: Optional[SpecMetadata] = None,
residual: torch.Tensor | None = None,
spec_metadata: SpecMetadata | None = None,
**kwargs,
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = torch.zeros_like(hidden_states)
residual = hidden_states
hidden_states = self.norm(hidden_states)
if self.norm.return_hp_output:
hidden_states, residual, high_precision_normed_output = self.norm(
hidden_states, residual)
hidden_states = (hidden_states, high_precision_normed_output)
else:
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states,
attn_metadata,
spec_metadata=spec_metadata,
**kwargs)
hidden_states = torch.add(hidden_states, residual)
if spec_metadata is not None and spec_metadata.is_layer_capture(
self.layer_idx):
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
hidden_states, None)
hidden_states, residual)
return hidden_states
return hidden_states, residual
class NemotronHModel(DecoderModel):
@ -426,10 +497,10 @@ class NemotronHModel(DecoderModel):
def forward(
self,
attn_metadata: AttentionMetadata,
input_ids: Optional[torch.IntTensor] = None,
position_ids: Optional[torch.IntTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
spec_metadata: Optional[SpecMetadata] = None,
input_ids: torch.IntTensor | None = None,
position_ids: torch.IntTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
spec_metadata: SpecMetadata | None = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
@ -443,16 +514,15 @@ class NemotronHModel(DecoderModel):
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
residual = torch.zeros_like(hidden_states)
for layer in self.layers[:self.num_hidden_layers]:
hidden_states = layer(position_ids,
hidden_states,
attn_metadata,
spec_metadata=spec_metadata,
mamba_metadata=mamba_metadata)
hidden_states = self.norm_f(hidden_states)
hidden_states, residual = layer(position_ids,
hidden_states,
residual=residual,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata,
mamba_metadata=mamba_metadata)
hidden_states, _ = self.norm_f(hidden_states, residual)
return hidden_states
@ -517,7 +587,7 @@ class NemotronHForCausalLM(SpecDecOneEngineForCausalLM[NemotronHModel,
self.epilogue.extend(self.draft_model.mtp_layers)
self.epilogue.append(self.spec_worker)
def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper):
def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper):
new_weights = weight_mapper.preprocess_weights(weights)
super().load_weights(weights=new_weights, weight_mapper=weight_mapper)
@ -528,7 +598,7 @@ class NemotronHMTPDecoderLayer(NemotronHLayer):
self,
model_config: ModelConfig[NemotronHConfig],
layer_idx: int,
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
aux_stream_dict: dict[AuxStreamType, torch.cuda.Stream],
has_start_projections: bool,
has_end_norm: bool,
layer_type: str,
@ -625,7 +695,7 @@ class NemotronHMTPDecoderLayer(NemotronHLayer):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None = None,
attn_metadata: Optional[AttentionMetadata] = None,
attn_metadata: AttentionMetadata | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
if self.has_start_projections:
@ -672,7 +742,7 @@ class NemotronHMTP(nn.Module):
def __init__(self,
model_config: ModelConfig[NemotronHConfig],
layer_idx: int,
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
aux_stream_dict: dict[AuxStreamType, torch.cuda.Stream],
is_separate_draft_engine: bool = False,
prefix: str = ""):
super().__init__()
@ -744,8 +814,8 @@ class NemotronHMTP(nn.Module):
hidden_states: torch.Tensor,
embed_tokens: Embedding,
attn_metadata: AttentionMetadata,
all_rank_num_tokens: Optional[List[int]] = None,
spec_metadata: Optional[SpecMetadata] = None,
all_rank_num_tokens: list[int] | None = None,
spec_metadata: SpecMetadata | None = None,
**kwargs,
) -> torch.Tensor:
inputs_embeds = embed_tokens(input_ids)

View File

@ -0,0 +1,176 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fused elementwise operations for Mamba2 prefill optimization."""
import torch
import triton
import triton.language as tl
@triton.jit
def _extract_transpose_prefill_kernel(
src_ptr,
dst_ptr,
num_prefill_tokens,
d_in_proj,
d_inner,
conv_dim,
BLOCK_SEQ: tl.constexpr,
BLOCK_CONV: tl.constexpr,
):
"""Extract src[0:num_prefill_tokens, d_inner:d_inner+conv_dim] and
transpose to dst[conv_dim, num_prefill_tokens]."""
pid_seq = tl.program_id(0)
pid_conv = tl.program_id(1)
seq_offsets = pid_seq * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)
conv_offsets = pid_conv * BLOCK_CONV + tl.arange(0, BLOCK_CONV)
seq_mask = seq_offsets < num_prefill_tokens
conv_mask = conv_offsets < conv_dim
mask = seq_mask[:, None] & conv_mask[None, :]
src_offsets = seq_offsets[:, None] * d_in_proj + (d_inner + conv_offsets[None, :])
data = tl.load(src_ptr + src_offsets, mask=mask, other=0.0)
dst_offsets = conv_offsets[:, None] * num_prefill_tokens + seq_offsets[None, :]
tl.store(dst_ptr + dst_offsets, tl.trans(data), mask=conv_mask[:, None] & seq_mask[None, :])
def extract_transpose_xbc_prefill(
zxbcdt: torch.Tensor,
num_prefill_tokens: int,
d_inner: int,
conv_dim: int,
) -> torch.Tensor:
"""
Extract and transpose xbc slice from zxbcdt for causal_conv1d_fn.
Input: zxbcdt[num_tokens, d_in_proj]
Output: [conv_dim, num_prefill_tokens]
"""
out = torch.empty(conv_dim, num_prefill_tokens, dtype=zxbcdt.dtype, device=zxbcdt.device)
BLOCK_SEQ, BLOCK_CONV = 32, 128
grid = (triton.cdiv(num_prefill_tokens, BLOCK_SEQ), triton.cdiv(conv_dim, BLOCK_CONV))
_extract_transpose_prefill_kernel[grid](
zxbcdt,
out,
num_prefill_tokens,
zxbcdt.shape[1],
d_inner,
conv_dim,
BLOCK_SEQ,
BLOCK_CONV,
)
return out
@triton.jit
def _fused_conv_output_transpose_kernel(
src_ptr,
out_x_ptr,
out_B_ptr,
out_C_ptr,
num_prefill_tokens,
d_inner,
bc_size,
x_tiles,
bc_tiles,
BLOCK_SEQ: tl.constexpr,
BLOCK_DIM: tl.constexpr,
):
"""
Transpose and split conv1d output into x, B, C using linear grid mapping.
Grid: tiles [0, x_tiles) -> x, [x_tiles, x_tiles+bc_tiles) -> B, rest -> C
"""
tile_id = tl.program_id(0)
is_x = tile_id < x_tiles
is_B = (tile_id >= x_tiles) & (tile_id < x_tiles + bc_tiles)
local_tile = tl.where(
is_x, tile_id, tl.where(is_B, tile_id - x_tiles, tile_id - x_tiles - bc_tiles)
)
dim_size = tl.where(is_x, d_inner, bc_size)
num_dim_blocks = tl.cdiv(dim_size, BLOCK_DIM)
pid_seq = local_tile // num_dim_blocks
pid_dim = local_tile % num_dim_blocks
seq_offsets = pid_seq * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)
dim_offsets = pid_dim * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
seq_mask = seq_offsets < num_prefill_tokens
dim_mask = dim_offsets < dim_size
src_offset = tl.where(is_x, 0, tl.where(is_B, d_inner, d_inner + bc_size))
src_indices = (src_offset + dim_offsets[:, None]) * num_prefill_tokens + seq_offsets[None, :]
data = tl.load(src_ptr + src_indices, mask=dim_mask[:, None] & seq_mask[None, :], other=0.0)
out_ptr = tl.where(is_x, out_x_ptr, tl.where(is_B, out_B_ptr, out_C_ptr))
dst_indices = seq_offsets[:, None] * dim_size + dim_offsets[None, :]
tl.store(out_ptr + dst_indices, tl.trans(data), mask=seq_mask[:, None] & dim_mask[None, :])
def fused_split_rearrange_after_conv1d(
xbc: torch.Tensor,
d_inner: int,
n_groups: int,
d_state: int,
nheads: int,
head_dim: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Split and rearrange causal_conv1d output into contiguous x, B, C tensors.
Input: xbc[conv_dim, num_prefill_tokens]
Output: x[1, num_prefill_tokens, nheads, head_dim],
B[1, num_prefill_tokens, n_groups, d_state],
C[1, num_prefill_tokens, n_groups, d_state]
"""
conv_dim, num_prefill_tokens = xbc.shape
bc_size = n_groups * d_state
x_flat = torch.empty(num_prefill_tokens, d_inner, dtype=xbc.dtype, device=xbc.device)
B_flat = torch.empty(num_prefill_tokens, bc_size, dtype=xbc.dtype, device=xbc.device)
C_flat = torch.empty(num_prefill_tokens, bc_size, dtype=xbc.dtype, device=xbc.device)
BLOCK_SEQ, BLOCK_DIM = 64, 64
num_seq_blocks = triton.cdiv(num_prefill_tokens, BLOCK_SEQ)
x_tiles = num_seq_blocks * triton.cdiv(d_inner, BLOCK_DIM)
bc_tiles = num_seq_blocks * triton.cdiv(bc_size, BLOCK_DIM)
_fused_conv_output_transpose_kernel[(x_tiles + 2 * bc_tiles,)](
xbc,
x_flat,
B_flat,
C_flat,
num_prefill_tokens,
d_inner,
bc_size,
x_tiles,
bc_tiles,
BLOCK_SEQ,
BLOCK_DIM,
)
return (
x_flat.view(1, num_prefill_tokens, nheads, head_dim),
B_flat.view(1, num_prefill_tokens, n_groups, d_state),
C_flat.view(1, num_prefill_tokens, n_groups, d_state),
)

View File

@ -17,6 +17,8 @@ import math
from typing import Tuple
import torch
import triton
import triton.language as tl
from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
@ -25,6 +27,86 @@ from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import \
use_cpp_mamba_cache_manager
@triton.jit
def _cu_seqlens_triton_kernel(
cu_seqlens_ptr, # [num_seqs + 1]
chunk_indices_ptr, # [N] output
chunk_offsets_ptr, # [N] output
num_seqs: tl.constexpr,
chunk_size: tl.constexpr,
N: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Computes chunk_indices and chunk_offsets in a single kernel launch."""
pid = tl.program_id(0)
chunk_start = pid * BLOCK_SIZE
offsets = chunk_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < N
chunk_indices = offsets.to(tl.int64)
chunk_offsets = tl.zeros([BLOCK_SIZE], dtype=tl.int64)
p = 0
for seq_idx in range(num_seqs - 1):
seq_start = tl.load(cu_seqlens_ptr + seq_idx + 1).to(tl.int64)
seq_end = tl.load(cu_seqlens_ptr + seq_idx + 2).to(tl.int64)
is_misaligned = (seq_start % chunk_size) > 0
p = p + is_misaligned
s_chunk = seq_start // chunk_size + p
e_chunk = seq_end // chunk_size + p + ((seq_end % chunk_size) > 0)
in_range = (offsets >= s_chunk) & (offsets < e_chunk)
chunk_indices = tl.where(in_range & mask, chunk_indices - p,
chunk_indices)
is_start = (offsets == s_chunk)
chunk_offsets = tl.where(is_start & mask, seq_start % chunk_size,
chunk_offsets)
tl.store(chunk_indices_ptr + offsets, chunk_indices.to(tl.int32), mask=mask)
tl.store(chunk_offsets_ptr + offsets, chunk_offsets.to(tl.int32), mask=mask)
def cu_seqlens_to_chunk_indices_offsets_triton(
cu_seqlens: torch.Tensor,
chunk_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Optimized version of cu_seqlens_to_chunk_indices_offsets."""
device = cu_seqlens.device
num_seqs = cu_seqlens.numel() - 1
if num_seqs == 0:
return (torch.empty(0, dtype=torch.int, device=device),
torch.empty(0, dtype=torch.int, device=device))
cu = cu_seqlens.to(dtype=torch.int64)
total_seqlens = cu[-1].item()
if num_seqs == 1:
# Fast path for single sequence (no boundaries to process)
N = (total_seqlens + chunk_size - 1) // chunk_size
return (torch.arange(N, device=device, dtype=torch.int),
torch.zeros(N, device=device, dtype=torch.int))
seq_starts = cu[1:-1]
misaligned = ((seq_starts % chunk_size) > 0).to(torch.int64)
p = torch.cumsum(misaligned, dim=0)
extra_chunks = p[-1].item() if p.numel() > 0 else 0
N = (total_seqlens + chunk_size - 1) // chunk_size + extra_chunks
chunk_indices = torch.empty(N, device=device, dtype=torch.int)
chunk_offsets = torch.empty(N, device=device, dtype=torch.int)
BLOCK_SIZE = 256
grid = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, )
_cu_seqlens_triton_kernel[grid](
cu,
chunk_indices,
chunk_offsets,
num_seqs=num_seqs,
chunk_size=chunk_size,
N=N,
BLOCK_SIZE=BLOCK_SIZE,
)
return chunk_indices, chunk_offsets
def cu_seqlens_to_chunk_indices_offsets(
cu_seqlens: torch.Tensor,
chunk_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
@ -117,15 +199,15 @@ class Mamba2Metadata:
self.state_indices = torch.zeros(max_batch_size,
dtype=torch.int32,
device="cuda")
self._query_start_loc_long_buf = torch.arange(0,
max_batch_size + 1,
dtype=torch.long,
device="cuda")
self._query_start_loc_buf = torch.zeros(max_batch_size + 1,
dtype=torch.int,
device="cuda")
self.query_start_loc_long = self._query_start_loc_long_buf
self.query_start_loc = self._query_start_loc_buf
# Pre-allocated buffers.
self._arange_buffer = torch.arange(max_batch_size + 1,
dtype=torch.int,
device="cuda")
self._arange_buffer_long = self._arange_buffer.to(torch.long)
self._cu_seqlens_long = torch.zeros(max_batch_size + 1,
dtype=torch.long,
device="cuda")
def prepare(self, attn_metadata: AttentionMetadata):
batch_size = attn_metadata.seq_lens.shape[0]
@ -158,47 +240,32 @@ class Mamba2Metadata:
dtype=torch.int,
out=self.cu_seqlens[1:num_contexts + 1])
torch.add(self.cu_seqlens[num_contexts],
torch.arange(1,
batch_size - num_contexts + 1,
dtype=self.cu_seqlens.dtype,
device=self.cu_seqlens.device),
self._arange_buffer[1:batch_size - num_contexts + 1],
out=self.cu_seqlens[num_contexts + 1:batch_size + 1])
# Need both `query_start_loc` and `query_start_loc_long` because `causal_conv1d_fn`
# accepts only `int32` while `chunk_gated_delta_rule` accepts only `long`.
self._query_start_loc_buf[:batch_size +
1] = self.cu_seqlens[:batch_size + 1]
self.query_start_loc = self._query_start_loc_buf[:batch_size + 1]
self._query_start_loc_long_buf[:batch_size + 1].copy_(
self.query_start_loc.to(torch.long), non_blocking=True)
self.query_start_loc_long = self._query_start_loc_long_buf[:
batch_size
+ 1]
self.query_start_loc = self.cu_seqlens[:batch_size + 1]
self._cu_seqlens_long[:batch_size + 1].copy_(self.query_start_loc)
self.query_start_loc_long = self._cu_seqlens_long[:batch_size + 1]
self.seq_idx = torch.repeat_interleave(
torch.arange(num_contexts,
dtype=torch.int,
device=self.cu_seqlens.device),
self._arange_buffer[:num_contexts],
repeats=context_lens,
output_size=num_ctx_tokens).unsqueeze(0)
num_cached_tokens_per_seq = attn_metadata.kv_cache_params.num_cached_tokens_per_seq
self.has_initial_states[:num_contexts] = torch.tensor(
num_cached_tokens_per_seq[:num_contexts]) > 0
# precomputed bool to avoid host<->device syncs during forward pass
self.use_initial_states = torch.any(
self.has_initial_states[:num_contexts]).item()
initial_states = [
num_cached_tokens_per_seq[i] > 0 for i in range(num_contexts)
]
self.use_initial_states = any(initial_states)
if self.use_initial_states:
self.chunk_indices, self.chunk_offsets = cu_seqlens_to_chunk_indices_offsets(
self.has_initial_states[:num_contexts] = torch.tensor(
initial_states, dtype=torch.bool)
self.chunk_indices, self.chunk_offsets = cu_seqlens_to_chunk_indices_offsets_triton(
self.cu_seqlens[:num_contexts + 1], self.chunk_size)
else:
self.chunk_indices = None
self.chunk_offsets = None
else:
self.query_start_loc = None
torch.arange(0,
batch_size + 1,
dtype=torch.long,
device=self.cu_seqlens.device,
out=self._query_start_loc_long_buf[:batch_size + 1])
self.query_start_loc_long = self._query_start_loc_long_buf[:
batch_size
+ 1]
self.query_start_loc_long = self._arange_buffer_long[:batch_size +
1]

View File

@ -33,6 +33,8 @@ from ..linear import Linear, TensorParallelMode
from .causal_conv1d import causal_conv1d_fn, causal_conv1d_update
from .causal_conv1d_triton import \
causal_conv1d_update as causal_conv1d_update_triton
from .fuse_elementwise_ops import (extract_transpose_xbc_prefill,
fused_split_rearrange_after_conv1d)
from .layernorm_gated import RMSNorm as RMSNormGated
from .selective_state_update import \
selective_state_update as selective_state_update_native
@ -227,15 +229,17 @@ class Mamba2Mixer(nn.Module):
# in_proj
zxbcdt = self.in_proj(hidden_states)
z, xbc, dt = torch.split(
zxbcdt,
[self.tp_d_inner, self.tp_conv_dim, self.tp_nheads],
dim=-1,
)
# Split z and dt with views.
z = zxbcdt[:, :self.tp_d_inner]
dt = zxbcdt[:, self.tp_d_inner + self.tp_conv_dim:]
z_p, z_d = torch.split(z, seqlen_split_size, dim=0)
xbc_p, xbc_d = torch.split(xbc, seqlen_split_size, dim=0)
dt_p, dt_d = torch.split(dt, seqlen_split_size, dim=0)
# Decode path uses regular view since no transpose is needed.
xbc_d = zxbcdt[num_prefill_tokens:num_actual_tokens,
self.tp_d_inner:self.tp_d_inner + self.tp_conv_dim]
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
preallocated_ssm_out = torch.empty(
@ -243,8 +247,8 @@ class Mamba2Mixer(nn.Module):
zxbcdt.shape[0],
(self.num_heads * self.head_dim) // self.tp_size,
],
dtype=hidden_states.dtype,
device=hidden_states.device,
dtype=zxbcdt.dtype,
device=zxbcdt.device,
)
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
preallocated_ssm_out,
@ -259,27 +263,29 @@ class Mamba2Mixer(nn.Module):
has_initial_states = mamba_metadata.has_initial_states[:
num_prefills]
xbc_p = causal_conv1d_fn(xbc_p.transpose(0, 1),
# Fused kernel to avoid expensive .contiguous() call in causal_conv1d_fn.
xbc_p_t = extract_transpose_xbc_prefill(zxbcdt, num_prefill_tokens,
self.tp_d_inner,
self.tp_conv_dim)
xbc_p = causal_conv1d_fn(xbc_p_t,
self.conv1d.weight,
self.conv1d.bias,
activation="silu",
conv_states=conv_states,
has_initial_state=has_initial_states,
query_start_loc=cu_seqlens,
cache_indices=state_indices_p).transpose(
0, 1)
cache_indices=state_indices_p)
x_p, B_p, C_p = torch.split(xbc_p.unsqueeze(0), [
# Fused kernel to avoid expensive .contiguous() calls after split/rearrange.
x_p, B_p, C_p = fused_split_rearrange_after_conv1d(
xbc_p,
self.tp_d_inner,
self.tp_ngroups * self.d_state,
self.tp_ngroups * self.d_state,
],
dim=-1)
x_p = rearrange(x_p, "b l (h p) -> b l h p", h=self.tp_nheads)
self.tp_ngroups,
self.d_state,
self.tp_nheads,
self.head_dim,
)
dt_p = dt_p.unsqueeze(0)
B_p = rearrange(B_p, "b l (g n) -> b l g n", g=self.tp_ngroups)
C_p = rearrange(C_p, "b l (g n) -> b l g n", g=self.tp_ngroups)
z_p = rearrange(z_p.unsqueeze(0),
"b l (h p) -> b l h p",
h=self.tp_nheads)

View File

@ -26,6 +26,124 @@ TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
@triton.autotune(
configs=[
# =================================================================
# Higher warp count configs for better latency hiding
# More warps = more instructions in flight = better memory latency hiding
# =================================================================
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32
},
num_stages=2,
num_warps=8, # 8 warps = 256 threads per block
),
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32
},
num_stages=2,
num_warps=8, # 8 warps for better latency hiding
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32
},
num_stages=2,
num_warps=8,
),
# Smaller tiles with more stages for software pipelining
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32
},
num_stages=3,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64
},
num_stages=2,
num_warps=4,
),
# =================================================================
# Low register pressure configs (num_stages=1) for large dstate
# =================================================================
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32
},
num_stages=1,
num_warps=4,
),
# num_stages=2 configs - moderate register pressure
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64
},
num_stages=2,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32
},
num_stages=2,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32
},
num_stages=2,
num_warps=4,
),
# Original configs for smaller dstate values
triton.Config(
{
"BLOCK_SIZE_M": 128,
@ -355,14 +473,17 @@ def _chunk_scan_fwd_kernel(
if not HAS_INITSTATES:
# - this is for continuous batching where there is no init states
scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m),
# Use exp2 for faster computation: exp(x) = exp2(x * log2(e))
scale_m = tl.where(seq_idx_m == seq_idx_prev,
tl.math.exp2(dA_cs_m * 1.4426950408889634),
0.0)
else:
# - if there is initstates, we will rely on prev_states, no zeroing
# required.
scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary)
scale_m = tl.math.exp2(
(dA_cs_m - dA_cs_m_boundary) * 1.4426950408889634)
else:
scale_m = tl.exp(dA_cs_m)
scale_m = tl.math.exp2(dA_cs_m * 1.4426950408889634)
if BLOCK_SIZE_DSTATE <= 128:
C = tl.load(
C_ptrs,
@ -421,7 +542,9 @@ def _chunk_scan_fwd_kernel(
other=0.0).to(tl.float32)
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
# So we don't need masking wrt seq_idx here.
cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :])
# Use exp2 for faster computation: exp(x) = exp2(x * log2(e))
cb *= tl.math.exp2(
(dA_cs_m[:, None] - dA_cs_k[None, :]) * 1.4426950408889634)
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k,
other=0.0).to(tl.float32)
cb *= dt_k

View File

@ -128,6 +128,54 @@ def _chunk_cumsum_fwd_kernel(
@triton.autotune(
configs=[
# Small headdim/dstate configs (hdim<=64, dstate<=128) - increased parallelism
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32
},
num_stages=3,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32
},
num_stages=3,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32
},
num_stages=3,
num_warps=4,
),
# Low register pressure configs for large dstate (dstate=128)
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64
},
num_stages=2,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64
},
num_stages=2,
num_warps=4,
),
# Original configs for larger headdim/dstate values
triton.Config(
{
"BLOCK_SIZE_M": 128,
@ -175,40 +223,13 @@ def _chunk_cumsum_fwd_kernel(
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32
},
num_stages=5,
num_warps=2,
),
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32
},
num_stages=5,
num_warps=2,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32
},
num_stages=4,
num_warps=2,
),
],
key=["hdim", "dstate", "chunk_size"],
)
@ -351,6 +372,41 @@ def _chunk_state_fwd_kernel(
@triton.autotune(
configs=[
# Small headdim/dstate configs for better parallelism
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32
},
num_stages=3,
num_warps=4),
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32
},
num_stages=3,
num_warps=4),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32
},
num_stages=3,
num_warps=4),
# Low register pressure configs
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64
},
num_stages=2,
num_warps=4),
# Original configs for larger dimensions
triton.Config(
{
"BLOCK_SIZE_M": 128,
@ -393,36 +449,12 @@ def _chunk_state_fwd_kernel(
num_warps=4),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32
},
num_stages=5,
num_warps=2),
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32
},
num_stages=5,
num_warps=2),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32
},
num_stages=4,
num_warps=2),
],
key=["hdim", "dstate", "chunk_size"],
)

View File

@ -8,23 +8,25 @@ from tensorrt_llm.mapping import Mapping
from ..model_config import ModelConfig
from ..peft.lora.layer import LoraLayer, LoraModuleType
from ..utils import Fp4QuantizedTensor, relu2
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
class MLP(nn.Module):
def __init__(self,
*,
hidden_size: int,
intermediate_size: int,
bias: bool,
activation: Callable[[torch.Tensor], torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
config: Optional[ModelConfig] = None,
layer_idx: Optional[int] = None,
reduce_output: bool = True,
overridden_tp_size: Optional[int] = None):
def __init__(
self,
*,
hidden_size: int,
intermediate_size: int,
bias: bool,
activation: Callable[[torch.Tensor], torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
config: Optional[ModelConfig] = None,
layer_idx: Optional[int] = None,
reduce_output: bool = True,
overridden_tp_size: Optional[int] = None,
):
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = hidden_size
@ -81,7 +83,22 @@ class MLP(nn.Module):
lora=self.down_lora,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization,
reduce_output=reduce_output)
reduce_output=reduce_output,
)
self._use_fused_relu2_quant = False
def create_weights(self):
self.up_proj.create_weights()
self.down_proj.create_weights()
has_nvfp4 = hasattr(self.down_proj,
'has_nvfp4') and self.down_proj.has_nvfp4
has_kernel = hasattr(torch.ops.trtllm, 'fused_relu2_quantize')
has_scale = hasattr(self.down_proj, 'input_scale')
is_relu2 = self.activation is relu2
self._use_fused_relu2_quant = has_nvfp4 and has_kernel and has_scale and is_relu2
def forward(
self,
@ -92,11 +109,34 @@ class MLP(nn.Module):
return self.forward_lora(x, lora_params=lora_params)
x_up = self.up_proj(x)
x_act = self.activation(x_up)
if self._use_fused_relu2_quant:
x_act = self._fused_relu2_quant(x_up)
else:
x_act = self.activation(x_up)
x_down = self.down_proj(x_act)
return x_down
def _fused_relu2_quant(self, x: torch.Tensor) -> Fp4QuantizedTensor:
x_flat = x.view(-1, x.shape[-1])
if not x_flat.is_contiguous():
x_flat = x_flat.contiguous()
if x_flat.dtype not in (torch.float16, torch.bfloat16):
x_flat = x_flat.to(torch.bfloat16)
fp4_tensor, sf_tensor = torch.ops.trtllm.fused_relu2_quantize(
x_flat, self.down_proj.input_scale, 16)
return Fp4QuantizedTensor(
fp4_tensor=fp4_tensor,
scaling_factor=sf_tensor,
is_sf_swizzled=True,
)
def forward_lora(
self,
x: torch.Tensor,

View File

@ -41,6 +41,7 @@ class RMSNorm(nn.Module):
use_gemma: bool = False,
quantize_type: Optional[str] = None,
use_cuda_tile: bool = False,
return_hp_output: bool = False,
):
super().__init__()
@ -72,6 +73,7 @@ class RMSNorm(nn.Module):
self.variance_epsilon = eps
self.use_gemma = use_gemma
self.use_cuda_tile = use_cuda_tile
self.return_hp_output = return_hp_output
def forward(
self,
@ -80,7 +82,8 @@ class RMSNorm(nn.Module):
Optional[torch.Tensor],
_ArgumentNotSpecifiedSentinelType] = _ARGUMENT_NOT_SPECIFIED_SENTINEL,
) -> Union[torch.Tensor, Fp4QuantizedTensor, Tuple[Union[
torch.Tensor, Fp4QuantizedTensor], Optional[torch.Tensor]]]:
torch.Tensor, Fp4QuantizedTensor], Optional[torch.Tensor]], Tuple[
Fp4QuantizedTensor, torch.Tensor, torch.Tensor]]:
has_residual = residual is not self._ARGUMENT_NOT_SPECIFIED_SENTINEL
if not has_residual:
residual = None
@ -116,14 +119,16 @@ class RMSNorm(nn.Module):
sf_scale = nvfp4_scale.contiguous()
normed_fp4_i32, residual_out_2d, sf_fused = torch.ops.trtllm.fused_add_rms_norm_quant(
results = torch.ops.trtllm.fused_add_rms_norm_quant(
hs_2d,
res_2d,
gamma,
sf_scale,
True,
eps=self.variance_epsilon,
output_hp_norm=self.return_hp_output,
)
normed_fp4_i32, residual_out_2d, sf_fused = results[:3]
normed_fp4_u8 = normed_fp4_i32.view(torch.uint8)
if len(orig_shape) != 2:
normed_fp4_u8 = normed_fp4_u8.reshape(*orig_shape[:-1], n // 2)
@ -132,9 +137,21 @@ class RMSNorm(nn.Module):
residual_out = residual_out_2d
hidden_states_fused = Fp4QuantizedTensor(normed_fp4_u8, sf_fused)
return (hidden_states_fused,
residual_out) if has_residual else hidden_states_fused
elif self.use_cuda_tile:
outputs = [hidden_states_fused]
if has_residual:
outputs.append(residual_out)
if self.return_hp_output:
high_precision_normed_output = results[3].reshape(orig_shape)
outputs.append(high_precision_normed_output)
return outputs[0] if len(outputs) == 1 else tuple(outputs)
if self.return_hp_output:
raise ValueError(
"Auxiliary high precision output is only supported for NVFP4 fused path"
)
if self.use_cuda_tile:
if isinstance(residual, torch.Tensor):
# Use fused residual kernel
hidden_states = hidden_states.contiguous()

View File

@ -593,7 +593,9 @@ class MambaCacheManager(BaseResourceManager):
return self._impl.get_intermediate_conv_states(layer_idx)
def is_speculative(self) -> bool:
assert not self._use_cpp, "is_speculative is not supported in CppMambaCacheManager"
if self._use_cpp:
# CppMambaCacheManager does not support speculative decoding for now.
return False
return self._impl.is_speculative()
def mamba_layer_cache(

View File

@ -1,5 +1,6 @@
import contextlib
import functools
import inspect
import itertools
import os
import unittest.mock
@ -443,9 +444,9 @@ class Runner:
def forward(position_ids, hidden_states, attn_metadata, residual, **kwargs):
# TODO: to be more general, we should call DecoderModel.forward
residual_fusion = hasattr(model.model.layers[layer_indices[0]], "next_layer_layernorm")
for layer_idx in layer_indices:
layer = model.model.layers[layer_idx]
residual_fusion = "residual" in inspect.signature(layer.forward).parameters
if residual_fusion:
hidden_states, residual = layer(
position_ids, hidden_states, attn_metadata, residual, **kwargs

View File

@ -0,0 +1,246 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
import torch.nn.functional as F
from tensorrt_llm._torch.modules.mamba import PAD_SLOT_ID
def mamba_conv1d_ref(x, past_conv_state, conv_weight, conv_bias, apply_silu):
"""
Reference implementation for causal conv1d.
Arguments:
x: [batch_size, dim, seq_len]
past_conv_state: [batch_size, dim, dconv-1]
conv_weight: [dim, 1, dconv]
conv_bias: [dim]
Output:
y: [batch_size, dim, seq_len]
present_conv_state: [batch_size, dim, dconv-1]
"""
assert x.dim() == 3
assert past_conv_state.dim() == 3
assert conv_weight.dim() == 3
assert conv_bias.dim() == 1
batch_size, dim, seq_len = x.shape
assert past_conv_state.shape[0] == batch_size
assert past_conv_state.shape[1] == dim
dconv = past_conv_state.shape[2] + 1
assert conv_weight.shape[0] == dim
assert conv_weight.shape[1] == 1
assert conv_weight.shape[2] == dconv
padded_x = torch.cat([past_conv_state, x], dim=2)
present_conv_state = padded_x[:, :, -(dconv - 1) :]
x_conv = F.conv1d(padded_x, conv_weight, bias=conv_bias, groups=dim)
y = F.silu(x_conv) if apply_silu else x_conv
return y, present_conv_state
def trtllm_causal_conv1d_available():
"""Check if trtllm.causal_conv1d_fwd is available."""
return hasattr(torch.ops, "trtllm") and hasattr(torch.ops.trtllm, "causal_conv1d_fwd")
skip_unsupported = pytest.mark.skipif(
not torch.cuda.is_available() or not trtllm_causal_conv1d_available(),
reason="Requires CUDA and trtllm.causal_conv1d_fwd op",
)
@skip_unsupported
class TestCausalConv1d:
"""Tests for the causal_conv1d CUDA kernel."""
@pytest.mark.parametrize("dtype", ["float16", "bfloat16", "float32"])
@pytest.mark.parametrize("apply_silu", [True, False])
@pytest.mark.parametrize("dim", [256, 512, 1024, 2048])
def test_basic_correctness(self, dtype, apply_silu, dim):
"""Test basic correctness against reference implementation."""
torch.manual_seed(42)
device = "cuda"
torch_dtype = getattr(torch, dtype)
batch_size = 4
seq_len = 32
dconv = 4
std_dev = 0.5
x = torch.randn(batch_size, dim, seq_len, dtype=torch_dtype, device=device)
x = x * std_dev
conv_state = torch.zeros(batch_size, dim, dconv - 1, dtype=torch_dtype, device=device)
conv_weight = torch.randn(dim, 1, dconv, dtype=torch_dtype, device=device)
conv_bias = torch.randn(dim, dtype=torch_dtype, device=device)
x_kernel = x.clone()
conv_state_kernel = conv_state.clone()
conv_weight_input = conv_weight.squeeze(1).contiguous()
torch.ops.trtllm.causal_conv1d_fwd(
x_kernel,
conv_weight_input,
conv_bias,
conv_state_kernel,
None, # query_start_loc
None, # cache_indices
None, # has_initial_state
apply_silu,
PAD_SLOT_ID,
)
out_ref, conv_state_ref = mamba_conv1d_ref(
x, conv_state, conv_weight, conv_bias, apply_silu
)
torch.testing.assert_close(x_kernel, out_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(conv_state_kernel, conv_state_ref, rtol=1e-2, atol=1e-2)
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
def test_various_batch_sizes(self, batch_size):
"""Test with various batch sizes."""
torch.manual_seed(42)
device = "cuda"
dtype = torch.bfloat16
dim = 1024
seq_len = 64
dconv = 4
apply_silu = True
x = torch.randn(batch_size, dim, seq_len, dtype=dtype, device=device) * 0.5
conv_state = torch.zeros(batch_size, dim, dconv - 1, dtype=dtype, device=device)
conv_weight = torch.randn(dim, 1, dconv, dtype=dtype, device=device)
conv_bias = torch.randn(dim, dtype=dtype, device=device)
x_kernel = x.clone()
conv_state_kernel = conv_state.clone()
conv_weight_input = conv_weight.squeeze(1).contiguous()
torch.ops.trtllm.causal_conv1d_fwd(
x_kernel,
conv_weight_input,
conv_bias,
conv_state_kernel,
None,
None,
None,
apply_silu,
PAD_SLOT_ID,
)
out_ref, conv_state_ref = mamba_conv1d_ref(
x, conv_state, conv_weight, conv_bias, apply_silu
)
torch.testing.assert_close(x_kernel, out_ref, rtol=1e-2, atol=1e-1)
torch.testing.assert_close(conv_state_kernel, conv_state_ref, rtol=1e-2, atol=1e-1)
@pytest.mark.parametrize("dconv", [2, 3, 4])
def test_various_kernel_widths(self, dconv):
"""Test with different convolution kernel widths."""
torch.manual_seed(42)
device = "cuda"
dtype = torch.bfloat16
batch_size = 4
dim = 1024
seq_len = 64
apply_silu = True
x = torch.randn(batch_size, dim, seq_len, dtype=dtype, device=device) * 0.5
conv_state = torch.zeros(batch_size, dim, dconv - 1, dtype=dtype, device=device)
conv_weight = torch.randn(dim, 1, dconv, dtype=dtype, device=device)
conv_bias = torch.randn(dim, dtype=dtype, device=device)
x_kernel = x.clone()
conv_state_kernel = conv_state.clone()
conv_weight_input = conv_weight.squeeze(1).contiguous()
torch.ops.trtllm.causal_conv1d_fwd(
x_kernel,
conv_weight_input,
conv_bias,
conv_state_kernel,
None,
None,
None,
apply_silu,
PAD_SLOT_ID,
)
out_ref, conv_state_ref = mamba_conv1d_ref(
x, conv_state, conv_weight, conv_bias, apply_silu
)
torch.testing.assert_close(x_kernel, out_ref, rtol=1e-2, atol=1e-1)
torch.testing.assert_close(conv_state_kernel, conv_state_ref, rtol=1e-2, atol=1e-1)
def test_with_initial_state(self):
"""Test with non-zero initial conv state."""
torch.manual_seed(42)
device = "cuda"
dtype = torch.bfloat16
batch_size = 4
dim = 1024
seq_len = 32
dconv = 4
apply_silu = True
x = torch.randn(batch_size, dim, seq_len, dtype=dtype, device=device) * 0.5
# Non-zero initial state
conv_state = torch.randn(batch_size, dim, dconv - 1, dtype=dtype, device=device)
conv_state = conv_state * 0.5
conv_weight = torch.randn(dim, 1, dconv, dtype=dtype, device=device)
conv_bias = torch.randn(dim, dtype=dtype, device=device)
conv_state_kernel = conv_state.clone()
# Need to tell the kernel about initial state
has_initial_state = torch.ones(batch_size, dtype=torch.bool, device=device)
query_start_loc = torch.tensor(
[0] + [seq_len * (i + 1) for i in range(batch_size)],
dtype=torch.int32,
device=device,
)
# Reshape for varlen format
x_varlen = x.transpose(1, 2).reshape(-1, dim).T.contiguous()
conv_weight_input = conv_weight.squeeze(1).contiguous()
torch.ops.trtllm.causal_conv1d_fwd(
x_varlen,
conv_weight_input,
conv_bias,
conv_state_kernel,
query_start_loc,
None, # cache_indices
has_initial_state,
apply_silu,
PAD_SLOT_ID,
)
out_ref_list = []
conv_state_ref_list = []
for b in range(batch_size):
out_b, state_b = mamba_conv1d_ref(
x[b : b + 1],
conv_state[b : b + 1],
conv_weight,
conv_bias,
apply_silu,
)
out_ref_list.append(out_b)
conv_state_ref_list.append(state_b)
out_ref = torch.cat(out_ref_list, dim=0)
conv_state_ref = torch.cat(conv_state_ref_list, dim=0)
x_kernel_reshaped = (
x_varlen.T.reshape(batch_size, seq_len, dim).transpose(1, 2).contiguous()
)
torch.testing.assert_close(x_kernel_reshaped, out_ref, rtol=1e-2, atol=1e-1)
torch.testing.assert_close(conv_state_kernel, conv_state_ref, rtol=1e-2, atol=1e-1)

View File

@ -0,0 +1,113 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for fused elementwise operations in Mamba2 prefill."""
import pytest
import torch
from tensorrt_llm._torch.modules.mamba.fuse_elementwise_ops import (
extract_transpose_xbc_prefill,
fused_split_rearrange_after_conv1d,
)
skip_no_cuda = pytest.mark.skipif(
not torch.cuda.is_available(),
reason="CUDA required for triton kernels",
)
def extract_transpose_xbc_prefill_ref(
zxbcdt: torch.Tensor,
num_prefill_tokens: int,
d_inner: int,
conv_dim: int,
) -> torch.Tensor:
"""Reference implementation for extract_transpose_xbc_prefill."""
# Extract the xbc slice and transpose
xbc = zxbcdt[:num_prefill_tokens, d_inner : d_inner + conv_dim]
return xbc.transpose(0, 1).contiguous()
def fused_split_rearrange_after_conv1d_ref(
xbc: torch.Tensor,
d_inner: int,
n_groups: int,
d_state: int,
nheads: int,
head_dim: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Reference implementation for fused_split_rearrange_after_conv1d."""
conv_dim, num_prefill_tokens = xbc.shape
bc_size = n_groups * d_state
# Transpose and split
xbc_t = xbc.transpose(0, 1).contiguous() # [num_prefill_tokens, conv_dim]
x, B, C = torch.split(xbc_t, [d_inner, bc_size, bc_size], dim=-1)
x = x.contiguous().view(1, num_prefill_tokens, nheads, head_dim)
B = B.contiguous().view(1, num_prefill_tokens, n_groups, d_state)
C = C.contiguous().view(1, num_prefill_tokens, n_groups, d_state)
return x, B, C
@skip_no_cuda
@pytest.mark.parametrize("num_prefill_tokens", [1, 32, 128, 1024])
@pytest.mark.parametrize(
"d_inner,conv_dim,d_in_proj", [(256, 512, 800), (512, 1024, 1600), (1024, 2048, 3200)]
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_extract_transpose_xbc_prefill(num_prefill_tokens, d_inner, conv_dim, d_in_proj, dtype):
"""Test extract_transpose_xbc_prefill matches reference implementation."""
torch.manual_seed(42)
device = torch.device("cuda")
num_total_tokens = num_prefill_tokens + 16
zxbcdt = torch.randn(num_total_tokens, d_in_proj, dtype=dtype, device=device)
out_ref = extract_transpose_xbc_prefill_ref(zxbcdt, num_prefill_tokens, d_inner, conv_dim)
out_fused = extract_transpose_xbc_prefill(zxbcdt, num_prefill_tokens, d_inner, conv_dim)
assert out_fused.shape == out_ref.shape, f"Shape mismatch: {out_fused.shape} vs {out_ref.shape}"
torch.testing.assert_close(out_fused, out_ref, rtol=1e-3, atol=1e-3)
@skip_no_cuda
@pytest.mark.parametrize("num_prefill_tokens", [1, 32, 128, 1024])
@pytest.mark.parametrize(
"nheads,head_dim,n_groups,d_state", [(8, 64, 1, 128), (16, 64, 2, 64), (32, 64, 4, 64)]
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_fused_split_rearrange_after_conv1d(
num_prefill_tokens, nheads, head_dim, n_groups, d_state, dtype
):
"""Test fused_split_rearrange_after_conv1d matches reference implementation."""
torch.manual_seed(42)
device = torch.device("cuda")
d_inner = nheads * head_dim
bc_size = n_groups * d_state
conv_dim = d_inner + 2 * bc_size
xbc = torch.randn(conv_dim, num_prefill_tokens, dtype=dtype, device=device)
x_ref, B_ref, C_ref = fused_split_rearrange_after_conv1d_ref(
xbc, d_inner, n_groups, d_state, nheads, head_dim
)
x_fused, B_fused, C_fused = fused_split_rearrange_after_conv1d(
xbc, d_inner, n_groups, d_state, nheads, head_dim
)
assert x_fused.shape == x_ref.shape, f"x shape mismatch: {x_fused.shape} vs {x_ref.shape}"
assert B_fused.shape == B_ref.shape, f"B shape mismatch: {B_fused.shape} vs {B_ref.shape}"
assert C_fused.shape == C_ref.shape, f"C shape mismatch: {C_fused.shape} vs {C_ref.shape}"
torch.testing.assert_close(x_fused, x_ref, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(B_fused, B_ref, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(C_fused, C_ref, rtol=1e-3, atol=1e-3)

View File

@ -0,0 +1,133 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for Mamba2 metadata preparation optimizations."""
import pytest
import torch
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import (
cu_seqlens_to_chunk_indices_offsets,
cu_seqlens_to_chunk_indices_offsets_triton,
)
skip_no_cuda = pytest.mark.skipif(
not torch.cuda.is_available(),
reason="CUDA required for triton kernels",
)
@skip_no_cuda
class TestCuSeqlensToChunkIndicesOffsets:
"""Tests for cu_seqlens_to_chunk_indices_offsets_triton function."""
def test_empty_sequence(self):
"""Test with empty cu_seqlens (no sequences)."""
cu_seqlens = torch.tensor([0], dtype=torch.int, device="cuda")
chunk_size = 8
indices_triton, offsets_triton = cu_seqlens_to_chunk_indices_offsets_triton(
cu_seqlens, chunk_size
)
assert indices_triton.numel() == 0
assert offsets_triton.numel() == 0
def test_single_sequence_aligned(self):
"""Test with a single sequence that aligns with chunk size."""
cu_seqlens = torch.tensor([0, 16], dtype=torch.int, device="cuda")
chunk_size = 8
indices_ref, offsets_ref = cu_seqlens_to_chunk_indices_offsets(cu_seqlens, chunk_size)
indices_triton, offsets_triton = cu_seqlens_to_chunk_indices_offsets_triton(
cu_seqlens, chunk_size
)
torch.testing.assert_close(indices_triton, indices_ref)
torch.testing.assert_close(offsets_triton, offsets_ref)
def test_single_sequence_unaligned(self):
"""Test with a single sequence that doesn't align with chunk size."""
cu_seqlens = torch.tensor([0, 10], dtype=torch.int, device="cuda")
chunk_size = 8
indices_ref, offsets_ref = cu_seqlens_to_chunk_indices_offsets(cu_seqlens, chunk_size)
indices_triton, offsets_triton = cu_seqlens_to_chunk_indices_offsets_triton(
cu_seqlens, chunk_size
)
torch.testing.assert_close(indices_triton, indices_ref)
torch.testing.assert_close(offsets_triton, offsets_ref)
def test_two_sequences_aligned(self):
"""Test with two sequences, both aligned with chunk boundaries."""
cu_seqlens = torch.tensor([0, 8, 16], dtype=torch.int, device="cuda")
chunk_size = 8
indices_ref, offsets_ref = cu_seqlens_to_chunk_indices_offsets(cu_seqlens, chunk_size)
indices_triton, offsets_triton = cu_seqlens_to_chunk_indices_offsets_triton(
cu_seqlens, chunk_size
)
torch.testing.assert_close(indices_triton, indices_ref)
torch.testing.assert_close(offsets_triton, offsets_ref)
def test_two_sequences_misaligned(self):
"""Test with two sequences where second starts at misaligned position."""
# Example from docstring: cu_seqlens = [0, 5, 10], chunk_size = 8
# -> chunk_indices = [0, 0, 1], chunk_offsets = [0, 5, 0]
cu_seqlens = torch.tensor([0, 5, 10], dtype=torch.int, device="cuda")
chunk_size = 8
indices_ref, offsets_ref = cu_seqlens_to_chunk_indices_offsets(cu_seqlens, chunk_size)
indices_triton, offsets_triton = cu_seqlens_to_chunk_indices_offsets_triton(
cu_seqlens, chunk_size
)
# Verify against expected values from docstring
expected_indices = torch.tensor([0, 0, 1], dtype=torch.int, device="cuda")
expected_offsets = torch.tensor([0, 5, 0], dtype=torch.int, device="cuda")
torch.testing.assert_close(indices_ref, expected_indices)
torch.testing.assert_close(offsets_ref, expected_offsets)
torch.testing.assert_close(indices_triton, indices_ref)
torch.testing.assert_close(offsets_triton, offsets_ref)
@pytest.mark.parametrize("chunk_size", [8, 16, 32, 64, 128])
def test_multiple_sequences_various_chunk_sizes(self, chunk_size):
"""Test with multiple sequences and various chunk sizes."""
# Create sequences with varying lengths
cu_seqlens = torch.tensor([0, 10, 25, 40, 60, 75], dtype=torch.int, device="cuda")
indices_ref, offsets_ref = cu_seqlens_to_chunk_indices_offsets(cu_seqlens, chunk_size)
indices_triton, offsets_triton = cu_seqlens_to_chunk_indices_offsets_triton(
cu_seqlens, chunk_size
)
torch.testing.assert_close(indices_triton, indices_ref)
torch.testing.assert_close(offsets_triton, offsets_ref)
def test_all_sequences_within_one_chunk(self):
"""Test when all sequences fit within a single chunk."""
cu_seqlens = torch.tensor([0, 2, 4, 6], dtype=torch.int, device="cuda")
chunk_size = 64 # Large chunk size
indices_ref, offsets_ref = cu_seqlens_to_chunk_indices_offsets(cu_seqlens, chunk_size)
indices_triton, offsets_triton = cu_seqlens_to_chunk_indices_offsets_triton(
cu_seqlens, chunk_size
)
torch.testing.assert_close(indices_triton, indices_ref)
torch.testing.assert_close(offsets_triton, offsets_ref)

View File

@ -0,0 +1,223 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for fused relu2 + NVFP4 quantization kernel."""
import pytest
import torch
import torch.nn.functional as F
from tests.unittest.utils.util import getSMVersion
def fused_relu2_quantize_available():
"""Check if the fused_relu2_quantize op is available."""
return hasattr(torch.ops, "trtllm") and hasattr(torch.ops.trtllm, "fused_relu2_quantize")
def fp4_quantize_available():
"""Check if the fp4_quantize op is available."""
return hasattr(torch.ops, "trtllm") and hasattr(torch.ops.trtllm, "fp4_quantize")
skip_unless_fused_relu2_quantize = pytest.mark.skipif(
getSMVersion() < 100 or not fused_relu2_quantize_available(),
reason="Requires SM100+ (Blackwell) and trtllm.fused_relu2_quantize op",
)
skip_unless_fused_relu2_and_fp4_quantize = pytest.mark.skipif(
getSMVersion() < 100 or not fused_relu2_quantize_available() or not fp4_quantize_available(),
reason="Requires SM100+ (Blackwell) and trtllm fused_relu2_quantize + fp4_quantize ops",
)
# FP4 E2M1 lookup table for reference implementation
E2M1_BOUNDS = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5])
def relu2(x: torch.Tensor) -> torch.Tensor:
"""Reference relu2 activation: square(relu(x))."""
return torch.square(F.relu(x))
def cast_to_fp4(weight: torch.Tensor) -> torch.Tensor:
"""Cast tensor values to FP4 E2M1 format (as uint8)."""
device = weight.device
mask = torch.tensor([0, 1, 0, 1, 0, 1, 0], dtype=torch.uint8).to(device)
mask_shape = list(weight.shape)
mask = mask.expand([*mask_shape, 7])
sign_bit = (weight < 0).to(torch.uint8)
weight_abs = weight.abs()
ord_val = torch.searchsorted(E2M1_BOUNDS.to(device), weight_abs, out_int32=True).to(torch.uint8)
round_val = torch.any((weight_abs.unsqueeze(-1) == E2M1_BOUNDS.to(device)) * mask, dim=-1)
fp4_val = (sign_bit * 0b1000 + ord_val + round_val).to(torch.uint8)
return fp4_val
def quantize_nvfp4_ref(
input: torch.Tensor, sf_scale: torch.Tensor, sf_vec_size: int = 16
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Reference NVFP4 quantization implementation.
Args:
input: Input tensor [M, N], already activated (e.g., after relu2)
sf_scale: Per-tensor scaling factor (sf_scale = amax / (6 * 448))
sf_vec_size: Block size for per-block scaling (default 16)
Returns:
Tuple of (fp4_packed, scale_factors)
"""
m, n = input.shape
assert n % sf_vec_size == 0, f"N ({n}) must be divisible by sf_vec_size ({sf_vec_size})"
# Reshape for block-wise quantization
input_blocked = input.view(m, n // sf_vec_size, sf_vec_size)
# Compute per-block amax
per_block_amax = input_blocked.abs().amax(dim=-1).float()
# Compute per-block scale: amax / 6.0
per_block_scale = per_block_amax / 6.0
# Quantize per-block scale to FP8
q_per_block_scale = per_block_scale / sf_scale
q_per_block_scale[per_block_scale == 0] = 1.0
q_per_block_scale_fp8 = q_per_block_scale.to(torch.float8_e4m3fn)
# Dequantize scale for actual quantization
scale_dequant = q_per_block_scale_fp8.float() * sf_scale
# Scale the input
scale_expanded = scale_dequant.unsqueeze(-1).expand_as(input_blocked)
scaled_input = input_blocked / (scale_expanded + 1e-12)
scaled_input = scaled_input.view(m, n)
# Cast to FP4
fp4_vals = cast_to_fp4(scaled_input)
# Pack two FP4 values into one uint8
packed = (fp4_vals[..., 1::2] << 4) | fp4_vals[..., 0::2]
return packed, q_per_block_scale_fp8
def fused_relu2_quantize_ref(
input: torch.Tensor, sf_scale: torch.Tensor, sf_vec_size: int = 16
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Reference implementation for fused relu2 + NVFP4 quantization.
Args:
input: Input tensor [M, N]
sf_scale: Per-tensor scaling factor
sf_vec_size: Block size for per-block scaling (default 16)
Returns:
Tuple of (fp4_packed, scale_factors)
"""
# Apply relu2 activation
activated = relu2(input)
# Quantize to NVFP4
return quantize_nvfp4_ref(activated, sf_scale, sf_vec_size)
@skip_unless_fused_relu2_quantize
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_fused_relu2_quantize_zeros(dtype):
"""Test fused_relu2_quantize with inputs that produce zeros after relu2."""
device = torch.device("cuda")
# All negative inputs -> relu2 produces all zeros
m, n = 32, 64
input_tensor = -torch.abs(torch.randn(m, n, dtype=dtype, device=device))
sf_scale = torch.tensor([1.0], dtype=torch.float32, device=device)
fp4_fused, sf_fused = torch.ops.trtllm.fused_relu2_quantize(input_tensor, sf_scale, 16)
assert fp4_fused.shape == (m, n // 2)
assert (fp4_fused == 0).all(), "All negative inputs should produce zero output"
@skip_unless_fused_relu2_and_fp4_quantize
@pytest.mark.parametrize("m", [1, 16, 64, 128])
@pytest.mark.parametrize("n", [32, 64, 128, 256])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_fused_relu2_quantize_vs_separate_ops(m, n, dtype):
"""
Compare fused_relu2_quantize kernel output against separate relu2 + fp4_quantize.
This test verifies that the fused CUDA kernel produces FP4 packed values that
closely match running relu2 activation followed by fp4_quantize separately.
Note: Due to floating point precision differences in intermediate calculations
(e.g., FMA vs separate mul+add), a small percentage of values at quantization
boundaries may differ. We require >= 99% match rate.
"""
torch.manual_seed(42)
device = torch.device("cuda")
input_tensor = torch.randn(m, n, dtype=dtype, device=device)
activated = relu2(input_tensor)
sf_scale = (activated.abs().amax().float() / (6.0 * 448.0)).to(device)
sf_scale = sf_scale.view(1)
fp4_separate, sf_separate = torch.ops.trtllm.fp4_quantize(
activated,
sf_scale,
16,
False,
True, # use_ue8m0=False, is_sf_swizzled_layout=True
)
fp4_fused, sf_fused = torch.ops.trtllm.fused_relu2_quantize(
input_tensor.contiguous(), sf_scale, 16
)
match_rate = (fp4_fused == fp4_separate).float().mean().item()
assert match_rate >= 0.99, (
f"FP4 values match rate {match_rate:.4f} < 0.99 for shape ({m}, {n}), dtype {dtype}"
)
@skip_unless_fused_relu2_and_fp4_quantize
def test_fused_relu2_quantize_vs_separate_ops_various_sf_scales():
"""
Test with various sf_scale values to ensure consistent behavior.
"""
device = torch.device("cuda")
m, n = 64, 128
dtype = torch.bfloat16
torch.manual_seed(123)
input_tensor = torch.randn(m, n, dtype=dtype, device=device)
activated = relu2(input_tensor)
# Test with different sf_scale values
for scale_multiplier in [0.1, 1.0, 10.0]:
sf_scale = (
(activated.abs().amax().float() / (6.0 * 448.0) * scale_multiplier).to(device).view(1)
)
fp4_separate, sf_separate = torch.ops.trtllm.fp4_quantize(
activated, sf_scale, 16, False, True
)
fp4_fused, sf_fused = torch.ops.trtllm.fused_relu2_quantize(
input_tensor.contiguous(), sf_scale, 16
)
match_rate = (fp4_fused == fp4_separate).float().mean().item()
assert match_rate >= 0.99, (
f"FP4 values match rate {match_rate:.4f} < 0.99 with scale_multiplier={scale_multiplier}"
)

View File

@ -0,0 +1,336 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for fused_add_rms_norm_quant with/without output_hp_norm support."""
import pytest
import torch
from tests.unittest.utils.util import getSMVersion
def fused_add_rms_norm_quant_available():
"""Check if the fused_add_rms_norm_quant op is available."""
return hasattr(torch.ops, "trtllm") and hasattr(torch.ops.trtllm, "fused_add_rms_norm_quant")
skip_unsupported = pytest.mark.skipif(
getSMVersion() < 100 or not fused_add_rms_norm_quant_available(),
reason="Requires Blackwell+ (SM100+) and trtllm.fused_add_rms_norm_quant op",
)
def rms_norm_ref(
hidden_states: torch.Tensor,
residual: torch.Tensor,
gamma: torch.Tensor,
eps: float,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Reference RMSNorm implementation with residual addition.
Args:
hidden_states: Input tensor [M, N]
residual: Residual tensor [M, N]
gamma: Weight tensor [N]
eps: Epsilon for numerical stability
Returns:
Tuple of (normalized_output, residual_output)
"""
input_dtype = hidden_states.dtype
# Add residual
hidden_states_fp32 = hidden_states.float() + residual.float()
residual_out = hidden_states_fp32.to(input_dtype)
# RMSNorm
variance = hidden_states_fp32.pow(2).mean(-1, keepdim=True)
normed = hidden_states_fp32 * torch.rsqrt(variance + eps)
normed_output = (gamma.float() * normed).to(input_dtype)
return normed_output, residual_out
def get_swizzled_sf_indices(m: int, n: int, sf_vec_size: int = 16) -> list[int]:
"""
Compute the valid indices in swizzled SF layout for given m and n.
The swizzled layout uses 128x4 tiles:
- SF layout: [numMTiles, numKTiles, 32 (outerM), 4 (innerM), 4 (innerK)]
- Each SF block has 128 rows, padded to multiple of 128
- Each SF block has columns padded to multiple of 4
Args:
m: Number of rows
n: Hidden dimension
sf_vec_size: Number of elements sharing one scale factor (default 16)
Returns:
List of valid indices in the swizzled buffer
"""
num_col_vecs = n // sf_vec_size # Number of SF columns
indices = []
for m_idx in range(m):
for k_idx in range(num_col_vecs):
# Compute swizzled offset using 128x4 tile layout
inner_k_idx = k_idx % 4
inner_k_stride = 1
inner_m_idx = (m_idx % 128) // 32
inner_m_stride = 4 * inner_k_stride # 4
outer_m_idx = m_idx % 32
outer_m_stride = 4 * inner_m_stride # 16
k_tile_idx = k_idx // 4
k_tile_stride = 32 * outer_m_stride # 512
num_k_tiles = (num_col_vecs + 3) // 4
m_tile_idx = m_idx // 128
m_tile_stride = num_k_tiles * k_tile_stride
offset = (
m_tile_idx * m_tile_stride
+ k_tile_idx * k_tile_stride
+ outer_m_idx * outer_m_stride
+ inner_m_idx * inner_m_stride
+ inner_k_idx * inner_k_stride
)
indices.append(offset)
return indices
@skip_unsupported
@pytest.mark.parametrize("m", [1, 16, 64, 128])
@pytest.mark.parametrize("n", [2048, 4096])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_fused_add_rms_norm_quant_basic(m, n, dtype):
"""Test basic functionality of fused_add_rms_norm_quant without hp_output."""
torch.manual_seed(42)
device = torch.device("cuda")
# Create input tensors
hidden_states = torch.randn(m, n, dtype=dtype, device=device)
residual = torch.randn(m, n, dtype=dtype, device=device)
gamma = torch.ones(n, dtype=dtype, device=device)
# Compute sf_scale (per-tensor scale)
eps = 1e-6
normed_ref, _ = rms_norm_ref(hidden_states, residual, gamma, eps)
sf_scale = (normed_ref.abs().amax().float() / (6.0 * 448.0)).view(1)
# Run fused kernel without hp_output
normed_fp4, residual_out, sf_out, dummy_output = torch.ops.trtllm.fused_add_rms_norm_quant(
hidden_states, residual, gamma, sf_scale, True, eps=eps
)
assert dummy_output is None
# Verify output shapes
assert normed_fp4.shape[0] == m
assert residual_out.shape == (m, n)
assert residual_out.dtype == dtype
@skip_unsupported
@pytest.mark.parametrize("m", [1, 16, 64, 128])
@pytest.mark.parametrize("n", [2048, 4096])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_fused_add_rms_norm_quant_with_hp_output(m, n, dtype):
"""Test fused_add_rms_norm_quant with output_hp_norm=True."""
torch.manual_seed(42)
device = torch.device("cuda")
# Create input tensors
hidden_states = torch.randn(m, n, dtype=dtype, device=device)
residual = torch.randn(m, n, dtype=dtype, device=device)
gamma = torch.ones(n, dtype=dtype, device=device)
# Compute sf_scale
eps = 1e-6
normed_ref, residual_ref = rms_norm_ref(hidden_states, residual, gamma, eps)
sf_scale = (normed_ref.abs().amax().float() / (6.0 * 448.0)).view(1)
# Run fused kernel with hp_output
results = torch.ops.trtllm.fused_add_rms_norm_quant(
hidden_states, residual, gamma, sf_scale, True, eps=eps, output_hp_norm=True
)
# Should return 4 tensors when output_hp_norm=True
assert len(results) == 4, f"Expected 4 outputs, got {len(results)}"
normed_fp4, residual_out, sf_out, hp_normed_output = results
# Verify output shapes
assert normed_fp4.shape[0] == m
assert residual_out.shape == (m, n)
assert hp_normed_output.shape == (m, n)
# Verify dtypes
assert residual_out.dtype == dtype
assert hp_normed_output.dtype == dtype
# Verify high precision output matches reference
torch.testing.assert_close(hp_normed_output, normed_ref, rtol=1e-2, atol=1e-2)
# Verify residual output matches reference
torch.testing.assert_close(residual_out, residual_ref, rtol=1e-3, atol=1e-3)
@skip_unsupported
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_fused_add_rms_norm_quant_hp_output_consistency(dtype):
"""Test that hp_output is consistent with the quantized output."""
torch.manual_seed(42)
device = torch.device("cuda")
m, n = 64, 4096
# Create input tensors
hidden_states = torch.randn(m, n, dtype=dtype, device=device)
residual = torch.randn(m, n, dtype=dtype, device=device)
gamma = torch.ones(n, dtype=dtype, device=device)
eps = 1e-6
normed_ref, _ = rms_norm_ref(hidden_states, residual, gamma, eps)
sf_scale = (normed_ref.abs().amax().float() / (6.0 * 448.0)).view(1)
# Run without hp_output
results_no_hp = torch.ops.trtllm.fused_add_rms_norm_quant(
hidden_states, residual, gamma, sf_scale, True, eps=eps, output_hp_norm=False
)
assert results_no_hp[3] is None
normed_fp4_no_hp, residual_out_no_hp, sf_out_no_hp = results_no_hp[:3]
# Run with hp_output
results_hp = torch.ops.trtllm.fused_add_rms_norm_quant(
hidden_states, residual, gamma, sf_scale, True, eps=eps, output_hp_norm=True
)
normed_fp4_hp, residual_out_hp, sf_out_hp, hp_normed_output = results_hp
# The quantized outputs should be identical regardless of hp_output flag
torch.testing.assert_close(normed_fp4_hp, normed_fp4_no_hp, rtol=0, atol=0)
torch.testing.assert_close(residual_out_hp, residual_out_no_hp, rtol=0, atol=0)
# Compare only valid SF indices (swizzled layout pads rows to 128)
valid_sf_indices = get_swizzled_sf_indices(m, n)
sf_out_hp_valid = sf_out_hp[valid_sf_indices]
sf_out_no_hp_valid = sf_out_no_hp[valid_sf_indices]
torch.testing.assert_close(sf_out_hp_valid, sf_out_no_hp_valid, rtol=0, atol=0)
@skip_unsupported
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_fused_add_rms_norm_quant_gamma_weight(dtype):
"""Test fused_add_rms_norm_quant with non-trivial gamma weights."""
torch.manual_seed(42)
device = torch.device("cuda")
m, n = 32, 2048
# Create input tensors
hidden_states = torch.randn(m, n, dtype=dtype, device=device)
residual = torch.randn(m, n, dtype=dtype, device=device)
# Non-trivial gamma weights
gamma = torch.randn(n, dtype=dtype, device=device) * 0.5 + 1.0
eps = 1e-6
normed_ref, residual_ref = rms_norm_ref(hidden_states, residual, gamma, eps)
sf_scale = (normed_ref.abs().amax().float() / (6.0 * 448.0)).view(1)
# Run with hp_output
results = torch.ops.trtllm.fused_add_rms_norm_quant(
hidden_states, residual, gamma, sf_scale, True, eps=eps, output_hp_norm=True
)
normed_fp4, residual_out, sf_out, hp_normed_output = results
# Verify high precision output matches reference
torch.testing.assert_close(hp_normed_output, normed_ref, rtol=1e-2, atol=1e-2)
# Verify residual output matches reference
torch.testing.assert_close(residual_out, residual_ref, rtol=1e-3, atol=1e-3)
@skip_unsupported
def test_fused_add_rms_norm_quant_large_batch():
"""Test fused_add_rms_norm_quant with larger batch size."""
torch.manual_seed(42)
device = torch.device("cuda")
m, n = 512, 4096
dtype = torch.bfloat16
hidden_states = torch.randn(m, n, dtype=dtype, device=device)
residual = torch.randn(m, n, dtype=dtype, device=device)
gamma = torch.ones(n, dtype=dtype, device=device)
eps = 1e-6
normed_ref, residual_ref = rms_norm_ref(hidden_states, residual, gamma, eps)
sf_scale = (normed_ref.abs().amax().float() / (6.0 * 448.0)).view(1)
results = torch.ops.trtllm.fused_add_rms_norm_quant(
hidden_states, residual, gamma, sf_scale, True, eps=eps, output_hp_norm=True
)
normed_fp4, residual_out, sf_out, hp_normed_output = results
assert hp_normed_output.shape == (m, n)
torch.testing.assert_close(hp_normed_output, normed_ref, rtol=1e-2, atol=1e-2)
@skip_unsupported
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_low_latency_layernorm_hp_output_consistency(dtype):
"""
Test that low_latency_layernorm hp_output is consistent with/without the flag.
The quantized outputs should be identical regardless of output_hp_norm flag.
Uses m=1 to trigger the low_latency_layernorm path.
"""
torch.manual_seed(42)
device = torch.device("cuda")
m, n = 1, 4096 # m=1 triggers low_latency_layernorm path
hidden_states = torch.randn(m, n, dtype=dtype, device=device)
residual = torch.randn(m, n, dtype=dtype, device=device)
gamma = torch.ones(n, dtype=dtype, device=device)
eps = 1e-6
normed_ref, _ = rms_norm_ref(hidden_states, residual, gamma, eps)
sf_scale = (normed_ref.abs().amax().float() / (6.0 * 448.0)).view(1)
# Run without hp_output
results_no_hp = torch.ops.trtllm.fused_add_rms_norm_quant(
hidden_states, residual, gamma, sf_scale, True, eps=eps
)
assert len(results_no_hp) == 4, f"Expected 4 outputs, got {len(results_no_hp)}"
assert results_no_hp[3] is None, "Expected 4th output to be None when output_hp_norm=False"
normed_fp4_no_hp, residual_out_no_hp, sf_out_no_hp = results_no_hp[:3]
# Run with hp_output
results_hp = torch.ops.trtllm.fused_add_rms_norm_quant(
hidden_states, residual, gamma, sf_scale, True, eps=eps, output_hp_norm=True
)
assert len(results_hp) == 4, f"Expected 4 outputs, got {len(results_hp)}"
normed_fp4_hp, residual_out_hp, sf_out_hp, hp_normed_output = results_hp
# The quantized outputs should be identical regardless of hp_output flag
torch.testing.assert_close(normed_fp4_hp, normed_fp4_no_hp, rtol=0, atol=0)
torch.testing.assert_close(residual_out_hp, residual_out_no_hp, rtol=0, atol=0)
# Compare only valid SF indices (swizzled layout pads rows to 128)
valid_sf_indices = get_swizzled_sf_indices(m, n)
sf_out_hp_valid = sf_out_hp[valid_sf_indices]
sf_out_no_hp_valid = sf_out_no_hp[valid_sf_indices]
torch.testing.assert_close(sf_out_hp_valid, sf_out_no_hp_valid, rtol=0, atol=0)