From 421eb9e39c850b89d0c5f266752695b40c175c39 Mon Sep 17 00:00:00 2001 From: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com> Date: Thu, 12 Feb 2026 22:25:31 +0800 Subject: [PATCH] [None][feat] Optimize NemotronH model with elementwise and nvfp4 fusion (#11273) Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com> --- .../kernels/causalConv1d/causalConv1d.cu | 52 ++- .../kernels/fusedActivationQuant.cu | 187 ++++++++++ .../kernels/fusedActivationQuant.h | 33 ++ .../fusedLayernormKernels/layernorm_param.h | 1 + .../low_latency_layernorm.cuh | 4 +- .../fusedLayernormKernels/ws_layernorm.cuh | 12 +- .../fusedLayernormKernels/ws_layernorm.h | 2 +- .../ws_layernorm_fp4_traits.cu | 49 ++- cpp/tensorrt_llm/thop/CMakeLists.txt | 1 + .../thop/fusedActivationQuant.cpp | 94 +++++ .../thop/fusedAddRMSNormQuant.cpp | 25 +- .../_torch/custom_ops/cpp_custom_ops.py | 24 +- .../_torch/models/modeling_nemotron_h.py | 142 ++++++-- .../modules/mamba/fuse_elementwise_ops.py | 176 +++++++++ .../_torch/modules/mamba/mamba2_metadata.py | 143 ++++++-- .../_torch/modules/mamba/mamba2_mixer.py | 46 +-- .../_torch/modules/mamba/ssd_chunk_scan.py | 131 ++++++- .../_torch/modules/mamba/ssd_chunk_state.py | 142 +++++--- tensorrt_llm/_torch/modules/mlp.py | 68 +++- tensorrt_llm/_torch/modules/rms_norm.py | 27 +- .../_torch/pyexecutor/mamba_cache_manager.py | 4 +- .../tools/layer_wise_benchmarks/runner.py | 3 +- .../modules/mamba/test_causal_conv1d.py | 246 +++++++++++++ .../mamba/test_fuse_elementwise_ops.py | 113 ++++++ .../modules/mamba/test_mamba2_metadata.py | 133 +++++++ .../modules/test_fused_activation_quant.py | 223 ++++++++++++ .../modules/test_fused_add_rms_norm_quant.py | 336 ++++++++++++++++++ 27 files changed, 2203 insertions(+), 214 deletions(-) create mode 100644 cpp/tensorrt_llm/kernels/fusedActivationQuant.cu create mode 100644 cpp/tensorrt_llm/kernels/fusedActivationQuant.h create mode 100644 cpp/tensorrt_llm/thop/fusedActivationQuant.cpp create mode 100644 tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py create mode 100644 tests/unittest/_torch/modules/mamba/test_causal_conv1d.py create mode 100644 tests/unittest/_torch/modules/mamba/test_fuse_elementwise_ops.py create mode 100644 tests/unittest/_torch/modules/mamba/test_mamba2_metadata.py create mode 100644 tests/unittest/_torch/modules/test_fused_activation_quant.py create mode 100644 tests/unittest/_torch/modules/test_fused_add_rms_norm_quant.py diff --git a/cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu b/cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu index 8ec6bbbf82..a5f22858ac 100644 --- a/cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu +++ b/cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu @@ -43,6 +43,8 @@ struct Causal_conv1d_fwd_kernel_traits static_assert(kWidth <= kNElts); static constexpr bool kIsVecLoad = kIsVecLoad_; using vec_t = typename BytesToType::Type; + static_assert(kNThreads_ % 32 == 0, "kNThreads must be a multiple of 32 for warp shuffle"); + static_assert(sizeof(vec_t) == 16, "vec_t must be 16 bytes for warp shuffle optimization"); using BlockLoadT = cub::BlockLoad; using BlockLoadVecT = cub::BlockLoad; using BlockStoreT = cub::BlockStore; @@ -123,7 +125,7 @@ __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_fwd_kernel(C #pragma unroll for (int i = 0; i < kWidth; ++i) { - weight_vals[i] = float(weight[i * params.weight_width_stride]); + weight_vals[i] = float(__ldg(&weight[i * params.weight_width_stride])); } constexpr int kChunkSize = kNThreads * kNElts; @@ -144,20 +146,41 @@ __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_fwd_kernel(C x, *reinterpret_cast(&x_vals_load[kNElts]), seqlen - chunk * kChunkSize); } x += kChunkSize; + + int const lane_id = tidx & 31; + vec_t high_val = reinterpret_cast(x_vals_load)[1]; + __syncthreads(); // Thread kNThreads - 1 don't write yet, so that thread 0 can read // the last elements of the previous chunk. if (tidx < kNThreads - 1) { - smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; + smem_exchange[tidx] = high_val; } __syncthreads(); - reinterpret_cast(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1]; + + // Get neighbor data: use warp shuffle for most threads, shared memory for warp boundaries + vec_t neighbor; + uint32_t* high_val_p = reinterpret_cast(&high_val); + uint32_t* nbr_p = reinterpret_cast(&neighbor); + nbr_p[0] = __shfl_up_sync(0xFFFFFFFF, high_val_p[0], 1); + nbr_p[1] = __shfl_up_sync(0xFFFFFFFF, high_val_p[1], 1); + nbr_p[2] = __shfl_up_sync(0xFFFFFFFF, high_val_p[2], 1); + nbr_p[3] = __shfl_up_sync(0xFFFFFFFF, high_val_p[3], 1); + + // Lane 0 must use shared memory to handle the cross-warp boundary. + // thread 0 uses the last element of the previous chunk. + if (lane_id == 0) + { + neighbor = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1]; + } + reinterpret_cast(x_vals_load)[0] = neighbor; + __syncthreads(); // Now thread kNThreads - 1 can write the last elements of the current chunk. if (tidx == kNThreads - 1) { - smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; + smem_exchange[tidx] = high_val; } float x_vals[2 * kNElts]; @@ -169,22 +192,33 @@ __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_fwd_kernel(C float out_vals[kNElts]; #pragma unroll - for (int i = 0; i < kNElts; ++i) + // Process 2 outputs at a time for better ILP (instruction level parallelism). + for (int i = 0; i < kNElts; i += 2) { - out_vals[i] = bias_val; + float acc0 = bias_val; + float acc1 = bias_val; #pragma unroll for (int w = 0; w < kWidth; ++w) { - out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)]; + float wt = weight_vals[w]; + acc0 = __fmaf_rn(wt, x_vals[kNElts + i - (kWidth - w - 1)], acc0); + acc1 = __fmaf_rn(wt, x_vals[kNElts + i + 1 - (kWidth - w - 1)], acc1); } + out_vals[i] = acc0; + out_vals[i + 1] = acc1; } if (params.silu_activation) { #pragma unroll - for (int i = 0; i < kNElts; ++i) + for (int i = 0; i < kNElts; i += 2) { - out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); + // SiLU: x * sigmoid(x) = x / (1 + exp(-x)) + // Using fast math: __expf and __frcp_rn + float v0 = out_vals[i]; + float v1 = out_vals[i + 1]; + out_vals[i] = v0 * __frcp_rn(1.0f + __expf(-v0)); + out_vals[i + 1] = v1 * __frcp_rn(1.0f + __expf(-v1)); } } diff --git a/cpp/tensorrt_llm/kernels/fusedActivationQuant.cu b/cpp/tensorrt_llm/kernels/fusedActivationQuant.cu new file mode 100644 index 0000000000..67481cb56f --- /dev/null +++ b/cpp/tensorrt_llm/kernels/fusedActivationQuant.cu @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/kernels/fusedActivationQuant.h" +#include "tensorrt_llm/kernels/quantization.cuh" +#include "tensorrt_llm/kernels/quantization.h" + +#include +#include +#include +#include + +TRTLLM_NAMESPACE_BEGIN + +namespace kernels +{ + +constexpr int kEltsPerThread = 8; + +__device__ __forceinline__ float relu2_f32(float x) +{ + float r = fmaxf(0.0f, x); + return r * r; +} + +// Fused relu2 + NVFP4 quantization kernel. +// +// To match the unfused path (PyTorch relu2 -> cvt_warp_fp16_to_fp4), relu2 is +// computed in f32 then rounded back to native precision (bf16/fp16) before +// quantization. Absmax and scale-factor math follow cvt_warp_fp16_to_fp4 exactly. +// Column padding to a multiple of (4 * kSfVecSize) matches quantize_with_block_size +// for the swizzled SF layout. +template +__global__ void fusedRelu2QuantizeKernel(T const* __restrict__ input, float const* __restrict__ sfScale, + uint32_t* __restrict__ outputFp4, uint32_t* __restrict__ outputSf, int m, int n) +{ +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr int kSfVecSize = 16; + constexpr int kNumThreadsPerSf = kSfVecSize / kEltsPerThread; + constexpr int kPackedPerThread = kEltsPerThread / 2; + + using PackedType = std::conditional_t, __half2, __nv_bfloat162>; + + float const SFScaleVal = sfScale[0]; + int const numColThreads = n / kEltsPerThread; + int const numColVecs = n / kSfVecSize; + int const numColThreadsPadded = ((n + 4 * kSfVecSize - 1) / (4 * kSfVecSize)) * (4 * kSfVecSize) / kEltsPerThread; + int const rowIdx = blockIdx.x; + + if (rowIdx >= m) + return; + + for (int colIdx = threadIdx.x; colIdx < numColThreadsPadded; colIdx += blockDim.x) + { + bool const isValidCol = colIdx < numColThreads; + PackedType packedVals[kPackedPerThread]; + + if (isValidCol) + { + int const inputOffset = rowIdx * n + colIdx * kEltsPerThread; +#pragma unroll + for (int i = 0; i < kPackedPerThread; i++) + { + float f0 = relu2_f32(static_cast(input[inputOffset + i * 2])); + float f1 = relu2_f32(static_cast(input[inputOffset + i * 2 + 1])); + if constexpr (std::is_same_v) + { + packedVals[i] = __floats2half2_rn(f0, f1); + } + else + { + packedVals[i] = __floats2bfloat162_rn(f0, f1); + } + } + } + else + { +#pragma unroll + for (int i = 0; i < kPackedPerThread; i++) + { + if constexpr (std::is_same_v) + { + packedVals[i] = __float2half2_rn(0.0f); + } + else + { + packedVals[i] = __float2bfloat162_rn(0.0f); + } + } + } + + // Absmax in native precision, then reduce across the SF group (2 threads). + auto localMax = cuda_abs(packedVals[0]); +#pragma unroll + for (int i = 1; i < kPackedPerThread; i++) + { + localMax = cuda_max(localMax, cuda_abs(packedVals[i])); + } + localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + float vecMax = float(cuda_max(localMax.x, localMax.y)); + + // Scale-factor computation (identical to cvt_warp_fp16_to_fp4). + float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + __nv_fp8_e4m3 fp8SF = __nv_fp8_e4m3(SFValue); + uint8_t fp8SFVal = fp8SF.__x; + SFValue = static_cast(fp8SF); + + float outputScale + = vecMax != 0.0f ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) : 0.0f; + + if (colIdx % kNumThreadsPerSf == 0) + { + auto sfOutPtr = cvt_quant_get_sf_out_offset(std::nullopt, rowIdx, colIdx, + std::optional(m), numColVecs, outputSf, QuantizationSFLayout::SWIZZLED); + if (sfOutPtr != nullptr) + { + *sfOutPtr = fp8SFVal; + } + } + + if (isValidCol) + { + float2 fp2Vals[kPackedPerThread]; +#pragma unroll + for (int i = 0; i < kPackedPerThread; i++) + { + if constexpr (std::is_same_v) + { + fp2Vals[i] = __half22float2(packedVals[i]); + } + else + { + fp2Vals[i] = __bfloat1622float2(packedVals[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + outputFp4[rowIdx * numColThreads + colIdx] = fp32_vec_to_e2m1(fp2Vals); + } + } +#else + if (threadIdx.x == 0 && blockIdx.x == 0) + { + printf("FP4 quantization requires SM100 (Blackwell) or later!\n"); + } +#endif +} + +template +void invokeFusedRelu2Quantize(T const* input, float const* sfScale, std::uint8_t* outputFp4, std::uint8_t* outputSf, + int m, int n, int sfVecSize, cudaStream_t stream) +{ + constexpr int kSfVecSize = 16; + int const numColThreadsPadded = ((n + 4 * kSfVecSize - 1) / (4 * kSfVecSize)) * (4 * kSfVecSize) / kEltsPerThread; + int threadsPerBlock = min(512, numColThreadsPadded); + threadsPerBlock = max(32, ((threadsPerBlock + 31) / 32) * 32); + + fusedRelu2QuantizeKernel<<>>( + input, sfScale, reinterpret_cast(outputFp4), reinterpret_cast(outputSf), m, n); +} + +template void invokeFusedRelu2Quantize( + half const*, float const*, std::uint8_t*, std::uint8_t*, int, int, int, cudaStream_t); + +#ifdef ENABLE_BF16 +template void invokeFusedRelu2Quantize<__nv_bfloat16>( + __nv_bfloat16 const*, float const*, std::uint8_t*, std::uint8_t*, int, int, int, cudaStream_t); +#endif + +} // namespace kernels + +TRTLLM_NAMESPACE_END diff --git a/cpp/tensorrt_llm/kernels/fusedActivationQuant.h b/cpp/tensorrt_llm/kernels/fusedActivationQuant.h new file mode 100644 index 0000000000..6b9a20e913 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/fusedActivationQuant.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "tensorrt_llm/common/config.h" +#include +#include + +TRTLLM_NAMESPACE_BEGIN + +namespace kernels +{ + +template +void invokeFusedRelu2Quantize(T const* input, float const* sfScale, std::uint8_t* outputFp4, std::uint8_t* outputSf, + int m, int n, int sfVecSize, cudaStream_t stream); + +} // namespace kernels + +TRTLLM_NAMESPACE_END diff --git a/cpp/tensorrt_llm/kernels/fusedLayernormKernels/layernorm_param.h b/cpp/tensorrt_llm/kernels/fusedLayernormKernels/layernorm_param.h index 0e05e0a835..ebf0e1bcd2 100644 --- a/cpp/tensorrt_llm/kernels/fusedLayernormKernels/layernorm_param.h +++ b/cpp/tensorrt_llm/kernels/fusedLayernormKernels/layernorm_param.h @@ -37,6 +37,7 @@ struct GeneralFP4AddBiasResidualPreLayerNormParam T const* bias = nullptr; T const* gamma = nullptr; T const* beta = nullptr; + T* high_precision_normed_output = nullptr; int m; int n; diff --git a/cpp/tensorrt_llm/kernels/fusedLayernormKernels/low_latency_layernorm.cuh b/cpp/tensorrt_llm/kernels/fusedLayernormKernels/low_latency_layernorm.cuh index 6a925c5510..c618945ec5 100644 --- a/cpp/tensorrt_llm/kernels/fusedLayernormKernels/low_latency_layernorm.cuh +++ b/cpp/tensorrt_llm/kernels/fusedLayernormKernels/low_latency_layernorm.cuh @@ -276,7 +276,7 @@ struct LowLatencyLayerNorm } typename PackType::type normed_output; - typename PackType::type + typename PackType::type high_precision_normed_output; for (int j = 0; j < Traits::PACKED_ELEMS_PER_COMPUTE; j++) { @@ -300,7 +300,7 @@ struct LowLatencyLayerNorm } if constexpr (Traits::HIGH_PRECISION_NORMED_OUTPUT) { - high_precision_normed_output.array[j] = normed_out; + high_precision_normed_output.array[j] = (typename Traits::InputType) normed_out; } if constexpr (Traits::OUTPUT_SCALE == SCALE_TYPE::SCALAR) { diff --git a/cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.cuh b/cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.cuh index 5359b9dc55..48906bb07a 100644 --- a/cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.cuh +++ b/cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.cuh @@ -690,7 +690,7 @@ struct WarpSpecializedLayerNorm typename PackType::type normed_output; typename PackType::type output; - typename PackType::type + typename PackType::type high_precision_normed_output; #pragma unroll Traits::PACKED_ELEMS_PER_COMPUTE @@ -719,6 +719,11 @@ struct WarpSpecializedLayerNorm normed_out += beta[j]; } + if constexpr (Traits::HIGH_PRECISION_NORMED_OUTPUT) + { + high_precision_normed_output.array[j] = (typename Traits::InputType) normed_out; + } + if constexpr (Traits::OUTPUT_SCALE != SCALE_TYPE::NONE) { static_assert(Traits::OUTPUT_SCALE == SCALE_TYPE::SCALAR); @@ -730,11 +735,6 @@ struct WarpSpecializedLayerNorm output.array[j] = (typename Traits::InputType) data[m_offset][i][j]; } - if constexpr (Traits::HIGH_PRECISION_NORMED_OUTPUT) - { - high_precision_normed_output.array[j] = normed_out; - } - normed_output.array[j] = (typename Traits::OutputType) normed_out; } diff --git a/cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.h b/cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.h index c7579251fb..c4b93024e9 100644 --- a/cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.h +++ b/cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.h @@ -44,7 +44,7 @@ enum class SCALE_TYPE }; template -void invokeWSLayerNorm(WarpSpecializedParam param, bool use_rms_norm, int ctas); +void invokeWSLayerNorm(WarpSpecializedParam param, bool use_rms_norm, int ctas, bool output_hp_norm = false); } // namespace kernels diff --git a/cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm_fp4_traits.cu b/cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm_fp4_traits.cu index 9103491cdd..3623d630d4 100644 --- a/cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm_fp4_traits.cu +++ b/cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm_fp4_traits.cu @@ -31,7 +31,8 @@ TRTLLM_NAMESPACE_BEGIN namespace kernels { template + int _M_BLOCK, int _N_BLOCK, int _STAGES = 3, bool _PERSISTENT_MODE = true, bool _LOW_LATENCY_MODE = false, + bool _HIGH_PRECISION_NORMED_OUTPUT = false> struct FP4AddBiasResidualPreLayerNormTraits { @@ -59,12 +60,12 @@ struct FP4AddBiasResidualPreLayerNormTraits static constexpr bool PERSISTENT_MODE = _PERSISTENT_MODE; static constexpr bool LOW_LATENCY_MODE = _LOW_LATENCY_MODE; static constexpr bool PREFETCH_TO_L2 = false; - static constexpr bool HIGH_PRECISION_NORMED_OUTPUT = false; + static constexpr bool HIGH_PRECISION_NORMED_OUTPUT = _HIGH_PRECISION_NORMED_OUTPUT; }; template -void invokeWSLayerNormImpl( - WarpSpecializedParam> param, bool use_rms_norm, int ctas) +void invokeWSLayerNormImpl(WarpSpecializedParam> param, bool use_rms_norm, + int ctas, bool output_hp_norm) { auto _invoke = [&](auto traits) @@ -80,10 +81,11 @@ void invokeWSLayerNormImpl( { int waves = ((param.m + Traits::M_BLOCK - 1) / Traits::M_BLOCK + ctas - 1) / ctas; TLLM_LOG_DEBUG( - "Selected TILE_M = %d, N = %d, STAGE = %d, PERSISTENT_MODE = %d, LOW_LATENCY_MODE = %d for param M = " + "Selected TILE_M = %d, N = %d, STAGE = %d, PERSISTENT_MODE = %d, LOW_LATENCY_MODE = %d, " + "HIGH_PRECISION_NORMED_OUTPUT = %d for param M = " "%d, N = %d, num_sms = %d. (waves = %d)\n", Traits::M_BLOCK, Traits::N_BLOCK, Traits::STAGES, Traits::PERSISTENT_MODE, Traits::LOW_LATENCY_MODE, - param.m, param.n, ctas, waves); + Traits::HIGH_PRECISION_NORMED_OUTPUT, param.m, param.n, ctas, waves); printed = true; } @@ -117,15 +119,32 @@ void invokeWSLayerNormImpl( constexpr auto PERSISTENT = decltype(persistent)::value; constexpr auto LOW_LATENCY_MODE = decltype(low_latency_mode)::value; + // Select kernel variant based on use_rms_norm and output_hp_norm if (use_rms_norm) { - _invoke(FP4AddBiasResidualPreLayerNormTraits, T, T, float, - true, M_BLOCK, N_BLOCK, STAGES, PERSISTENT, LOW_LATENCY_MODE>{}); + if (output_hp_norm) + { + _invoke(FP4AddBiasResidualPreLayerNormTraits, T, T, float, + true, M_BLOCK, N_BLOCK, STAGES, PERSISTENT, LOW_LATENCY_MODE, true>{}); + } + else + { + _invoke(FP4AddBiasResidualPreLayerNormTraits, T, T, float, + true, M_BLOCK, N_BLOCK, STAGES, PERSISTENT, LOW_LATENCY_MODE, false>{}); + } } else { - _invoke(FP4AddBiasResidualPreLayerNormTraits, T, T, float, - false, M_BLOCK, N_BLOCK, STAGES, PERSISTENT, LOW_LATENCY_MODE>{}); + if (output_hp_norm) + { + _invoke(FP4AddBiasResidualPreLayerNormTraits, T, T, float, + false, M_BLOCK, N_BLOCK, STAGES, PERSISTENT, LOW_LATENCY_MODE, true>{}); + } + else + { + _invoke(FP4AddBiasResidualPreLayerNormTraits, T, T, float, + false, M_BLOCK, N_BLOCK, STAGES, PERSISTENT, LOW_LATENCY_MODE, false>{}); + } } }; @@ -308,16 +327,18 @@ void invokeWSLayerNormImpl( template <> void invokeWSLayerNorm>( - WarpSpecializedParam> param, bool use_rms_norm, int ctas) + WarpSpecializedParam> param, bool use_rms_norm, int ctas, + bool output_hp_norm) { - invokeWSLayerNormImpl(param, use_rms_norm, ctas); + invokeWSLayerNormImpl(param, use_rms_norm, ctas, output_hp_norm); } template <> void invokeWSLayerNorm>( - WarpSpecializedParam> param, bool use_rms_norm, int ctas) + WarpSpecializedParam> param, bool use_rms_norm, int ctas, + bool output_hp_norm) { - invokeWSLayerNormImpl(param, use_rms_norm, ctas); + invokeWSLayerNormImpl(param, use_rms_norm, ctas, output_hp_norm); } } // namespace kernels diff --git a/cpp/tensorrt_llm/thop/CMakeLists.txt b/cpp/tensorrt_llm/thop/CMakeLists.txt index cd4974d917..08c7baf9c6 100644 --- a/cpp/tensorrt_llm/thop/CMakeLists.txt +++ b/cpp/tensorrt_llm/thop/CMakeLists.txt @@ -67,6 +67,7 @@ add_library( dsv3FusedAGemmOp.cpp fusedQKNormRopeOp.cpp fusedAddRMSNormQuant.cpp + fusedActivationQuant.cpp fusedTopkSoftmax.cpp gatherTreeOp.cpp groupRmsNormOp.cpp diff --git a/cpp/tensorrt_llm/thop/fusedActivationQuant.cpp b/cpp/tensorrt_llm/thop/fusedActivationQuant.cpp new file mode 100644 index 0000000000..ff25405fee --- /dev/null +++ b/cpp/tensorrt_llm/thop/fusedActivationQuant.cpp @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/fusedActivationQuant.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/kernels/quantization.h" +#include "tensorrt_llm/thop/thUtils.h" + +#include +#include +#include + +#include +#include + +TRTLLM_NAMESPACE_BEGIN + +namespace torch_ext +{ + +std::tuple fused_relu2_quantize( + at::Tensor const& input, at::Tensor const& sf_scale, int64_t sf_vec_size) +{ + CHECK_TH_CUDA(input); + CHECK_CONTIGUOUS(input); + CHECK_INPUT(sf_scale, torch::kFloat32); + + auto const& inputShape = input.sizes(); + TORCH_CHECK(inputShape.size() == 2, "input should be 2D tensor [M, N]."); + + int64_t const m = inputShape[0]; + int64_t const n = inputShape[1]; + + TORCH_CHECK(sf_vec_size == 16, "sf_vec_size must be 16 for NVFP4."); + TORCH_CHECK(n % sf_vec_size == 0, "N must be divisible by sf_vec_size."); + + auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); + + at::Tensor output_fp4 = at::detail::empty_cuda({m, n / 2}, torch::kUInt8, input.device(), std::nullopt); + int64_t const sfSize = tensorrt_llm::computeSwizzledLayoutSFSize(m, n / sf_vec_size); + at::Tensor output_sf = at::detail::empty_cuda({sfSize}, SF_DTYPE, input.device(), std::nullopt); + + float const* sfScalePtr = sf_scale.data_ptr(); + + if (input.scalar_type() == at::ScalarType::Half) + { + kernels::invokeFusedRelu2Quantize(reinterpret_cast(input.data_ptr()), sfScalePtr, + output_fp4.data_ptr(), output_sf.data_ptr(), static_cast(m), static_cast(n), + static_cast(sf_vec_size), stream); + } + else if (input.scalar_type() == at::ScalarType::BFloat16) + { +#ifdef ENABLE_BF16 + kernels::invokeFusedRelu2Quantize<__nv_bfloat16>(reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr()), + sfScalePtr, output_fp4.data_ptr(), output_sf.data_ptr(), static_cast(m), + static_cast(n), static_cast(sf_vec_size), stream); +#else + C10_THROW_ERROR(NotImplementedError, "BFloat16 not enabled."); +#endif + } + else + { + C10_THROW_ERROR(NotImplementedError, "fused_relu2_quantize only supports fp16/bf16."); + } + + return std::make_tuple(output_fp4, output_sf); +} + +} // namespace torch_ext + +TRTLLM_NAMESPACE_END + +TORCH_LIBRARY_FRAGMENT(trtllm, m) +{ + m.def("fused_relu2_quantize(Tensor input, Tensor sf_scale, int sf_vec_size=16) -> (Tensor, Tensor)"); +} + +TORCH_LIBRARY_IMPL(trtllm, CUDA, m) +{ + m.impl("fused_relu2_quantize", &tensorrt_llm::torch_ext::fused_relu2_quantize); +} diff --git a/cpp/tensorrt_llm/thop/fusedAddRMSNormQuant.cpp b/cpp/tensorrt_llm/thop/fusedAddRMSNormQuant.cpp index 0d76aff4f8..b764c18043 100644 --- a/cpp/tensorrt_llm/thop/fusedAddRMSNormQuant.cpp +++ b/cpp/tensorrt_llm/thop/fusedAddRMSNormQuant.cpp @@ -43,16 +43,18 @@ namespace torch_ext // gamma: [N] - RMSNorm weight (fp16/bf16) // sf_scale: [1] - optional scale factor for FP4 quantization (float) // use_rms_norm: bool - if true use RMSNorm, else use LayerNorm +// output_hp_norm: bool - if true, also output high precision normalized values (same dtype as input) for MoE gate. // Returns: // normed_output: [M, N/8] - FP4 quantized normalized output (uint32_t, packed) // output: [M, N] - pre-norm output (input + residual), same dtype as input // sf_out: scale factors for FP4 (uint8_t), swizzled layout +// high_precision_normed_output: [M, N] - normalized output before quant (only if output_hp_norm=true, else empty) // // NOTE: This kernel requires SM90 (Hopper) or SM100 (Blackwell) GPU architecture. // NOTE: Hidden dimension N must be >= 2048 and <= 16384. -std::tuple fused_add_rms_norm_quant(at::Tensor const& input, - at::Tensor const& residual, at::Tensor const& gamma, std::optional const& sf_scale, bool use_rms_norm, - double eps) +std::tuple> fused_add_rms_norm_quant( + at::Tensor const& input, at::Tensor const& residual, at::Tensor const& gamma, + std::optional const& sf_scale, bool use_rms_norm, double eps, bool output_hp_norm) { CHECK_TH_CUDA(input); CHECK_CONTIGUOUS(input); @@ -116,6 +118,14 @@ std::tuple fused_add_rms_norm_quant(at::Tens int64_t const sfSizePadded = tensorrt_llm::computeSwizzledLayoutSFSize(m_padded, n / sfVecSize); at::Tensor sf_out_padded = at::detail::empty_cuda({sfSizePadded}, SF_DTYPE, input.device(), std::nullopt); at::Tensor sf_out = (m_padded == m) ? sf_out_padded : sf_out_padded.narrow(0, 0, sfSize); + std::optional high_precision_normed_output = std::nullopt; + if (output_hp_norm) + { + at::Tensor hp_normed_output_padded + = at::detail::empty_cuda({m_padded, n}, input.scalar_type(), input.device(), std::nullopt); + high_precision_normed_output + = (m_padded == m) ? hp_normed_output_padded : hp_normed_output_padded.narrow(0, 0, m); + } // Get number of SMs for persistent kernel static int const multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount(); @@ -152,12 +162,14 @@ std::tuple fused_add_rms_norm_quant(at::Tens param.bias = nullptr; \ param.gamma = reinterpret_cast(gamma.data_ptr()); \ param.beta = nullptr; \ + param.high_precision_normed_output \ + = output_hp_norm ? reinterpret_cast(high_precision_normed_output.value().data_ptr()) : nullptr; \ param.m = static_cast(m); \ param.n = static_cast(n); \ param.layernorm_eps = static_cast(eps); \ param.stream = stream; \ param.counters = counters; \ - tensorrt_llm::kernels::invokeWSLayerNorm(param, use_rms_norm, multiProcessorCount); \ + tensorrt_llm::kernels::invokeWSLayerNorm(param, use_rms_norm, multiProcessorCount, output_hp_norm); \ } while (0) if (input.scalar_type() == at::ScalarType::Half) @@ -180,7 +192,7 @@ std::tuple fused_add_rms_norm_quant(at::Tens #undef LAUNCH_FUSED_ADD_RMS_NORM_QUANT // No explicit sync needed - kernel runs asynchronously on the stream - return std::make_tuple(normed_output, output, sf_out); + return std::make_tuple(normed_output, output, sf_out, high_precision_normed_output); } } // namespace torch_ext @@ -191,7 +203,8 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "fused_add_rms_norm_quant(Tensor input, Tensor residual, Tensor gamma, " - "Tensor? sf_scale, bool use_rms_norm=True, float eps=1e-6) -> (Tensor, Tensor, Tensor)"); + "Tensor? sf_scale, bool use_rms_norm=True, float eps=1e-6, bool output_hp_norm=False) -> (Tensor, Tensor, " + "Tensor, Tensor?)"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) diff --git a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py index 3a3ee1238b..0a1a3b345f 100644 --- a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py @@ -1003,7 +1003,9 @@ def _register_fake(): sf_scale: Optional[torch.Tensor], use_rms_norm: bool = True, eps: float = 1e-5, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + output_hp_norm: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor]]: m, n = input.shape # normed_output_fp4: [M, N/8] as int32 (8 FP4 values packed per int32) normed_output_fp4 = input.new_empty((m, n // 8), dtype=torch.int32) @@ -1013,4 +1015,22 @@ def _register_fake(): sf_vec_size = 16 sf_size = ((m + 127) // 128) * 128 * ((n // sf_vec_size + 3) // 4) * 4 sf_out = input.new_empty((sf_size, ), dtype=torch.uint8) - return normed_output_fp4, output, sf_out + # high_precision_normed_output: [M, N] optional, only when output_hp_norm=True + hp_output = input.new_empty( + (m, n), dtype=input.dtype) if output_hp_norm else None + return normed_output_fp4, output, sf_out, hp_output + + @torch.library.register_fake("trtllm::fused_relu2_quantize") + def _( + input: torch.Tensor, + sf_scale: torch.Tensor, + sf_vec_size: int = 16, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # input: 2D tensor [M, N] (bf16 or fp16) + # output_fp4: [M, N/2] (packed FP4 values, 2 values per byte) + # output_sf: swizzled scale factors + output_shape, scale_shape = fp4_utils.get_fp4_shape( + input.shape, sf_vec_size, is_swizzled_layout=True) + output_fp4 = input.new_empty(output_shape, dtype=torch.uint8) + output_sf = input.new_empty((scale_shape, ), dtype=torch.uint8) + return output_fp4, output_sf diff --git a/tensorrt_llm/_torch/models/modeling_nemotron_h.py b/tensorrt_llm/_torch/models/modeling_nemotron_h.py index e50c67991c..7871912075 100644 --- a/tensorrt_llm/_torch/models/modeling_nemotron_h.py +++ b/tensorrt_llm/_torch/models/modeling_nemotron_h.py @@ -14,7 +14,6 @@ # limitations under the License. import re -from typing import Dict, List, Optional import torch from torch import nn @@ -23,6 +22,7 @@ from transformers import AutoConfig, PretrainedConfig from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \ BaseWeightMapper from tensorrt_llm._torch.utils import ActivationType, relu2 +from tensorrt_llm.logger import logger from ..attention_backend import AttentionMetadata from ..distributed import AllReduce @@ -37,7 +37,7 @@ from ..modules.mlp import MLP from ..modules.multi_stream_utils import maybe_execute_in_parallel from ..modules.rms_norm import RMSNorm from ..speculative import SpecMetadata -from ..utils import AuxStreamType, EventType +from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor from .modeling_deepseekv3 import DeepseekV3MTPHead from .modeling_speculative import SpecDecOneEngineForCausalLM from .modeling_utils import DecoderModel, register_auto_model @@ -121,7 +121,7 @@ class NemotronHMOE(nn.Module): self, model_config: ModelConfig[PretrainedConfig], layer_idx: int, - aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream], + aux_stream_dict: dict[AuxStreamType, torch.cuda.Stream], ): super().__init__() @@ -242,13 +242,20 @@ class NemotronHMOE(nn.Module): def forward( self, - hidden_states: torch.Tensor, + hidden_states: torch.Tensor + | tuple[torch.Tensor | Fp4QuantizedTensor, torch.Tensor], attn_metadata: AttentionMetadata, **kwargs, ) -> torch.Tensor: - assert hidden_states.shape[-1] == self.hidden_dim - orig_shape = hidden_states.shape - hidden_states = hidden_states.view(-1, self.hidden_dim) + + if isinstance(hidden_states, tuple): + hidden_states, hidden_states_hp = hidden_states + else: + hidden_states_hp = hidden_states + + assert hidden_states_hp.shape[-1] == self.hidden_dim + orig_shape = hidden_states_hp.shape + hidden_states_hp_2d = hidden_states_hp.view(-1, self.hidden_dim) all_rank_num_tokens = attn_metadata.all_rank_num_tokens def _compute_shared_output(): @@ -259,7 +266,8 @@ class NemotronHMOE(nn.Module): return shared_expert_output def _compute_routed_output(): - router_logits = self.gate(hidden_states) + # Gate uses high precision input for accurate routing decisions. + router_logits = self.gate(hidden_states_hp_2d) routed_hidden_states = hidden_states if self.use_latent_moe: @@ -301,7 +309,7 @@ class NemotronHLayer(DecoderLayer): # - -> MLPLayer # * -> TransformerLayer layer_type: str, - aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream], + aux_stream_dict: dict[AuxStreamType, torch.cuda.Stream], ): super().__init__() @@ -310,10 +318,31 @@ class NemotronHLayer(DecoderLayer): self.layer_idx = layer_idx self.layer_type = layer_type + self.is_nvfp4 = (model_config.quant_config is not None + and model_config.quant_config.quant_mode is not None + and model_config.quant_config.quant_mode.has_nvfp4()) + # The fused RMSNorm+NVFP4 CUDA kernel requires hidden_size to be + # a supported tile size. Non-power-of-2 hidden sizes within tile + # ranges may cause kernel hangs. Disable fused NVFP4 for such cases. + # Supported tile sizes: 2048, 4096, 8192, 16384 + _SUPPORTED_NVFP4_HIDDEN_SIZES = {2048, 4096, 8192, 16384} + if self.is_nvfp4 and config.hidden_size not in _SUPPORTED_NVFP4_HIDDEN_SIZES: + logger.warning_once( + f"Layer {layer_idx}: Disabling fused NVFP4 RMSNorm for hidden_size={config.hidden_size}. " + f"Supported sizes: {_SUPPORTED_NVFP4_HIDDEN_SIZES}. Using non-fused path.", + key=f"disable_nvfp4_rmsnorm_with_{config.hidden_size}") + self.is_nvfp4 = False + self.norm = RMSNorm( hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype, + # Enable fused NVFP4 quantization if possible. + # It might be overridden in `_try_attach_nvfp4_scale` function. + quantize_type="nvfp4" if self.is_nvfp4 else None, + # Enable high precision output for MoE layer (only with NVFP4). + # It might be overridden in `_try_attach_nvfp4_scale` function. + return_hp_output=layer_type == "E" and self.is_nvfp4, ) if layer_type == "M": @@ -343,29 +372,71 @@ class NemotronHLayer(DecoderLayer): else: raise ValueError(f"{layer_type} is not supported") + def post_load_weights(self): + """Post-process after loading weights.""" + if self.norm.is_nvfp4 and not hasattr(self.norm, 'nvfp4_scale'): + self._try_attach_nvfp4_scale() + + def _try_attach_nvfp4_scale(self): + """Attach input_scale from mixer's first linear to norm for fused RMSNorm+Quant.""" + # Normal handling for Mamba, MLP, and Attention layers. + first_linear_attr = { + 'M': 'in_proj', + '-': 'up_proj', + '*': 'qkv_proj' + }.get(self.layer_type) + if first_linear_attr: + first_linear = getattr(self.mixer, first_linear_attr, None) + if first_linear and hasattr(first_linear, 'input_scale'): + self.norm.nvfp4_scale = first_linear.input_scale + return + + # Special handling for MoE layer: fetch shared_expert.up_proj.input_scale + # as representation of the input scale. + if self.layer_type == 'E': + if (hasattr(self.mixer, 'shared_experts') + and self.mixer.shared_experts is not None + and hasattr(self.mixer.shared_experts, 'up_proj') + and hasattr(self.mixer.shared_experts.up_proj, + 'input_scale') and + self.mixer.shared_experts.up_proj.input_scale is not None): + self.norm.nvfp4_scale = self.mixer.shared_experts.up_proj.input_scale + # Enable high precision output for MoE layer. + self.norm.return_hp_output = True + return + + self.norm.is_nvfp4 = False + self.norm.return_hp_output = False + def forward( self, position_ids: torch.IntTensor, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, - spec_metadata: Optional[SpecMetadata] = None, + residual: torch.Tensor | None = None, + spec_metadata: SpecMetadata | None = None, **kwargs, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = torch.zeros_like(hidden_states) - residual = hidden_states - - hidden_states = self.norm(hidden_states) + if self.norm.return_hp_output: + hidden_states, residual, high_precision_normed_output = self.norm( + hidden_states, residual) + hidden_states = (hidden_states, high_precision_normed_output) + else: + hidden_states, residual = self.norm(hidden_states, residual) hidden_states = self.mixer(hidden_states, attn_metadata, spec_metadata=spec_metadata, **kwargs) - hidden_states = torch.add(hidden_states, residual) + if spec_metadata is not None and spec_metadata.is_layer_capture( self.layer_idx): spec_metadata.maybe_capture_hidden_states(self.layer_idx, - hidden_states, None) + hidden_states, residual) - return hidden_states + return hidden_states, residual class NemotronHModel(DecoderModel): @@ -426,10 +497,10 @@ class NemotronHModel(DecoderModel): def forward( self, attn_metadata: AttentionMetadata, - input_ids: Optional[torch.IntTensor] = None, - position_ids: Optional[torch.IntTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - spec_metadata: Optional[SpecMetadata] = None, + input_ids: torch.IntTensor | None = None, + position_ids: torch.IntTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + spec_metadata: SpecMetadata | None = None, **kwargs, ) -> torch.Tensor: if (input_ids is None) ^ (inputs_embeds is not None): @@ -443,16 +514,15 @@ class NemotronHModel(DecoderModel): inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds - + residual = torch.zeros_like(hidden_states) for layer in self.layers[:self.num_hidden_layers]: - hidden_states = layer(position_ids, - hidden_states, - attn_metadata, - spec_metadata=spec_metadata, - mamba_metadata=mamba_metadata) - - hidden_states = self.norm_f(hidden_states) - + hidden_states, residual = layer(position_ids, + hidden_states, + residual=residual, + attn_metadata=attn_metadata, + spec_metadata=spec_metadata, + mamba_metadata=mamba_metadata) + hidden_states, _ = self.norm_f(hidden_states, residual) return hidden_states @@ -517,7 +587,7 @@ class NemotronHForCausalLM(SpecDecOneEngineForCausalLM[NemotronHModel, self.epilogue.extend(self.draft_model.mtp_layers) self.epilogue.append(self.spec_worker) - def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper): + def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper): new_weights = weight_mapper.preprocess_weights(weights) super().load_weights(weights=new_weights, weight_mapper=weight_mapper) @@ -528,7 +598,7 @@ class NemotronHMTPDecoderLayer(NemotronHLayer): self, model_config: ModelConfig[NemotronHConfig], layer_idx: int, - aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream], + aux_stream_dict: dict[AuxStreamType, torch.cuda.Stream], has_start_projections: bool, has_end_norm: bool, layer_type: str, @@ -625,7 +695,7 @@ class NemotronHMTPDecoderLayer(NemotronHLayer): positions: torch.Tensor, hidden_states: torch.Tensor, residual: torch.Tensor | None = None, - attn_metadata: Optional[AttentionMetadata] = None, + attn_metadata: AttentionMetadata | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: if self.has_start_projections: @@ -672,7 +742,7 @@ class NemotronHMTP(nn.Module): def __init__(self, model_config: ModelConfig[NemotronHConfig], layer_idx: int, - aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream], + aux_stream_dict: dict[AuxStreamType, torch.cuda.Stream], is_separate_draft_engine: bool = False, prefix: str = ""): super().__init__() @@ -744,8 +814,8 @@ class NemotronHMTP(nn.Module): hidden_states: torch.Tensor, embed_tokens: Embedding, attn_metadata: AttentionMetadata, - all_rank_num_tokens: Optional[List[int]] = None, - spec_metadata: Optional[SpecMetadata] = None, + all_rank_num_tokens: list[int] | None = None, + spec_metadata: SpecMetadata | None = None, **kwargs, ) -> torch.Tensor: inputs_embeds = embed_tokens(input_ids) diff --git a/tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py b/tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py new file mode 100644 index 0000000000..96b604dbfc --- /dev/null +++ b/tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py @@ -0,0 +1,176 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fused elementwise operations for Mamba2 prefill optimization.""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _extract_transpose_prefill_kernel( + src_ptr, + dst_ptr, + num_prefill_tokens, + d_in_proj, + d_inner, + conv_dim, + BLOCK_SEQ: tl.constexpr, + BLOCK_CONV: tl.constexpr, +): + """Extract src[0:num_prefill_tokens, d_inner:d_inner+conv_dim] and + transpose to dst[conv_dim, num_prefill_tokens].""" + pid_seq = tl.program_id(0) + pid_conv = tl.program_id(1) + + seq_offsets = pid_seq * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) + conv_offsets = pid_conv * BLOCK_CONV + tl.arange(0, BLOCK_CONV) + + seq_mask = seq_offsets < num_prefill_tokens + conv_mask = conv_offsets < conv_dim + mask = seq_mask[:, None] & conv_mask[None, :] + + src_offsets = seq_offsets[:, None] * d_in_proj + (d_inner + conv_offsets[None, :]) + data = tl.load(src_ptr + src_offsets, mask=mask, other=0.0) + + dst_offsets = conv_offsets[:, None] * num_prefill_tokens + seq_offsets[None, :] + tl.store(dst_ptr + dst_offsets, tl.trans(data), mask=conv_mask[:, None] & seq_mask[None, :]) + + +def extract_transpose_xbc_prefill( + zxbcdt: torch.Tensor, + num_prefill_tokens: int, + d_inner: int, + conv_dim: int, +) -> torch.Tensor: + """ + Extract and transpose xbc slice from zxbcdt for causal_conv1d_fn. + + Input: zxbcdt[num_tokens, d_in_proj] + Output: [conv_dim, num_prefill_tokens] + """ + out = torch.empty(conv_dim, num_prefill_tokens, dtype=zxbcdt.dtype, device=zxbcdt.device) + + BLOCK_SEQ, BLOCK_CONV = 32, 128 + grid = (triton.cdiv(num_prefill_tokens, BLOCK_SEQ), triton.cdiv(conv_dim, BLOCK_CONV)) + + _extract_transpose_prefill_kernel[grid]( + zxbcdt, + out, + num_prefill_tokens, + zxbcdt.shape[1], + d_inner, + conv_dim, + BLOCK_SEQ, + BLOCK_CONV, + ) + return out + + +@triton.jit +def _fused_conv_output_transpose_kernel( + src_ptr, + out_x_ptr, + out_B_ptr, + out_C_ptr, + num_prefill_tokens, + d_inner, + bc_size, + x_tiles, + bc_tiles, + BLOCK_SEQ: tl.constexpr, + BLOCK_DIM: tl.constexpr, +): + """ + Transpose and split conv1d output into x, B, C using linear grid mapping. + + Grid: tiles [0, x_tiles) -> x, [x_tiles, x_tiles+bc_tiles) -> B, rest -> C + """ + tile_id = tl.program_id(0) + + is_x = tile_id < x_tiles + is_B = (tile_id >= x_tiles) & (tile_id < x_tiles + bc_tiles) + + local_tile = tl.where( + is_x, tile_id, tl.where(is_B, tile_id - x_tiles, tile_id - x_tiles - bc_tiles) + ) + dim_size = tl.where(is_x, d_inner, bc_size) + num_dim_blocks = tl.cdiv(dim_size, BLOCK_DIM) + + pid_seq = local_tile // num_dim_blocks + pid_dim = local_tile % num_dim_blocks + + seq_offsets = pid_seq * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) + dim_offsets = pid_dim * BLOCK_DIM + tl.arange(0, BLOCK_DIM) + + seq_mask = seq_offsets < num_prefill_tokens + dim_mask = dim_offsets < dim_size + + src_offset = tl.where(is_x, 0, tl.where(is_B, d_inner, d_inner + bc_size)) + src_indices = (src_offset + dim_offsets[:, None]) * num_prefill_tokens + seq_offsets[None, :] + data = tl.load(src_ptr + src_indices, mask=dim_mask[:, None] & seq_mask[None, :], other=0.0) + + out_ptr = tl.where(is_x, out_x_ptr, tl.where(is_B, out_B_ptr, out_C_ptr)) + dst_indices = seq_offsets[:, None] * dim_size + dim_offsets[None, :] + tl.store(out_ptr + dst_indices, tl.trans(data), mask=seq_mask[:, None] & dim_mask[None, :]) + + +def fused_split_rearrange_after_conv1d( + xbc: torch.Tensor, + d_inner: int, + n_groups: int, + d_state: int, + nheads: int, + head_dim: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Split and rearrange causal_conv1d output into contiguous x, B, C tensors. + + Input: xbc[conv_dim, num_prefill_tokens] + Output: x[1, num_prefill_tokens, nheads, head_dim], + B[1, num_prefill_tokens, n_groups, d_state], + C[1, num_prefill_tokens, n_groups, d_state] + """ + conv_dim, num_prefill_tokens = xbc.shape + bc_size = n_groups * d_state + + x_flat = torch.empty(num_prefill_tokens, d_inner, dtype=xbc.dtype, device=xbc.device) + B_flat = torch.empty(num_prefill_tokens, bc_size, dtype=xbc.dtype, device=xbc.device) + C_flat = torch.empty(num_prefill_tokens, bc_size, dtype=xbc.dtype, device=xbc.device) + + BLOCK_SEQ, BLOCK_DIM = 64, 64 + num_seq_blocks = triton.cdiv(num_prefill_tokens, BLOCK_SEQ) + x_tiles = num_seq_blocks * triton.cdiv(d_inner, BLOCK_DIM) + bc_tiles = num_seq_blocks * triton.cdiv(bc_size, BLOCK_DIM) + + _fused_conv_output_transpose_kernel[(x_tiles + 2 * bc_tiles,)]( + xbc, + x_flat, + B_flat, + C_flat, + num_prefill_tokens, + d_inner, + bc_size, + x_tiles, + bc_tiles, + BLOCK_SEQ, + BLOCK_DIM, + ) + + return ( + x_flat.view(1, num_prefill_tokens, nheads, head_dim), + B_flat.view(1, num_prefill_tokens, n_groups, d_state), + C_flat.view(1, num_prefill_tokens, n_groups, d_state), + ) diff --git a/tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py b/tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py index 6888dbfaf4..748bd05158 100644 --- a/tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py +++ b/tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py @@ -17,6 +17,8 @@ import math from typing import Tuple import torch +import triton +import triton.language as tl from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \ @@ -25,6 +27,86 @@ from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import \ use_cpp_mamba_cache_manager +@triton.jit +def _cu_seqlens_triton_kernel( + cu_seqlens_ptr, # [num_seqs + 1] + chunk_indices_ptr, # [N] output + chunk_offsets_ptr, # [N] output + num_seqs: tl.constexpr, + chunk_size: tl.constexpr, + N: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Computes chunk_indices and chunk_offsets in a single kernel launch.""" + pid = tl.program_id(0) + chunk_start = pid * BLOCK_SIZE + offsets = chunk_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < N + chunk_indices = offsets.to(tl.int64) + chunk_offsets = tl.zeros([BLOCK_SIZE], dtype=tl.int64) + + p = 0 + for seq_idx in range(num_seqs - 1): + seq_start = tl.load(cu_seqlens_ptr + seq_idx + 1).to(tl.int64) + seq_end = tl.load(cu_seqlens_ptr + seq_idx + 2).to(tl.int64) + is_misaligned = (seq_start % chunk_size) > 0 + p = p + is_misaligned + s_chunk = seq_start // chunk_size + p + e_chunk = seq_end // chunk_size + p + ((seq_end % chunk_size) > 0) + in_range = (offsets >= s_chunk) & (offsets < e_chunk) + chunk_indices = tl.where(in_range & mask, chunk_indices - p, + chunk_indices) + is_start = (offsets == s_chunk) + chunk_offsets = tl.where(is_start & mask, seq_start % chunk_size, + chunk_offsets) + + tl.store(chunk_indices_ptr + offsets, chunk_indices.to(tl.int32), mask=mask) + tl.store(chunk_offsets_ptr + offsets, chunk_offsets.to(tl.int32), mask=mask) + + +def cu_seqlens_to_chunk_indices_offsets_triton( + cu_seqlens: torch.Tensor, + chunk_size: int) -> Tuple[torch.Tensor, torch.Tensor]: + """Optimized version of cu_seqlens_to_chunk_indices_offsets.""" + device = cu_seqlens.device + num_seqs = cu_seqlens.numel() - 1 + + if num_seqs == 0: + return (torch.empty(0, dtype=torch.int, device=device), + torch.empty(0, dtype=torch.int, device=device)) + + cu = cu_seqlens.to(dtype=torch.int64) + total_seqlens = cu[-1].item() + + if num_seqs == 1: + # Fast path for single sequence (no boundaries to process) + N = (total_seqlens + chunk_size - 1) // chunk_size + return (torch.arange(N, device=device, dtype=torch.int), + torch.zeros(N, device=device, dtype=torch.int)) + + seq_starts = cu[1:-1] + misaligned = ((seq_starts % chunk_size) > 0).to(torch.int64) + p = torch.cumsum(misaligned, dim=0) + extra_chunks = p[-1].item() if p.numel() > 0 else 0 + N = (total_seqlens + chunk_size - 1) // chunk_size + extra_chunks + chunk_indices = torch.empty(N, device=device, dtype=torch.int) + chunk_offsets = torch.empty(N, device=device, dtype=torch.int) + + BLOCK_SIZE = 256 + grid = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, ) + _cu_seqlens_triton_kernel[grid]( + cu, + chunk_indices, + chunk_offsets, + num_seqs=num_seqs, + chunk_size=chunk_size, + N=N, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return chunk_indices, chunk_offsets + + def cu_seqlens_to_chunk_indices_offsets( cu_seqlens: torch.Tensor, chunk_size: int) -> Tuple[torch.Tensor, torch.Tensor]: @@ -117,15 +199,15 @@ class Mamba2Metadata: self.state_indices = torch.zeros(max_batch_size, dtype=torch.int32, device="cuda") - self._query_start_loc_long_buf = torch.arange(0, - max_batch_size + 1, - dtype=torch.long, - device="cuda") - self._query_start_loc_buf = torch.zeros(max_batch_size + 1, - dtype=torch.int, - device="cuda") - self.query_start_loc_long = self._query_start_loc_long_buf - self.query_start_loc = self._query_start_loc_buf + + # Pre-allocated buffers. + self._arange_buffer = torch.arange(max_batch_size + 1, + dtype=torch.int, + device="cuda") + self._arange_buffer_long = self._arange_buffer.to(torch.long) + self._cu_seqlens_long = torch.zeros(max_batch_size + 1, + dtype=torch.long, + device="cuda") def prepare(self, attn_metadata: AttentionMetadata): batch_size = attn_metadata.seq_lens.shape[0] @@ -158,47 +240,32 @@ class Mamba2Metadata: dtype=torch.int, out=self.cu_seqlens[1:num_contexts + 1]) torch.add(self.cu_seqlens[num_contexts], - torch.arange(1, - batch_size - num_contexts + 1, - dtype=self.cu_seqlens.dtype, - device=self.cu_seqlens.device), + self._arange_buffer[1:batch_size - num_contexts + 1], out=self.cu_seqlens[num_contexts + 1:batch_size + 1]) # Need both `query_start_loc` and `query_start_loc_long` because `causal_conv1d_fn` # accepts only `int32` while `chunk_gated_delta_rule` accepts only `long`. - self._query_start_loc_buf[:batch_size + - 1] = self.cu_seqlens[:batch_size + 1] - self.query_start_loc = self._query_start_loc_buf[:batch_size + 1] - self._query_start_loc_long_buf[:batch_size + 1].copy_( - self.query_start_loc.to(torch.long), non_blocking=True) - self.query_start_loc_long = self._query_start_loc_long_buf[: - batch_size - + 1] + self.query_start_loc = self.cu_seqlens[:batch_size + 1] + self._cu_seqlens_long[:batch_size + 1].copy_(self.query_start_loc) + self.query_start_loc_long = self._cu_seqlens_long[:batch_size + 1] self.seq_idx = torch.repeat_interleave( - torch.arange(num_contexts, - dtype=torch.int, - device=self.cu_seqlens.device), + self._arange_buffer[:num_contexts], repeats=context_lens, output_size=num_ctx_tokens).unsqueeze(0) num_cached_tokens_per_seq = attn_metadata.kv_cache_params.num_cached_tokens_per_seq - self.has_initial_states[:num_contexts] = torch.tensor( - num_cached_tokens_per_seq[:num_contexts]) > 0 - # precomputed bool to avoid host<->device syncs during forward pass - self.use_initial_states = torch.any( - self.has_initial_states[:num_contexts]).item() + initial_states = [ + num_cached_tokens_per_seq[i] > 0 for i in range(num_contexts) + ] + self.use_initial_states = any(initial_states) if self.use_initial_states: - self.chunk_indices, self.chunk_offsets = cu_seqlens_to_chunk_indices_offsets( + self.has_initial_states[:num_contexts] = torch.tensor( + initial_states, dtype=torch.bool) + self.chunk_indices, self.chunk_offsets = cu_seqlens_to_chunk_indices_offsets_triton( self.cu_seqlens[:num_contexts + 1], self.chunk_size) else: self.chunk_indices = None self.chunk_offsets = None else: self.query_start_loc = None - torch.arange(0, - batch_size + 1, - dtype=torch.long, - device=self.cu_seqlens.device, - out=self._query_start_loc_long_buf[:batch_size + 1]) - self.query_start_loc_long = self._query_start_loc_long_buf[: - batch_size - + 1] + self.query_start_loc_long = self._arange_buffer_long[:batch_size + + 1] diff --git a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py index 1874ea65d4..67c5bff5a4 100644 --- a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py +++ b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py @@ -33,6 +33,8 @@ from ..linear import Linear, TensorParallelMode from .causal_conv1d import causal_conv1d_fn, causal_conv1d_update from .causal_conv1d_triton import \ causal_conv1d_update as causal_conv1d_update_triton +from .fuse_elementwise_ops import (extract_transpose_xbc_prefill, + fused_split_rearrange_after_conv1d) from .layernorm_gated import RMSNorm as RMSNormGated from .selective_state_update import \ selective_state_update as selective_state_update_native @@ -227,15 +229,17 @@ class Mamba2Mixer(nn.Module): # in_proj zxbcdt = self.in_proj(hidden_states) - z, xbc, dt = torch.split( - zxbcdt, - [self.tp_d_inner, self.tp_conv_dim, self.tp_nheads], - dim=-1, - ) + + # Split z and dt with views. + z = zxbcdt[:, :self.tp_d_inner] + dt = zxbcdt[:, self.tp_d_inner + self.tp_conv_dim:] z_p, z_d = torch.split(z, seqlen_split_size, dim=0) - xbc_p, xbc_d = torch.split(xbc, seqlen_split_size, dim=0) dt_p, dt_d = torch.split(dt, seqlen_split_size, dim=0) + # Decode path uses regular view since no transpose is needed. + xbc_d = zxbcdt[num_prefill_tokens:num_actual_tokens, + self.tp_d_inner:self.tp_d_inner + self.tp_conv_dim] + # Preallocate output tensor to avoid memcpy cost for merging prefill # and decode outputs preallocated_ssm_out = torch.empty( @@ -243,8 +247,8 @@ class Mamba2Mixer(nn.Module): zxbcdt.shape[0], (self.num_heads * self.head_dim) // self.tp_size, ], - dtype=hidden_states.dtype, - device=hidden_states.device, + dtype=zxbcdt.dtype, + device=zxbcdt.device, ) preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split( preallocated_ssm_out, @@ -259,27 +263,29 @@ class Mamba2Mixer(nn.Module): has_initial_states = mamba_metadata.has_initial_states[: num_prefills] - xbc_p = causal_conv1d_fn(xbc_p.transpose(0, 1), + # Fused kernel to avoid expensive .contiguous() call in causal_conv1d_fn. + xbc_p_t = extract_transpose_xbc_prefill(zxbcdt, num_prefill_tokens, + self.tp_d_inner, + self.tp_conv_dim) + xbc_p = causal_conv1d_fn(xbc_p_t, self.conv1d.weight, self.conv1d.bias, activation="silu", conv_states=conv_states, has_initial_state=has_initial_states, query_start_loc=cu_seqlens, - cache_indices=state_indices_p).transpose( - 0, 1) + cache_indices=state_indices_p) - x_p, B_p, C_p = torch.split(xbc_p.unsqueeze(0), [ + # Fused kernel to avoid expensive .contiguous() calls after split/rearrange. + x_p, B_p, C_p = fused_split_rearrange_after_conv1d( + xbc_p, self.tp_d_inner, - self.tp_ngroups * self.d_state, - self.tp_ngroups * self.d_state, - ], - dim=-1) - - x_p = rearrange(x_p, "b l (h p) -> b l h p", h=self.tp_nheads) + self.tp_ngroups, + self.d_state, + self.tp_nheads, + self.head_dim, + ) dt_p = dt_p.unsqueeze(0) - B_p = rearrange(B_p, "b l (g n) -> b l g n", g=self.tp_ngroups) - C_p = rearrange(C_p, "b l (g n) -> b l g n", g=self.tp_ngroups) z_p = rearrange(z_p.unsqueeze(0), "b l (h p) -> b l h p", h=self.tp_nheads) diff --git a/tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py b/tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py index c41e6b47f0..7863148e37 100644 --- a/tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py +++ b/tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py @@ -26,6 +26,124 @@ TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0") @triton.autotune( configs=[ + # ================================================================= + # Higher warp count configs for better latency hiding + # More warps = more instructions in flight = better memory latency hiding + # ================================================================= + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32 + }, + num_stages=2, + num_warps=8, # 8 warps = 256 threads per block + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32 + }, + num_stages=2, + num_warps=8, # 8 warps for better latency hiding + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32 + }, + num_stages=2, + num_warps=8, + ), + # Smaller tiles with more stages for software pipelining + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32 + }, + num_stages=3, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64 + }, + num_stages=2, + num_warps=4, + ), + # ================================================================= + # Low register pressure configs (num_stages=1) for large dstate + # ================================================================= + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64 + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32 + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32 + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32 + }, + num_stages=1, + num_warps=4, + ), + # num_stages=2 configs - moderate register pressure + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64 + }, + num_stages=2, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32 + }, + num_stages=2, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32 + }, + num_stages=2, + num_warps=4, + ), + # Original configs for smaller dstate values triton.Config( { "BLOCK_SIZE_M": 128, @@ -355,14 +473,17 @@ def _chunk_scan_fwd_kernel( if not HAS_INITSTATES: # - this is for continuous batching where there is no init states - scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), + # Use exp2 for faster computation: exp(x) = exp2(x * log2(e)) + scale_m = tl.where(seq_idx_m == seq_idx_prev, + tl.math.exp2(dA_cs_m * 1.4426950408889634), 0.0) else: # - if there is initstates, we will rely on prev_states, no zeroing # required. - scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary) + scale_m = tl.math.exp2( + (dA_cs_m - dA_cs_m_boundary) * 1.4426950408889634) else: - scale_m = tl.exp(dA_cs_m) + scale_m = tl.math.exp2(dA_cs_m * 1.4426950408889634) if BLOCK_SIZE_DSTATE <= 128: C = tl.load( C_ptrs, @@ -421,7 +542,9 @@ def _chunk_scan_fwd_kernel( other=0.0).to(tl.float32) # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. # So we don't need masking wrt seq_idx here. - cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :]) + # Use exp2 for faster computation: exp(x) = exp2(x * log2(e)) + cb *= tl.math.exp2( + (dA_cs_m[:, None] - dA_cs_k[None, :]) * 1.4426950408889634) dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) cb *= dt_k diff --git a/tensorrt_llm/_torch/modules/mamba/ssd_chunk_state.py b/tensorrt_llm/_torch/modules/mamba/ssd_chunk_state.py index b58f89eb77..bfd9c448e3 100644 --- a/tensorrt_llm/_torch/modules/mamba/ssd_chunk_state.py +++ b/tensorrt_llm/_torch/modules/mamba/ssd_chunk_state.py @@ -128,6 +128,54 @@ def _chunk_cumsum_fwd_kernel( @triton.autotune( configs=[ + # Small headdim/dstate configs (hdim<=64, dstate<=128) - increased parallelism + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32 + }, + num_stages=3, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32 + }, + num_stages=3, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32 + }, + num_stages=3, + num_warps=4, + ), + # Low register pressure configs for large dstate (dstate=128) + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64 + }, + num_stages=2, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64 + }, + num_stages=2, + num_warps=4, + ), + # Original configs for larger headdim/dstate values triton.Config( { "BLOCK_SIZE_M": 128, @@ -175,40 +223,13 @@ def _chunk_cumsum_fwd_kernel( ), triton.Config( { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32 }, num_stages=4, num_warps=4, ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32 - }, - num_stages=5, - num_warps=2, - ), - triton.Config( - { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32 - }, - num_stages=5, - num_warps=2, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32 - }, - num_stages=4, - num_warps=2, - ), ], key=["hdim", "dstate", "chunk_size"], ) @@ -351,6 +372,41 @@ def _chunk_state_fwd_kernel( @triton.autotune( configs=[ + # Small headdim/dstate configs for better parallelism + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32 + }, + num_stages=3, + num_warps=4), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32 + }, + num_stages=3, + num_warps=4), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32 + }, + num_stages=3, + num_warps=4), + # Low register pressure configs + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64 + }, + num_stages=2, + num_warps=4), + # Original configs for larger dimensions triton.Config( { "BLOCK_SIZE_M": 128, @@ -393,36 +449,12 @@ def _chunk_state_fwd_kernel( num_warps=4), triton.Config( { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32 }, num_stages=4, num_warps=4), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32 - }, - num_stages=5, - num_warps=2), - triton.Config( - { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32 - }, - num_stages=5, - num_warps=2), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32 - }, - num_stages=4, - num_warps=2), ], key=["hdim", "dstate", "chunk_size"], ) diff --git a/tensorrt_llm/_torch/modules/mlp.py b/tensorrt_llm/_torch/modules/mlp.py index d121457b48..cd3518cd7f 100644 --- a/tensorrt_llm/_torch/modules/mlp.py +++ b/tensorrt_llm/_torch/modules/mlp.py @@ -8,23 +8,25 @@ from tensorrt_llm.mapping import Mapping from ..model_config import ModelConfig from ..peft.lora.layer import LoraLayer, LoraModuleType +from ..utils import Fp4QuantizedTensor, relu2 from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig class MLP(nn.Module): - def __init__(self, - *, - hidden_size: int, - intermediate_size: int, - bias: bool, - activation: Callable[[torch.Tensor], torch.Tensor] = None, - dtype: Optional[torch.dtype] = None, - config: Optional[ModelConfig] = None, - layer_idx: Optional[int] = None, - reduce_output: bool = True, - overridden_tp_size: Optional[int] = None): - + def __init__( + self, + *, + hidden_size: int, + intermediate_size: int, + bias: bool, + activation: Callable[[torch.Tensor], torch.Tensor] = None, + dtype: Optional[torch.dtype] = None, + config: Optional[ModelConfig] = None, + layer_idx: Optional[int] = None, + reduce_output: bool = True, + overridden_tp_size: Optional[int] = None, + ): super().__init__() self.layer_idx = layer_idx self.hidden_size = hidden_size @@ -81,7 +83,22 @@ class MLP(nn.Module): lora=self.down_lora, allreduce_strategy=config.allreduce_strategy, force_dynamic_quantization=config.force_dynamic_quantization, - reduce_output=reduce_output) + reduce_output=reduce_output, + ) + + self._use_fused_relu2_quant = False + + def create_weights(self): + self.up_proj.create_weights() + self.down_proj.create_weights() + + has_nvfp4 = hasattr(self.down_proj, + 'has_nvfp4') and self.down_proj.has_nvfp4 + has_kernel = hasattr(torch.ops.trtllm, 'fused_relu2_quantize') + has_scale = hasattr(self.down_proj, 'input_scale') + is_relu2 = self.activation is relu2 + + self._use_fused_relu2_quant = has_nvfp4 and has_kernel and has_scale and is_relu2 def forward( self, @@ -92,11 +109,34 @@ class MLP(nn.Module): return self.forward_lora(x, lora_params=lora_params) x_up = self.up_proj(x) - x_act = self.activation(x_up) + + if self._use_fused_relu2_quant: + x_act = self._fused_relu2_quant(x_up) + else: + x_act = self.activation(x_up) + x_down = self.down_proj(x_act) return x_down + def _fused_relu2_quant(self, x: torch.Tensor) -> Fp4QuantizedTensor: + x_flat = x.view(-1, x.shape[-1]) + + if not x_flat.is_contiguous(): + x_flat = x_flat.contiguous() + + if x_flat.dtype not in (torch.float16, torch.bfloat16): + x_flat = x_flat.to(torch.bfloat16) + + fp4_tensor, sf_tensor = torch.ops.trtllm.fused_relu2_quantize( + x_flat, self.down_proj.input_scale, 16) + + return Fp4QuantizedTensor( + fp4_tensor=fp4_tensor, + scaling_factor=sf_tensor, + is_sf_swizzled=True, + ) + def forward_lora( self, x: torch.Tensor, diff --git a/tensorrt_llm/_torch/modules/rms_norm.py b/tensorrt_llm/_torch/modules/rms_norm.py index d6e0a5994b..3bf1d8fb30 100644 --- a/tensorrt_llm/_torch/modules/rms_norm.py +++ b/tensorrt_llm/_torch/modules/rms_norm.py @@ -41,6 +41,7 @@ class RMSNorm(nn.Module): use_gemma: bool = False, quantize_type: Optional[str] = None, use_cuda_tile: bool = False, + return_hp_output: bool = False, ): super().__init__() @@ -72,6 +73,7 @@ class RMSNorm(nn.Module): self.variance_epsilon = eps self.use_gemma = use_gemma self.use_cuda_tile = use_cuda_tile + self.return_hp_output = return_hp_output def forward( self, @@ -80,7 +82,8 @@ class RMSNorm(nn.Module): Optional[torch.Tensor], _ArgumentNotSpecifiedSentinelType] = _ARGUMENT_NOT_SPECIFIED_SENTINEL, ) -> Union[torch.Tensor, Fp4QuantizedTensor, Tuple[Union[ - torch.Tensor, Fp4QuantizedTensor], Optional[torch.Tensor]]]: + torch.Tensor, Fp4QuantizedTensor], Optional[torch.Tensor]], Tuple[ + Fp4QuantizedTensor, torch.Tensor, torch.Tensor]]: has_residual = residual is not self._ARGUMENT_NOT_SPECIFIED_SENTINEL if not has_residual: residual = None @@ -116,14 +119,16 @@ class RMSNorm(nn.Module): sf_scale = nvfp4_scale.contiguous() - normed_fp4_i32, residual_out_2d, sf_fused = torch.ops.trtllm.fused_add_rms_norm_quant( + results = torch.ops.trtllm.fused_add_rms_norm_quant( hs_2d, res_2d, gamma, sf_scale, True, eps=self.variance_epsilon, + output_hp_norm=self.return_hp_output, ) + normed_fp4_i32, residual_out_2d, sf_fused = results[:3] normed_fp4_u8 = normed_fp4_i32.view(torch.uint8) if len(orig_shape) != 2: normed_fp4_u8 = normed_fp4_u8.reshape(*orig_shape[:-1], n // 2) @@ -132,9 +137,21 @@ class RMSNorm(nn.Module): residual_out = residual_out_2d hidden_states_fused = Fp4QuantizedTensor(normed_fp4_u8, sf_fused) - return (hidden_states_fused, - residual_out) if has_residual else hidden_states_fused - elif self.use_cuda_tile: + + outputs = [hidden_states_fused] + if has_residual: + outputs.append(residual_out) + if self.return_hp_output: + high_precision_normed_output = results[3].reshape(orig_shape) + outputs.append(high_precision_normed_output) + return outputs[0] if len(outputs) == 1 else tuple(outputs) + + if self.return_hp_output: + raise ValueError( + "Auxiliary high precision output is only supported for NVFP4 fused path" + ) + + if self.use_cuda_tile: if isinstance(residual, torch.Tensor): # Use fused residual kernel hidden_states = hidden_states.contiguous() diff --git a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py index b8e4a04575..03ca777253 100644 --- a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py @@ -593,7 +593,9 @@ class MambaCacheManager(BaseResourceManager): return self._impl.get_intermediate_conv_states(layer_idx) def is_speculative(self) -> bool: - assert not self._use_cpp, "is_speculative is not supported in CppMambaCacheManager" + if self._use_cpp: + # CppMambaCacheManager does not support speculative decoding for now. + return False return self._impl.is_speculative() def mamba_layer_cache( diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/runner.py b/tensorrt_llm/tools/layer_wise_benchmarks/runner.py index 5b5355a28d..72e4bd046e 100644 --- a/tensorrt_llm/tools/layer_wise_benchmarks/runner.py +++ b/tensorrt_llm/tools/layer_wise_benchmarks/runner.py @@ -1,5 +1,6 @@ import contextlib import functools +import inspect import itertools import os import unittest.mock @@ -443,9 +444,9 @@ class Runner: def forward(position_ids, hidden_states, attn_metadata, residual, **kwargs): # TODO: to be more general, we should call DecoderModel.forward - residual_fusion = hasattr(model.model.layers[layer_indices[0]], "next_layer_layernorm") for layer_idx in layer_indices: layer = model.model.layers[layer_idx] + residual_fusion = "residual" in inspect.signature(layer.forward).parameters if residual_fusion: hidden_states, residual = layer( position_ids, hidden_states, attn_metadata, residual, **kwargs diff --git a/tests/unittest/_torch/modules/mamba/test_causal_conv1d.py b/tests/unittest/_torch/modules/mamba/test_causal_conv1d.py new file mode 100644 index 0000000000..1c74bb6173 --- /dev/null +++ b/tests/unittest/_torch/modules/mamba/test_causal_conv1d.py @@ -0,0 +1,246 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +import torch.nn.functional as F + +from tensorrt_llm._torch.modules.mamba import PAD_SLOT_ID + + +def mamba_conv1d_ref(x, past_conv_state, conv_weight, conv_bias, apply_silu): + """ + Reference implementation for causal conv1d. + + Arguments: + x: [batch_size, dim, seq_len] + past_conv_state: [batch_size, dim, dconv-1] + conv_weight: [dim, 1, dconv] + conv_bias: [dim] + Output: + y: [batch_size, dim, seq_len] + present_conv_state: [batch_size, dim, dconv-1] + """ + assert x.dim() == 3 + assert past_conv_state.dim() == 3 + assert conv_weight.dim() == 3 + assert conv_bias.dim() == 1 + batch_size, dim, seq_len = x.shape + assert past_conv_state.shape[0] == batch_size + assert past_conv_state.shape[1] == dim + dconv = past_conv_state.shape[2] + 1 + assert conv_weight.shape[0] == dim + assert conv_weight.shape[1] == 1 + assert conv_weight.shape[2] == dconv + + padded_x = torch.cat([past_conv_state, x], dim=2) + present_conv_state = padded_x[:, :, -(dconv - 1) :] + x_conv = F.conv1d(padded_x, conv_weight, bias=conv_bias, groups=dim) + + y = F.silu(x_conv) if apply_silu else x_conv + return y, present_conv_state + + +def trtllm_causal_conv1d_available(): + """Check if trtllm.causal_conv1d_fwd is available.""" + return hasattr(torch.ops, "trtllm") and hasattr(torch.ops.trtllm, "causal_conv1d_fwd") + + +skip_unsupported = pytest.mark.skipif( + not torch.cuda.is_available() or not trtllm_causal_conv1d_available(), + reason="Requires CUDA and trtllm.causal_conv1d_fwd op", +) + + +@skip_unsupported +class TestCausalConv1d: + """Tests for the causal_conv1d CUDA kernel.""" + + @pytest.mark.parametrize("dtype", ["float16", "bfloat16", "float32"]) + @pytest.mark.parametrize("apply_silu", [True, False]) + @pytest.mark.parametrize("dim", [256, 512, 1024, 2048]) + def test_basic_correctness(self, dtype, apply_silu, dim): + """Test basic correctness against reference implementation.""" + torch.manual_seed(42) + device = "cuda" + torch_dtype = getattr(torch, dtype) + + batch_size = 4 + seq_len = 32 + dconv = 4 + std_dev = 0.5 + x = torch.randn(batch_size, dim, seq_len, dtype=torch_dtype, device=device) + x = x * std_dev + conv_state = torch.zeros(batch_size, dim, dconv - 1, dtype=torch_dtype, device=device) + conv_weight = torch.randn(dim, 1, dconv, dtype=torch_dtype, device=device) + conv_bias = torch.randn(dim, dtype=torch_dtype, device=device) + x_kernel = x.clone() + conv_state_kernel = conv_state.clone() + + conv_weight_input = conv_weight.squeeze(1).contiguous() + torch.ops.trtllm.causal_conv1d_fwd( + x_kernel, + conv_weight_input, + conv_bias, + conv_state_kernel, + None, # query_start_loc + None, # cache_indices + None, # has_initial_state + apply_silu, + PAD_SLOT_ID, + ) + out_ref, conv_state_ref = mamba_conv1d_ref( + x, conv_state, conv_weight, conv_bias, apply_silu + ) + + torch.testing.assert_close(x_kernel, out_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(conv_state_kernel, conv_state_ref, rtol=1e-2, atol=1e-2) + + @pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) + def test_various_batch_sizes(self, batch_size): + """Test with various batch sizes.""" + torch.manual_seed(42) + device = "cuda" + dtype = torch.bfloat16 + dim = 1024 + seq_len = 64 + dconv = 4 + apply_silu = True + + x = torch.randn(batch_size, dim, seq_len, dtype=dtype, device=device) * 0.5 + conv_state = torch.zeros(batch_size, dim, dconv - 1, dtype=dtype, device=device) + conv_weight = torch.randn(dim, 1, dconv, dtype=dtype, device=device) + conv_bias = torch.randn(dim, dtype=dtype, device=device) + x_kernel = x.clone() + conv_state_kernel = conv_state.clone() + + conv_weight_input = conv_weight.squeeze(1).contiguous() + torch.ops.trtllm.causal_conv1d_fwd( + x_kernel, + conv_weight_input, + conv_bias, + conv_state_kernel, + None, + None, + None, + apply_silu, + PAD_SLOT_ID, + ) + out_ref, conv_state_ref = mamba_conv1d_ref( + x, conv_state, conv_weight, conv_bias, apply_silu + ) + + torch.testing.assert_close(x_kernel, out_ref, rtol=1e-2, atol=1e-1) + torch.testing.assert_close(conv_state_kernel, conv_state_ref, rtol=1e-2, atol=1e-1) + + @pytest.mark.parametrize("dconv", [2, 3, 4]) + def test_various_kernel_widths(self, dconv): + """Test with different convolution kernel widths.""" + torch.manual_seed(42) + device = "cuda" + dtype = torch.bfloat16 + + batch_size = 4 + dim = 1024 + seq_len = 64 + apply_silu = True + x = torch.randn(batch_size, dim, seq_len, dtype=dtype, device=device) * 0.5 + conv_state = torch.zeros(batch_size, dim, dconv - 1, dtype=dtype, device=device) + conv_weight = torch.randn(dim, 1, dconv, dtype=dtype, device=device) + conv_bias = torch.randn(dim, dtype=dtype, device=device) + x_kernel = x.clone() + conv_state_kernel = conv_state.clone() + + conv_weight_input = conv_weight.squeeze(1).contiguous() + torch.ops.trtllm.causal_conv1d_fwd( + x_kernel, + conv_weight_input, + conv_bias, + conv_state_kernel, + None, + None, + None, + apply_silu, + PAD_SLOT_ID, + ) + out_ref, conv_state_ref = mamba_conv1d_ref( + x, conv_state, conv_weight, conv_bias, apply_silu + ) + + torch.testing.assert_close(x_kernel, out_ref, rtol=1e-2, atol=1e-1) + torch.testing.assert_close(conv_state_kernel, conv_state_ref, rtol=1e-2, atol=1e-1) + + def test_with_initial_state(self): + """Test with non-zero initial conv state.""" + torch.manual_seed(42) + device = "cuda" + dtype = torch.bfloat16 + + batch_size = 4 + dim = 1024 + seq_len = 32 + dconv = 4 + apply_silu = True + + x = torch.randn(batch_size, dim, seq_len, dtype=dtype, device=device) * 0.5 + # Non-zero initial state + conv_state = torch.randn(batch_size, dim, dconv - 1, dtype=dtype, device=device) + conv_state = conv_state * 0.5 + conv_weight = torch.randn(dim, 1, dconv, dtype=dtype, device=device) + conv_bias = torch.randn(dim, dtype=dtype, device=device) + conv_state_kernel = conv_state.clone() + # Need to tell the kernel about initial state + has_initial_state = torch.ones(batch_size, dtype=torch.bool, device=device) + query_start_loc = torch.tensor( + [0] + [seq_len * (i + 1) for i in range(batch_size)], + dtype=torch.int32, + device=device, + ) + # Reshape for varlen format + x_varlen = x.transpose(1, 2).reshape(-1, dim).T.contiguous() + + conv_weight_input = conv_weight.squeeze(1).contiguous() + torch.ops.trtllm.causal_conv1d_fwd( + x_varlen, + conv_weight_input, + conv_bias, + conv_state_kernel, + query_start_loc, + None, # cache_indices + has_initial_state, + apply_silu, + PAD_SLOT_ID, + ) + + out_ref_list = [] + conv_state_ref_list = [] + for b in range(batch_size): + out_b, state_b = mamba_conv1d_ref( + x[b : b + 1], + conv_state[b : b + 1], + conv_weight, + conv_bias, + apply_silu, + ) + out_ref_list.append(out_b) + conv_state_ref_list.append(state_b) + out_ref = torch.cat(out_ref_list, dim=0) + conv_state_ref = torch.cat(conv_state_ref_list, dim=0) + x_kernel_reshaped = ( + x_varlen.T.reshape(batch_size, seq_len, dim).transpose(1, 2).contiguous() + ) + + torch.testing.assert_close(x_kernel_reshaped, out_ref, rtol=1e-2, atol=1e-1) + torch.testing.assert_close(conv_state_kernel, conv_state_ref, rtol=1e-2, atol=1e-1) diff --git a/tests/unittest/_torch/modules/mamba/test_fuse_elementwise_ops.py b/tests/unittest/_torch/modules/mamba/test_fuse_elementwise_ops.py new file mode 100644 index 0000000000..0b07a309c8 --- /dev/null +++ b/tests/unittest/_torch/modules/mamba/test_fuse_elementwise_ops.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for fused elementwise operations in Mamba2 prefill.""" + +import pytest +import torch + +from tensorrt_llm._torch.modules.mamba.fuse_elementwise_ops import ( + extract_transpose_xbc_prefill, + fused_split_rearrange_after_conv1d, +) + +skip_no_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA required for triton kernels", +) + + +def extract_transpose_xbc_prefill_ref( + zxbcdt: torch.Tensor, + num_prefill_tokens: int, + d_inner: int, + conv_dim: int, +) -> torch.Tensor: + """Reference implementation for extract_transpose_xbc_prefill.""" + # Extract the xbc slice and transpose + xbc = zxbcdt[:num_prefill_tokens, d_inner : d_inner + conv_dim] + return xbc.transpose(0, 1).contiguous() + + +def fused_split_rearrange_after_conv1d_ref( + xbc: torch.Tensor, + d_inner: int, + n_groups: int, + d_state: int, + nheads: int, + head_dim: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Reference implementation for fused_split_rearrange_after_conv1d.""" + conv_dim, num_prefill_tokens = xbc.shape + bc_size = n_groups * d_state + + # Transpose and split + xbc_t = xbc.transpose(0, 1).contiguous() # [num_prefill_tokens, conv_dim] + x, B, C = torch.split(xbc_t, [d_inner, bc_size, bc_size], dim=-1) + x = x.contiguous().view(1, num_prefill_tokens, nheads, head_dim) + B = B.contiguous().view(1, num_prefill_tokens, n_groups, d_state) + C = C.contiguous().view(1, num_prefill_tokens, n_groups, d_state) + return x, B, C + + +@skip_no_cuda +@pytest.mark.parametrize("num_prefill_tokens", [1, 32, 128, 1024]) +@pytest.mark.parametrize( + "d_inner,conv_dim,d_in_proj", [(256, 512, 800), (512, 1024, 1600), (1024, 2048, 3200)] +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_extract_transpose_xbc_prefill(num_prefill_tokens, d_inner, conv_dim, d_in_proj, dtype): + """Test extract_transpose_xbc_prefill matches reference implementation.""" + torch.manual_seed(42) + device = torch.device("cuda") + + num_total_tokens = num_prefill_tokens + 16 + zxbcdt = torch.randn(num_total_tokens, d_in_proj, dtype=dtype, device=device) + out_ref = extract_transpose_xbc_prefill_ref(zxbcdt, num_prefill_tokens, d_inner, conv_dim) + out_fused = extract_transpose_xbc_prefill(zxbcdt, num_prefill_tokens, d_inner, conv_dim) + + assert out_fused.shape == out_ref.shape, f"Shape mismatch: {out_fused.shape} vs {out_ref.shape}" + torch.testing.assert_close(out_fused, out_ref, rtol=1e-3, atol=1e-3) + + +@skip_no_cuda +@pytest.mark.parametrize("num_prefill_tokens", [1, 32, 128, 1024]) +@pytest.mark.parametrize( + "nheads,head_dim,n_groups,d_state", [(8, 64, 1, 128), (16, 64, 2, 64), (32, 64, 4, 64)] +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_split_rearrange_after_conv1d( + num_prefill_tokens, nheads, head_dim, n_groups, d_state, dtype +): + """Test fused_split_rearrange_after_conv1d matches reference implementation.""" + torch.manual_seed(42) + device = torch.device("cuda") + + d_inner = nheads * head_dim + bc_size = n_groups * d_state + conv_dim = d_inner + 2 * bc_size + xbc = torch.randn(conv_dim, num_prefill_tokens, dtype=dtype, device=device) + x_ref, B_ref, C_ref = fused_split_rearrange_after_conv1d_ref( + xbc, d_inner, n_groups, d_state, nheads, head_dim + ) + x_fused, B_fused, C_fused = fused_split_rearrange_after_conv1d( + xbc, d_inner, n_groups, d_state, nheads, head_dim + ) + + assert x_fused.shape == x_ref.shape, f"x shape mismatch: {x_fused.shape} vs {x_ref.shape}" + assert B_fused.shape == B_ref.shape, f"B shape mismatch: {B_fused.shape} vs {B_ref.shape}" + assert C_fused.shape == C_ref.shape, f"C shape mismatch: {C_fused.shape} vs {C_ref.shape}" + torch.testing.assert_close(x_fused, x_ref, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(B_fused, B_ref, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(C_fused, C_ref, rtol=1e-3, atol=1e-3) diff --git a/tests/unittest/_torch/modules/mamba/test_mamba2_metadata.py b/tests/unittest/_torch/modules/mamba/test_mamba2_metadata.py new file mode 100644 index 0000000000..ec5d72165c --- /dev/null +++ b/tests/unittest/_torch/modules/mamba/test_mamba2_metadata.py @@ -0,0 +1,133 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for Mamba2 metadata preparation optimizations.""" + +import pytest +import torch + +from tensorrt_llm._torch.modules.mamba.mamba2_metadata import ( + cu_seqlens_to_chunk_indices_offsets, + cu_seqlens_to_chunk_indices_offsets_triton, +) + +skip_no_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA required for triton kernels", +) + + +@skip_no_cuda +class TestCuSeqlensToChunkIndicesOffsets: + """Tests for cu_seqlens_to_chunk_indices_offsets_triton function.""" + + def test_empty_sequence(self): + """Test with empty cu_seqlens (no sequences).""" + cu_seqlens = torch.tensor([0], dtype=torch.int, device="cuda") + chunk_size = 8 + + indices_triton, offsets_triton = cu_seqlens_to_chunk_indices_offsets_triton( + cu_seqlens, chunk_size + ) + + assert indices_triton.numel() == 0 + assert offsets_triton.numel() == 0 + + def test_single_sequence_aligned(self): + """Test with a single sequence that aligns with chunk size.""" + cu_seqlens = torch.tensor([0, 16], dtype=torch.int, device="cuda") + chunk_size = 8 + + indices_ref, offsets_ref = cu_seqlens_to_chunk_indices_offsets(cu_seqlens, chunk_size) + indices_triton, offsets_triton = cu_seqlens_to_chunk_indices_offsets_triton( + cu_seqlens, chunk_size + ) + + torch.testing.assert_close(indices_triton, indices_ref) + torch.testing.assert_close(offsets_triton, offsets_ref) + + def test_single_sequence_unaligned(self): + """Test with a single sequence that doesn't align with chunk size.""" + cu_seqlens = torch.tensor([0, 10], dtype=torch.int, device="cuda") + chunk_size = 8 + + indices_ref, offsets_ref = cu_seqlens_to_chunk_indices_offsets(cu_seqlens, chunk_size) + indices_triton, offsets_triton = cu_seqlens_to_chunk_indices_offsets_triton( + cu_seqlens, chunk_size + ) + + torch.testing.assert_close(indices_triton, indices_ref) + torch.testing.assert_close(offsets_triton, offsets_ref) + + def test_two_sequences_aligned(self): + """Test with two sequences, both aligned with chunk boundaries.""" + cu_seqlens = torch.tensor([0, 8, 16], dtype=torch.int, device="cuda") + chunk_size = 8 + + indices_ref, offsets_ref = cu_seqlens_to_chunk_indices_offsets(cu_seqlens, chunk_size) + indices_triton, offsets_triton = cu_seqlens_to_chunk_indices_offsets_triton( + cu_seqlens, chunk_size + ) + + torch.testing.assert_close(indices_triton, indices_ref) + torch.testing.assert_close(offsets_triton, offsets_ref) + + def test_two_sequences_misaligned(self): + """Test with two sequences where second starts at misaligned position.""" + # Example from docstring: cu_seqlens = [0, 5, 10], chunk_size = 8 + # -> chunk_indices = [0, 0, 1], chunk_offsets = [0, 5, 0] + cu_seqlens = torch.tensor([0, 5, 10], dtype=torch.int, device="cuda") + chunk_size = 8 + + indices_ref, offsets_ref = cu_seqlens_to_chunk_indices_offsets(cu_seqlens, chunk_size) + indices_triton, offsets_triton = cu_seqlens_to_chunk_indices_offsets_triton( + cu_seqlens, chunk_size + ) + + # Verify against expected values from docstring + expected_indices = torch.tensor([0, 0, 1], dtype=torch.int, device="cuda") + expected_offsets = torch.tensor([0, 5, 0], dtype=torch.int, device="cuda") + + torch.testing.assert_close(indices_ref, expected_indices) + torch.testing.assert_close(offsets_ref, expected_offsets) + + torch.testing.assert_close(indices_triton, indices_ref) + torch.testing.assert_close(offsets_triton, offsets_ref) + + @pytest.mark.parametrize("chunk_size", [8, 16, 32, 64, 128]) + def test_multiple_sequences_various_chunk_sizes(self, chunk_size): + """Test with multiple sequences and various chunk sizes.""" + # Create sequences with varying lengths + cu_seqlens = torch.tensor([0, 10, 25, 40, 60, 75], dtype=torch.int, device="cuda") + + indices_ref, offsets_ref = cu_seqlens_to_chunk_indices_offsets(cu_seqlens, chunk_size) + indices_triton, offsets_triton = cu_seqlens_to_chunk_indices_offsets_triton( + cu_seqlens, chunk_size + ) + + torch.testing.assert_close(indices_triton, indices_ref) + torch.testing.assert_close(offsets_triton, offsets_ref) + + def test_all_sequences_within_one_chunk(self): + """Test when all sequences fit within a single chunk.""" + cu_seqlens = torch.tensor([0, 2, 4, 6], dtype=torch.int, device="cuda") + chunk_size = 64 # Large chunk size + + indices_ref, offsets_ref = cu_seqlens_to_chunk_indices_offsets(cu_seqlens, chunk_size) + indices_triton, offsets_triton = cu_seqlens_to_chunk_indices_offsets_triton( + cu_seqlens, chunk_size + ) + + torch.testing.assert_close(indices_triton, indices_ref) + torch.testing.assert_close(offsets_triton, offsets_ref) diff --git a/tests/unittest/_torch/modules/test_fused_activation_quant.py b/tests/unittest/_torch/modules/test_fused_activation_quant.py new file mode 100644 index 0000000000..ba743ad64a --- /dev/null +++ b/tests/unittest/_torch/modules/test_fused_activation_quant.py @@ -0,0 +1,223 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for fused relu2 + NVFP4 quantization kernel.""" + +import pytest +import torch +import torch.nn.functional as F + +from tests.unittest.utils.util import getSMVersion + + +def fused_relu2_quantize_available(): + """Check if the fused_relu2_quantize op is available.""" + return hasattr(torch.ops, "trtllm") and hasattr(torch.ops.trtllm, "fused_relu2_quantize") + + +def fp4_quantize_available(): + """Check if the fp4_quantize op is available.""" + return hasattr(torch.ops, "trtllm") and hasattr(torch.ops.trtllm, "fp4_quantize") + + +skip_unless_fused_relu2_quantize = pytest.mark.skipif( + getSMVersion() < 100 or not fused_relu2_quantize_available(), + reason="Requires SM100+ (Blackwell) and trtllm.fused_relu2_quantize op", +) + +skip_unless_fused_relu2_and_fp4_quantize = pytest.mark.skipif( + getSMVersion() < 100 or not fused_relu2_quantize_available() or not fp4_quantize_available(), + reason="Requires SM100+ (Blackwell) and trtllm fused_relu2_quantize + fp4_quantize ops", +) + + +# FP4 E2M1 lookup table for reference implementation +E2M1_BOUNDS = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5]) + + +def relu2(x: torch.Tensor) -> torch.Tensor: + """Reference relu2 activation: square(relu(x)).""" + return torch.square(F.relu(x)) + + +def cast_to_fp4(weight: torch.Tensor) -> torch.Tensor: + """Cast tensor values to FP4 E2M1 format (as uint8).""" + device = weight.device + + mask = torch.tensor([0, 1, 0, 1, 0, 1, 0], dtype=torch.uint8).to(device) + mask_shape = list(weight.shape) + mask = mask.expand([*mask_shape, 7]) + + sign_bit = (weight < 0).to(torch.uint8) + weight_abs = weight.abs() + + ord_val = torch.searchsorted(E2M1_BOUNDS.to(device), weight_abs, out_int32=True).to(torch.uint8) + round_val = torch.any((weight_abs.unsqueeze(-1) == E2M1_BOUNDS.to(device)) * mask, dim=-1) + fp4_val = (sign_bit * 0b1000 + ord_val + round_val).to(torch.uint8) + return fp4_val + + +def quantize_nvfp4_ref( + input: torch.Tensor, sf_scale: torch.Tensor, sf_vec_size: int = 16 +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Reference NVFP4 quantization implementation. + + Args: + input: Input tensor [M, N], already activated (e.g., after relu2) + sf_scale: Per-tensor scaling factor (sf_scale = amax / (6 * 448)) + sf_vec_size: Block size for per-block scaling (default 16) + + Returns: + Tuple of (fp4_packed, scale_factors) + """ + m, n = input.shape + assert n % sf_vec_size == 0, f"N ({n}) must be divisible by sf_vec_size ({sf_vec_size})" + + # Reshape for block-wise quantization + input_blocked = input.view(m, n // sf_vec_size, sf_vec_size) + + # Compute per-block amax + per_block_amax = input_blocked.abs().amax(dim=-1).float() + + # Compute per-block scale: amax / 6.0 + per_block_scale = per_block_amax / 6.0 + + # Quantize per-block scale to FP8 + q_per_block_scale = per_block_scale / sf_scale + q_per_block_scale[per_block_scale == 0] = 1.0 + q_per_block_scale_fp8 = q_per_block_scale.to(torch.float8_e4m3fn) + + # Dequantize scale for actual quantization + scale_dequant = q_per_block_scale_fp8.float() * sf_scale + + # Scale the input + scale_expanded = scale_dequant.unsqueeze(-1).expand_as(input_blocked) + scaled_input = input_blocked / (scale_expanded + 1e-12) + scaled_input = scaled_input.view(m, n) + + # Cast to FP4 + fp4_vals = cast_to_fp4(scaled_input) + + # Pack two FP4 values into one uint8 + packed = (fp4_vals[..., 1::2] << 4) | fp4_vals[..., 0::2] + + return packed, q_per_block_scale_fp8 + + +def fused_relu2_quantize_ref( + input: torch.Tensor, sf_scale: torch.Tensor, sf_vec_size: int = 16 +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Reference implementation for fused relu2 + NVFP4 quantization. + + Args: + input: Input tensor [M, N] + sf_scale: Per-tensor scaling factor + sf_vec_size: Block size for per-block scaling (default 16) + + Returns: + Tuple of (fp4_packed, scale_factors) + """ + # Apply relu2 activation + activated = relu2(input) + # Quantize to NVFP4 + return quantize_nvfp4_ref(activated, sf_scale, sf_vec_size) + + +@skip_unless_fused_relu2_quantize +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_relu2_quantize_zeros(dtype): + """Test fused_relu2_quantize with inputs that produce zeros after relu2.""" + device = torch.device("cuda") + + # All negative inputs -> relu2 produces all zeros + m, n = 32, 64 + input_tensor = -torch.abs(torch.randn(m, n, dtype=dtype, device=device)) + sf_scale = torch.tensor([1.0], dtype=torch.float32, device=device) + fp4_fused, sf_fused = torch.ops.trtllm.fused_relu2_quantize(input_tensor, sf_scale, 16) + + assert fp4_fused.shape == (m, n // 2) + assert (fp4_fused == 0).all(), "All negative inputs should produce zero output" + + +@skip_unless_fused_relu2_and_fp4_quantize +@pytest.mark.parametrize("m", [1, 16, 64, 128]) +@pytest.mark.parametrize("n", [32, 64, 128, 256]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_relu2_quantize_vs_separate_ops(m, n, dtype): + """ + Compare fused_relu2_quantize kernel output against separate relu2 + fp4_quantize. + + This test verifies that the fused CUDA kernel produces FP4 packed values that + closely match running relu2 activation followed by fp4_quantize separately. + + Note: Due to floating point precision differences in intermediate calculations + (e.g., FMA vs separate mul+add), a small percentage of values at quantization + boundaries may differ. We require >= 99% match rate. + """ + torch.manual_seed(42) + device = torch.device("cuda") + + input_tensor = torch.randn(m, n, dtype=dtype, device=device) + activated = relu2(input_tensor) + sf_scale = (activated.abs().amax().float() / (6.0 * 448.0)).to(device) + sf_scale = sf_scale.view(1) + + fp4_separate, sf_separate = torch.ops.trtllm.fp4_quantize( + activated, + sf_scale, + 16, + False, + True, # use_ue8m0=False, is_sf_swizzled_layout=True + ) + fp4_fused, sf_fused = torch.ops.trtllm.fused_relu2_quantize( + input_tensor.contiguous(), sf_scale, 16 + ) + + match_rate = (fp4_fused == fp4_separate).float().mean().item() + assert match_rate >= 0.99, ( + f"FP4 values match rate {match_rate:.4f} < 0.99 for shape ({m}, {n}), dtype {dtype}" + ) + + +@skip_unless_fused_relu2_and_fp4_quantize +def test_fused_relu2_quantize_vs_separate_ops_various_sf_scales(): + """ + Test with various sf_scale values to ensure consistent behavior. + """ + device = torch.device("cuda") + m, n = 64, 128 + dtype = torch.bfloat16 + + torch.manual_seed(123) + input_tensor = torch.randn(m, n, dtype=dtype, device=device) + activated = relu2(input_tensor) + + # Test with different sf_scale values + for scale_multiplier in [0.1, 1.0, 10.0]: + sf_scale = ( + (activated.abs().amax().float() / (6.0 * 448.0) * scale_multiplier).to(device).view(1) + ) + fp4_separate, sf_separate = torch.ops.trtllm.fp4_quantize( + activated, sf_scale, 16, False, True + ) + fp4_fused, sf_fused = torch.ops.trtllm.fused_relu2_quantize( + input_tensor.contiguous(), sf_scale, 16 + ) + + match_rate = (fp4_fused == fp4_separate).float().mean().item() + assert match_rate >= 0.99, ( + f"FP4 values match rate {match_rate:.4f} < 0.99 with scale_multiplier={scale_multiplier}" + ) diff --git a/tests/unittest/_torch/modules/test_fused_add_rms_norm_quant.py b/tests/unittest/_torch/modules/test_fused_add_rms_norm_quant.py new file mode 100644 index 0000000000..f914b028e3 --- /dev/null +++ b/tests/unittest/_torch/modules/test_fused_add_rms_norm_quant.py @@ -0,0 +1,336 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for fused_add_rms_norm_quant with/without output_hp_norm support.""" + +import pytest +import torch + +from tests.unittest.utils.util import getSMVersion + + +def fused_add_rms_norm_quant_available(): + """Check if the fused_add_rms_norm_quant op is available.""" + return hasattr(torch.ops, "trtllm") and hasattr(torch.ops.trtllm, "fused_add_rms_norm_quant") + + +skip_unsupported = pytest.mark.skipif( + getSMVersion() < 100 or not fused_add_rms_norm_quant_available(), + reason="Requires Blackwell+ (SM100+) and trtllm.fused_add_rms_norm_quant op", +) + + +def rms_norm_ref( + hidden_states: torch.Tensor, + residual: torch.Tensor, + gamma: torch.Tensor, + eps: float, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Reference RMSNorm implementation with residual addition. + + Args: + hidden_states: Input tensor [M, N] + residual: Residual tensor [M, N] + gamma: Weight tensor [N] + eps: Epsilon for numerical stability + + Returns: + Tuple of (normalized_output, residual_output) + """ + input_dtype = hidden_states.dtype + + # Add residual + hidden_states_fp32 = hidden_states.float() + residual.float() + residual_out = hidden_states_fp32.to(input_dtype) + + # RMSNorm + variance = hidden_states_fp32.pow(2).mean(-1, keepdim=True) + normed = hidden_states_fp32 * torch.rsqrt(variance + eps) + normed_output = (gamma.float() * normed).to(input_dtype) + + return normed_output, residual_out + + +def get_swizzled_sf_indices(m: int, n: int, sf_vec_size: int = 16) -> list[int]: + """ + Compute the valid indices in swizzled SF layout for given m and n. + + The swizzled layout uses 128x4 tiles: + - SF layout: [numMTiles, numKTiles, 32 (outerM), 4 (innerM), 4 (innerK)] + - Each SF block has 128 rows, padded to multiple of 128 + - Each SF block has columns padded to multiple of 4 + + Args: + m: Number of rows + n: Hidden dimension + sf_vec_size: Number of elements sharing one scale factor (default 16) + + Returns: + List of valid indices in the swizzled buffer + """ + num_col_vecs = n // sf_vec_size # Number of SF columns + indices = [] + + for m_idx in range(m): + for k_idx in range(num_col_vecs): + # Compute swizzled offset using 128x4 tile layout + inner_k_idx = k_idx % 4 + inner_k_stride = 1 + + inner_m_idx = (m_idx % 128) // 32 + inner_m_stride = 4 * inner_k_stride # 4 + + outer_m_idx = m_idx % 32 + outer_m_stride = 4 * inner_m_stride # 16 + + k_tile_idx = k_idx // 4 + k_tile_stride = 32 * outer_m_stride # 512 + + num_k_tiles = (num_col_vecs + 3) // 4 + m_tile_idx = m_idx // 128 + m_tile_stride = num_k_tiles * k_tile_stride + + offset = ( + m_tile_idx * m_tile_stride + + k_tile_idx * k_tile_stride + + outer_m_idx * outer_m_stride + + inner_m_idx * inner_m_stride + + inner_k_idx * inner_k_stride + ) + indices.append(offset) + + return indices + + +@skip_unsupported +@pytest.mark.parametrize("m", [1, 16, 64, 128]) +@pytest.mark.parametrize("n", [2048, 4096]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_add_rms_norm_quant_basic(m, n, dtype): + """Test basic functionality of fused_add_rms_norm_quant without hp_output.""" + torch.manual_seed(42) + device = torch.device("cuda") + + # Create input tensors + hidden_states = torch.randn(m, n, dtype=dtype, device=device) + residual = torch.randn(m, n, dtype=dtype, device=device) + gamma = torch.ones(n, dtype=dtype, device=device) + + # Compute sf_scale (per-tensor scale) + eps = 1e-6 + normed_ref, _ = rms_norm_ref(hidden_states, residual, gamma, eps) + sf_scale = (normed_ref.abs().amax().float() / (6.0 * 448.0)).view(1) + + # Run fused kernel without hp_output + normed_fp4, residual_out, sf_out, dummy_output = torch.ops.trtllm.fused_add_rms_norm_quant( + hidden_states, residual, gamma, sf_scale, True, eps=eps + ) + assert dummy_output is None + + # Verify output shapes + assert normed_fp4.shape[0] == m + assert residual_out.shape == (m, n) + assert residual_out.dtype == dtype + + +@skip_unsupported +@pytest.mark.parametrize("m", [1, 16, 64, 128]) +@pytest.mark.parametrize("n", [2048, 4096]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_add_rms_norm_quant_with_hp_output(m, n, dtype): + """Test fused_add_rms_norm_quant with output_hp_norm=True.""" + torch.manual_seed(42) + device = torch.device("cuda") + + # Create input tensors + hidden_states = torch.randn(m, n, dtype=dtype, device=device) + residual = torch.randn(m, n, dtype=dtype, device=device) + gamma = torch.ones(n, dtype=dtype, device=device) + + # Compute sf_scale + eps = 1e-6 + normed_ref, residual_ref = rms_norm_ref(hidden_states, residual, gamma, eps) + sf_scale = (normed_ref.abs().amax().float() / (6.0 * 448.0)).view(1) + + # Run fused kernel with hp_output + results = torch.ops.trtllm.fused_add_rms_norm_quant( + hidden_states, residual, gamma, sf_scale, True, eps=eps, output_hp_norm=True + ) + + # Should return 4 tensors when output_hp_norm=True + assert len(results) == 4, f"Expected 4 outputs, got {len(results)}" + + normed_fp4, residual_out, sf_out, hp_normed_output = results + + # Verify output shapes + assert normed_fp4.shape[0] == m + assert residual_out.shape == (m, n) + assert hp_normed_output.shape == (m, n) + + # Verify dtypes + assert residual_out.dtype == dtype + assert hp_normed_output.dtype == dtype + + # Verify high precision output matches reference + torch.testing.assert_close(hp_normed_output, normed_ref, rtol=1e-2, atol=1e-2) + + # Verify residual output matches reference + torch.testing.assert_close(residual_out, residual_ref, rtol=1e-3, atol=1e-3) + + +@skip_unsupported +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_add_rms_norm_quant_hp_output_consistency(dtype): + """Test that hp_output is consistent with the quantized output.""" + torch.manual_seed(42) + device = torch.device("cuda") + + m, n = 64, 4096 + + # Create input tensors + hidden_states = torch.randn(m, n, dtype=dtype, device=device) + residual = torch.randn(m, n, dtype=dtype, device=device) + gamma = torch.ones(n, dtype=dtype, device=device) + + eps = 1e-6 + normed_ref, _ = rms_norm_ref(hidden_states, residual, gamma, eps) + sf_scale = (normed_ref.abs().amax().float() / (6.0 * 448.0)).view(1) + + # Run without hp_output + results_no_hp = torch.ops.trtllm.fused_add_rms_norm_quant( + hidden_states, residual, gamma, sf_scale, True, eps=eps, output_hp_norm=False + ) + assert results_no_hp[3] is None + normed_fp4_no_hp, residual_out_no_hp, sf_out_no_hp = results_no_hp[:3] + + # Run with hp_output + results_hp = torch.ops.trtllm.fused_add_rms_norm_quant( + hidden_states, residual, gamma, sf_scale, True, eps=eps, output_hp_norm=True + ) + normed_fp4_hp, residual_out_hp, sf_out_hp, hp_normed_output = results_hp + + # The quantized outputs should be identical regardless of hp_output flag + torch.testing.assert_close(normed_fp4_hp, normed_fp4_no_hp, rtol=0, atol=0) + torch.testing.assert_close(residual_out_hp, residual_out_no_hp, rtol=0, atol=0) + # Compare only valid SF indices (swizzled layout pads rows to 128) + valid_sf_indices = get_swizzled_sf_indices(m, n) + sf_out_hp_valid = sf_out_hp[valid_sf_indices] + sf_out_no_hp_valid = sf_out_no_hp[valid_sf_indices] + torch.testing.assert_close(sf_out_hp_valid, sf_out_no_hp_valid, rtol=0, atol=0) + + +@skip_unsupported +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_add_rms_norm_quant_gamma_weight(dtype): + """Test fused_add_rms_norm_quant with non-trivial gamma weights.""" + torch.manual_seed(42) + device = torch.device("cuda") + + m, n = 32, 2048 + + # Create input tensors + hidden_states = torch.randn(m, n, dtype=dtype, device=device) + residual = torch.randn(m, n, dtype=dtype, device=device) + # Non-trivial gamma weights + gamma = torch.randn(n, dtype=dtype, device=device) * 0.5 + 1.0 + + eps = 1e-6 + normed_ref, residual_ref = rms_norm_ref(hidden_states, residual, gamma, eps) + sf_scale = (normed_ref.abs().amax().float() / (6.0 * 448.0)).view(1) + + # Run with hp_output + results = torch.ops.trtllm.fused_add_rms_norm_quant( + hidden_states, residual, gamma, sf_scale, True, eps=eps, output_hp_norm=True + ) + normed_fp4, residual_out, sf_out, hp_normed_output = results + + # Verify high precision output matches reference + torch.testing.assert_close(hp_normed_output, normed_ref, rtol=1e-2, atol=1e-2) + + # Verify residual output matches reference + torch.testing.assert_close(residual_out, residual_ref, rtol=1e-3, atol=1e-3) + + +@skip_unsupported +def test_fused_add_rms_norm_quant_large_batch(): + """Test fused_add_rms_norm_quant with larger batch size.""" + torch.manual_seed(42) + device = torch.device("cuda") + + m, n = 512, 4096 + dtype = torch.bfloat16 + + hidden_states = torch.randn(m, n, dtype=dtype, device=device) + residual = torch.randn(m, n, dtype=dtype, device=device) + gamma = torch.ones(n, dtype=dtype, device=device) + + eps = 1e-6 + normed_ref, residual_ref = rms_norm_ref(hidden_states, residual, gamma, eps) + sf_scale = (normed_ref.abs().amax().float() / (6.0 * 448.0)).view(1) + + results = torch.ops.trtllm.fused_add_rms_norm_quant( + hidden_states, residual, gamma, sf_scale, True, eps=eps, output_hp_norm=True + ) + normed_fp4, residual_out, sf_out, hp_normed_output = results + + assert hp_normed_output.shape == (m, n) + torch.testing.assert_close(hp_normed_output, normed_ref, rtol=1e-2, atol=1e-2) + + +@skip_unsupported +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_low_latency_layernorm_hp_output_consistency(dtype): + """ + Test that low_latency_layernorm hp_output is consistent with/without the flag. + + The quantized outputs should be identical regardless of output_hp_norm flag. + Uses m=1 to trigger the low_latency_layernorm path. + """ + torch.manual_seed(42) + device = torch.device("cuda") + + m, n = 1, 4096 # m=1 triggers low_latency_layernorm path + + hidden_states = torch.randn(m, n, dtype=dtype, device=device) + residual = torch.randn(m, n, dtype=dtype, device=device) + gamma = torch.ones(n, dtype=dtype, device=device) + + eps = 1e-6 + normed_ref, _ = rms_norm_ref(hidden_states, residual, gamma, eps) + sf_scale = (normed_ref.abs().amax().float() / (6.0 * 448.0)).view(1) + + # Run without hp_output + results_no_hp = torch.ops.trtllm.fused_add_rms_norm_quant( + hidden_states, residual, gamma, sf_scale, True, eps=eps + ) + assert len(results_no_hp) == 4, f"Expected 4 outputs, got {len(results_no_hp)}" + assert results_no_hp[3] is None, "Expected 4th output to be None when output_hp_norm=False" + normed_fp4_no_hp, residual_out_no_hp, sf_out_no_hp = results_no_hp[:3] + + # Run with hp_output + results_hp = torch.ops.trtllm.fused_add_rms_norm_quant( + hidden_states, residual, gamma, sf_scale, True, eps=eps, output_hp_norm=True + ) + assert len(results_hp) == 4, f"Expected 4 outputs, got {len(results_hp)}" + normed_fp4_hp, residual_out_hp, sf_out_hp, hp_normed_output = results_hp + + # The quantized outputs should be identical regardless of hp_output flag + torch.testing.assert_close(normed_fp4_hp, normed_fp4_no_hp, rtol=0, atol=0) + torch.testing.assert_close(residual_out_hp, residual_out_no_hp, rtol=0, atol=0) + # Compare only valid SF indices (swizzled layout pads rows to 128) + valid_sf_indices = get_swizzled_sf_indices(m, n) + sf_out_hp_valid = sf_out_hp[valid_sf_indices] + sf_out_no_hp_valid = sf_out_no_hp[valid_sf_indices] + torch.testing.assert_close(sf_out_hp_valid, sf_out_no_hp_valid, rtol=0, atol=0)