mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[None][feat] Optimize NemotronH model with elementwise and nvfp4 fusion (#11273)
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
This commit is contained in:
parent
ef7830d137
commit
421eb9e39c
@ -43,6 +43,8 @@ struct Causal_conv1d_fwd_kernel_traits
|
||||
static_assert(kWidth <= kNElts);
|
||||
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
||||
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
||||
static_assert(kNThreads_ % 32 == 0, "kNThreads must be a multiple of 32 for warp shuffle");
|
||||
static_assert(sizeof(vec_t) == 16, "vec_t must be 16 bytes for warp shuffle optimization");
|
||||
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
@ -123,7 +125,7 @@ __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_fwd_kernel(C
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i)
|
||||
{
|
||||
weight_vals[i] = float(weight[i * params.weight_width_stride]);
|
||||
weight_vals[i] = float(__ldg(&weight[i * params.weight_width_stride]));
|
||||
}
|
||||
|
||||
constexpr int kChunkSize = kNThreads * kNElts;
|
||||
@ -144,20 +146,41 @@ __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_fwd_kernel(C
|
||||
x, *reinterpret_cast<input_t(*)[kNElts]>(&x_vals_load[kNElts]), seqlen - chunk * kChunkSize);
|
||||
}
|
||||
x += kChunkSize;
|
||||
|
||||
int const lane_id = tidx & 31;
|
||||
vec_t high_val = reinterpret_cast<vec_t*>(x_vals_load)[1];
|
||||
|
||||
__syncthreads();
|
||||
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
|
||||
// the last elements of the previous chunk.
|
||||
if (tidx < kNThreads - 1)
|
||||
{
|
||||
smem_exchange[tidx] = reinterpret_cast<vec_t*>(x_vals_load)[1];
|
||||
smem_exchange[tidx] = high_val;
|
||||
}
|
||||
__syncthreads();
|
||||
reinterpret_cast<vec_t*>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
|
||||
|
||||
// Get neighbor data: use warp shuffle for most threads, shared memory for warp boundaries
|
||||
vec_t neighbor;
|
||||
uint32_t* high_val_p = reinterpret_cast<uint32_t*>(&high_val);
|
||||
uint32_t* nbr_p = reinterpret_cast<uint32_t*>(&neighbor);
|
||||
nbr_p[0] = __shfl_up_sync(0xFFFFFFFF, high_val_p[0], 1);
|
||||
nbr_p[1] = __shfl_up_sync(0xFFFFFFFF, high_val_p[1], 1);
|
||||
nbr_p[2] = __shfl_up_sync(0xFFFFFFFF, high_val_p[2], 1);
|
||||
nbr_p[3] = __shfl_up_sync(0xFFFFFFFF, high_val_p[3], 1);
|
||||
|
||||
// Lane 0 must use shared memory to handle the cross-warp boundary.
|
||||
// thread 0 uses the last element of the previous chunk.
|
||||
if (lane_id == 0)
|
||||
{
|
||||
neighbor = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
|
||||
}
|
||||
reinterpret_cast<vec_t*>(x_vals_load)[0] = neighbor;
|
||||
|
||||
__syncthreads();
|
||||
// Now thread kNThreads - 1 can write the last elements of the current chunk.
|
||||
if (tidx == kNThreads - 1)
|
||||
{
|
||||
smem_exchange[tidx] = reinterpret_cast<vec_t*>(x_vals_load)[1];
|
||||
smem_exchange[tidx] = high_val;
|
||||
}
|
||||
|
||||
float x_vals[2 * kNElts];
|
||||
@ -169,22 +192,33 @@ __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_fwd_kernel(C
|
||||
|
||||
float out_vals[kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i)
|
||||
// Process 2 outputs at a time for better ILP (instruction level parallelism).
|
||||
for (int i = 0; i < kNElts; i += 2)
|
||||
{
|
||||
out_vals[i] = bias_val;
|
||||
float acc0 = bias_val;
|
||||
float acc1 = bias_val;
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth; ++w)
|
||||
{
|
||||
out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
|
||||
float wt = weight_vals[w];
|
||||
acc0 = __fmaf_rn(wt, x_vals[kNElts + i - (kWidth - w - 1)], acc0);
|
||||
acc1 = __fmaf_rn(wt, x_vals[kNElts + i + 1 - (kWidth - w - 1)], acc1);
|
||||
}
|
||||
out_vals[i] = acc0;
|
||||
out_vals[i + 1] = acc1;
|
||||
}
|
||||
|
||||
if (params.silu_activation)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i)
|
||||
for (int i = 0; i < kNElts; i += 2)
|
||||
{
|
||||
out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
|
||||
// SiLU: x * sigmoid(x) = x / (1 + exp(-x))
|
||||
// Using fast math: __expf and __frcp_rn
|
||||
float v0 = out_vals[i];
|
||||
float v1 = out_vals[i + 1];
|
||||
out_vals[i] = v0 * __frcp_rn(1.0f + __expf(-v0));
|
||||
out_vals[i + 1] = v1 * __frcp_rn(1.0f + __expf(-v1));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
187
cpp/tensorrt_llm/kernels/fusedActivationQuant.cu
Normal file
187
cpp/tensorrt_llm/kernels/fusedActivationQuant.cu
Normal file
@ -0,0 +1,187 @@
|
||||
/*
|
||||
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/kernels/fusedActivationQuant.h"
|
||||
#include "tensorrt_llm/kernels/quantization.cuh"
|
||||
#include "tensorrt_llm/kernels/quantization.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
TRTLLM_NAMESPACE_BEGIN
|
||||
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
constexpr int kEltsPerThread = 8;
|
||||
|
||||
__device__ __forceinline__ float relu2_f32(float x)
|
||||
{
|
||||
float r = fmaxf(0.0f, x);
|
||||
return r * r;
|
||||
}
|
||||
|
||||
// Fused relu2 + NVFP4 quantization kernel.
|
||||
//
|
||||
// To match the unfused path (PyTorch relu2 -> cvt_warp_fp16_to_fp4), relu2 is
|
||||
// computed in f32 then rounded back to native precision (bf16/fp16) before
|
||||
// quantization. Absmax and scale-factor math follow cvt_warp_fp16_to_fp4 exactly.
|
||||
// Column padding to a multiple of (4 * kSfVecSize) matches quantize_with_block_size
|
||||
// for the swizzled SF layout.
|
||||
template <typename T>
|
||||
__global__ void fusedRelu2QuantizeKernel(T const* __restrict__ input, float const* __restrict__ sfScale,
|
||||
uint32_t* __restrict__ outputFp4, uint32_t* __restrict__ outputSf, int m, int n)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
constexpr int kSfVecSize = 16;
|
||||
constexpr int kNumThreadsPerSf = kSfVecSize / kEltsPerThread;
|
||||
constexpr int kPackedPerThread = kEltsPerThread / 2;
|
||||
|
||||
using PackedType = std::conditional_t<std::is_same_v<T, half>, __half2, __nv_bfloat162>;
|
||||
|
||||
float const SFScaleVal = sfScale[0];
|
||||
int const numColThreads = n / kEltsPerThread;
|
||||
int const numColVecs = n / kSfVecSize;
|
||||
int const numColThreadsPadded = ((n + 4 * kSfVecSize - 1) / (4 * kSfVecSize)) * (4 * kSfVecSize) / kEltsPerThread;
|
||||
int const rowIdx = blockIdx.x;
|
||||
|
||||
if (rowIdx >= m)
|
||||
return;
|
||||
|
||||
for (int colIdx = threadIdx.x; colIdx < numColThreadsPadded; colIdx += blockDim.x)
|
||||
{
|
||||
bool const isValidCol = colIdx < numColThreads;
|
||||
PackedType packedVals[kPackedPerThread];
|
||||
|
||||
if (isValidCol)
|
||||
{
|
||||
int const inputOffset = rowIdx * n + colIdx * kEltsPerThread;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kPackedPerThread; i++)
|
||||
{
|
||||
float f0 = relu2_f32(static_cast<float>(input[inputOffset + i * 2]));
|
||||
float f1 = relu2_f32(static_cast<float>(input[inputOffset + i * 2 + 1]));
|
||||
if constexpr (std::is_same_v<T, half>)
|
||||
{
|
||||
packedVals[i] = __floats2half2_rn(f0, f1);
|
||||
}
|
||||
else
|
||||
{
|
||||
packedVals[i] = __floats2bfloat162_rn(f0, f1);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kPackedPerThread; i++)
|
||||
{
|
||||
if constexpr (std::is_same_v<T, half>)
|
||||
{
|
||||
packedVals[i] = __float2half2_rn(0.0f);
|
||||
}
|
||||
else
|
||||
{
|
||||
packedVals[i] = __float2bfloat162_rn(0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Absmax in native precision, then reduce across the SF group (2 threads).
|
||||
auto localMax = cuda_abs(packedVals[0]);
|
||||
#pragma unroll
|
||||
for (int i = 1; i < kPackedPerThread; i++)
|
||||
{
|
||||
localMax = cuda_max(localMax, cuda_abs(packedVals[i]));
|
||||
}
|
||||
localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
|
||||
float vecMax = float(cuda_max(localMax.x, localMax.y));
|
||||
|
||||
// Scale-factor computation (identical to cvt_warp_fp16_to_fp4).
|
||||
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
|
||||
__nv_fp8_e4m3 fp8SF = __nv_fp8_e4m3(SFValue);
|
||||
uint8_t fp8SFVal = fp8SF.__x;
|
||||
SFValue = static_cast<float>(fp8SF);
|
||||
|
||||
float outputScale
|
||||
= vecMax != 0.0f ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) : 0.0f;
|
||||
|
||||
if (colIdx % kNumThreadsPerSf == 0)
|
||||
{
|
||||
auto sfOutPtr = cvt_quant_get_sf_out_offset<uint32_t, kNumThreadsPerSf>(std::nullopt, rowIdx, colIdx,
|
||||
std::optional<int>(m), numColVecs, outputSf, QuantizationSFLayout::SWIZZLED);
|
||||
if (sfOutPtr != nullptr)
|
||||
{
|
||||
*sfOutPtr = fp8SFVal;
|
||||
}
|
||||
}
|
||||
|
||||
if (isValidCol)
|
||||
{
|
||||
float2 fp2Vals[kPackedPerThread];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kPackedPerThread; i++)
|
||||
{
|
||||
if constexpr (std::is_same_v<T, half>)
|
||||
{
|
||||
fp2Vals[i] = __half22float2(packedVals[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
fp2Vals[i] = __bfloat1622float2(packedVals[i]);
|
||||
}
|
||||
fp2Vals[i].x *= outputScale;
|
||||
fp2Vals[i].y *= outputScale;
|
||||
}
|
||||
|
||||
outputFp4[rowIdx * numColThreads + colIdx] = fp32_vec_to_e2m1(fp2Vals);
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0)
|
||||
{
|
||||
printf("FP4 quantization requires SM100 (Blackwell) or later!\n");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeFusedRelu2Quantize(T const* input, float const* sfScale, std::uint8_t* outputFp4, std::uint8_t* outputSf,
|
||||
int m, int n, int sfVecSize, cudaStream_t stream)
|
||||
{
|
||||
constexpr int kSfVecSize = 16;
|
||||
int const numColThreadsPadded = ((n + 4 * kSfVecSize - 1) / (4 * kSfVecSize)) * (4 * kSfVecSize) / kEltsPerThread;
|
||||
int threadsPerBlock = min(512, numColThreadsPadded);
|
||||
threadsPerBlock = max(32, ((threadsPerBlock + 31) / 32) * 32);
|
||||
|
||||
fusedRelu2QuantizeKernel<T><<<m, threadsPerBlock, 0, stream>>>(
|
||||
input, sfScale, reinterpret_cast<uint32_t*>(outputFp4), reinterpret_cast<uint32_t*>(outputSf), m, n);
|
||||
}
|
||||
|
||||
template void invokeFusedRelu2Quantize<half>(
|
||||
half const*, float const*, std::uint8_t*, std::uint8_t*, int, int, int, cudaStream_t);
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template void invokeFusedRelu2Quantize<__nv_bfloat16>(
|
||||
__nv_bfloat16 const*, float const*, std::uint8_t*, std::uint8_t*, int, int, int, cudaStream_t);
|
||||
#endif
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
TRTLLM_NAMESPACE_END
|
||||
33
cpp/tensorrt_llm/kernels/fusedActivationQuant.h
Normal file
33
cpp/tensorrt_llm/kernels/fusedActivationQuant.h
Normal file
@ -0,0 +1,33 @@
|
||||
/*
|
||||
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/config.h"
|
||||
#include <cstdint>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
TRTLLM_NAMESPACE_BEGIN
|
||||
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
void invokeFusedRelu2Quantize(T const* input, float const* sfScale, std::uint8_t* outputFp4, std::uint8_t* outputSf,
|
||||
int m, int n, int sfVecSize, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
TRTLLM_NAMESPACE_END
|
||||
@ -37,6 +37,7 @@ struct GeneralFP4AddBiasResidualPreLayerNormParam
|
||||
T const* bias = nullptr;
|
||||
T const* gamma = nullptr;
|
||||
T const* beta = nullptr;
|
||||
T* high_precision_normed_output = nullptr;
|
||||
|
||||
int m;
|
||||
int n;
|
||||
|
||||
@ -276,7 +276,7 @@ struct LowLatencyLayerNorm
|
||||
}
|
||||
|
||||
typename PackType<typename Traits::OutputType, Traits::PACKED_ELEMS_PER_COMPUTE>::type normed_output;
|
||||
typename PackType<typename Traits::AccumulatorType, Traits::PACKED_ELEMS_PER_COMPUTE>::type
|
||||
typename PackType<typename Traits::InputType, Traits::PACKED_ELEMS_PER_COMPUTE>::type
|
||||
high_precision_normed_output;
|
||||
for (int j = 0; j < Traits::PACKED_ELEMS_PER_COMPUTE; j++)
|
||||
{
|
||||
@ -300,7 +300,7 @@ struct LowLatencyLayerNorm
|
||||
}
|
||||
if constexpr (Traits::HIGH_PRECISION_NORMED_OUTPUT)
|
||||
{
|
||||
high_precision_normed_output.array[j] = normed_out;
|
||||
high_precision_normed_output.array[j] = (typename Traits::InputType) normed_out;
|
||||
}
|
||||
if constexpr (Traits::OUTPUT_SCALE == SCALE_TYPE::SCALAR)
|
||||
{
|
||||
|
||||
@ -690,7 +690,7 @@ struct WarpSpecializedLayerNorm
|
||||
typename PackType<typename Traits::OutputType, Traits::PACKED_ELEMS_PER_COMPUTE>::type
|
||||
normed_output;
|
||||
typename PackType<typename Traits::InputType, Traits::PACKED_ELEMS_PER_COMPUTE>::type output;
|
||||
typename PackType<typename Traits::AccumulatorType, Traits::PACKED_ELEMS_PER_COMPUTE>::type
|
||||
typename PackType<typename Traits::InputType, Traits::PACKED_ELEMS_PER_COMPUTE>::type
|
||||
high_precision_normed_output;
|
||||
|
||||
#pragma unroll Traits::PACKED_ELEMS_PER_COMPUTE
|
||||
@ -719,6 +719,11 @@ struct WarpSpecializedLayerNorm
|
||||
normed_out += beta[j];
|
||||
}
|
||||
|
||||
if constexpr (Traits::HIGH_PRECISION_NORMED_OUTPUT)
|
||||
{
|
||||
high_precision_normed_output.array[j] = (typename Traits::InputType) normed_out;
|
||||
}
|
||||
|
||||
if constexpr (Traits::OUTPUT_SCALE != SCALE_TYPE::NONE)
|
||||
{
|
||||
static_assert(Traits::OUTPUT_SCALE == SCALE_TYPE::SCALAR);
|
||||
@ -730,11 +735,6 @@ struct WarpSpecializedLayerNorm
|
||||
output.array[j] = (typename Traits::InputType) data[m_offset][i][j];
|
||||
}
|
||||
|
||||
if constexpr (Traits::HIGH_PRECISION_NORMED_OUTPUT)
|
||||
{
|
||||
high_precision_normed_output.array[j] = normed_out;
|
||||
}
|
||||
|
||||
normed_output.array[j] = (typename Traits::OutputType) normed_out;
|
||||
}
|
||||
|
||||
|
||||
@ -44,7 +44,7 @@ enum class SCALE_TYPE
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void invokeWSLayerNorm(WarpSpecializedParam<T> param, bool use_rms_norm, int ctas);
|
||||
void invokeWSLayerNorm(WarpSpecializedParam<T> param, bool use_rms_norm, int ctas, bool output_hp_norm = false);
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
|
||||
@ -31,7 +31,8 @@ TRTLLM_NAMESPACE_BEGIN
|
||||
namespace kernels
|
||||
{
|
||||
template <typename _Param, typename _InputType, typename _OutputType, typename _AccumulatorType, bool _RMS_NORM,
|
||||
int _M_BLOCK, int _N_BLOCK, int _STAGES = 3, bool _PERSISTENT_MODE = true, bool _LOW_LATENCY_MODE = false>
|
||||
int _M_BLOCK, int _N_BLOCK, int _STAGES = 3, bool _PERSISTENT_MODE = true, bool _LOW_LATENCY_MODE = false,
|
||||
bool _HIGH_PRECISION_NORMED_OUTPUT = false>
|
||||
struct FP4AddBiasResidualPreLayerNormTraits
|
||||
{
|
||||
|
||||
@ -59,12 +60,12 @@ struct FP4AddBiasResidualPreLayerNormTraits
|
||||
static constexpr bool PERSISTENT_MODE = _PERSISTENT_MODE;
|
||||
static constexpr bool LOW_LATENCY_MODE = _LOW_LATENCY_MODE;
|
||||
static constexpr bool PREFETCH_TO_L2 = false;
|
||||
static constexpr bool HIGH_PRECISION_NORMED_OUTPUT = false;
|
||||
static constexpr bool HIGH_PRECISION_NORMED_OUTPUT = _HIGH_PRECISION_NORMED_OUTPUT;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void invokeWSLayerNormImpl(
|
||||
WarpSpecializedParam<GeneralFP4AddBiasResidualPreLayerNormParam<T>> param, bool use_rms_norm, int ctas)
|
||||
void invokeWSLayerNormImpl(WarpSpecializedParam<GeneralFP4AddBiasResidualPreLayerNormParam<T>> param, bool use_rms_norm,
|
||||
int ctas, bool output_hp_norm)
|
||||
{
|
||||
|
||||
auto _invoke = [&](auto traits)
|
||||
@ -80,10 +81,11 @@ void invokeWSLayerNormImpl(
|
||||
{
|
||||
int waves = ((param.m + Traits::M_BLOCK - 1) / Traits::M_BLOCK + ctas - 1) / ctas;
|
||||
TLLM_LOG_DEBUG(
|
||||
"Selected TILE_M = %d, N = %d, STAGE = %d, PERSISTENT_MODE = %d, LOW_LATENCY_MODE = %d for param M = "
|
||||
"Selected TILE_M = %d, N = %d, STAGE = %d, PERSISTENT_MODE = %d, LOW_LATENCY_MODE = %d, "
|
||||
"HIGH_PRECISION_NORMED_OUTPUT = %d for param M = "
|
||||
"%d, N = %d, num_sms = %d. (waves = %d)\n",
|
||||
Traits::M_BLOCK, Traits::N_BLOCK, Traits::STAGES, Traits::PERSISTENT_MODE, Traits::LOW_LATENCY_MODE,
|
||||
param.m, param.n, ctas, waves);
|
||||
Traits::HIGH_PRECISION_NORMED_OUTPUT, param.m, param.n, ctas, waves);
|
||||
printed = true;
|
||||
}
|
||||
|
||||
@ -117,15 +119,32 @@ void invokeWSLayerNormImpl(
|
||||
constexpr auto PERSISTENT = decltype(persistent)::value;
|
||||
constexpr auto LOW_LATENCY_MODE = decltype(low_latency_mode)::value;
|
||||
|
||||
// Select kernel variant based on use_rms_norm and output_hp_norm
|
||||
if (use_rms_norm)
|
||||
{
|
||||
_invoke(FP4AddBiasResidualPreLayerNormTraits<GeneralFP4AddBiasResidualPreLayerNormParam<T>, T, T, float,
|
||||
true, M_BLOCK, N_BLOCK, STAGES, PERSISTENT, LOW_LATENCY_MODE>{});
|
||||
if (output_hp_norm)
|
||||
{
|
||||
_invoke(FP4AddBiasResidualPreLayerNormTraits<GeneralFP4AddBiasResidualPreLayerNormParam<T>, T, T, float,
|
||||
true, M_BLOCK, N_BLOCK, STAGES, PERSISTENT, LOW_LATENCY_MODE, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
_invoke(FP4AddBiasResidualPreLayerNormTraits<GeneralFP4AddBiasResidualPreLayerNormParam<T>, T, T, float,
|
||||
true, M_BLOCK, N_BLOCK, STAGES, PERSISTENT, LOW_LATENCY_MODE, false>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
_invoke(FP4AddBiasResidualPreLayerNormTraits<GeneralFP4AddBiasResidualPreLayerNormParam<T>, T, T, float,
|
||||
false, M_BLOCK, N_BLOCK, STAGES, PERSISTENT, LOW_LATENCY_MODE>{});
|
||||
if (output_hp_norm)
|
||||
{
|
||||
_invoke(FP4AddBiasResidualPreLayerNormTraits<GeneralFP4AddBiasResidualPreLayerNormParam<T>, T, T, float,
|
||||
false, M_BLOCK, N_BLOCK, STAGES, PERSISTENT, LOW_LATENCY_MODE, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
_invoke(FP4AddBiasResidualPreLayerNormTraits<GeneralFP4AddBiasResidualPreLayerNormParam<T>, T, T, float,
|
||||
false, M_BLOCK, N_BLOCK, STAGES, PERSISTENT, LOW_LATENCY_MODE, false>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -308,16 +327,18 @@ void invokeWSLayerNormImpl(
|
||||
|
||||
template <>
|
||||
void invokeWSLayerNorm<GeneralFP4AddBiasResidualPreLayerNormParam<half>>(
|
||||
WarpSpecializedParam<GeneralFP4AddBiasResidualPreLayerNormParam<half>> param, bool use_rms_norm, int ctas)
|
||||
WarpSpecializedParam<GeneralFP4AddBiasResidualPreLayerNormParam<half>> param, bool use_rms_norm, int ctas,
|
||||
bool output_hp_norm)
|
||||
{
|
||||
invokeWSLayerNormImpl(param, use_rms_norm, ctas);
|
||||
invokeWSLayerNormImpl(param, use_rms_norm, ctas, output_hp_norm);
|
||||
}
|
||||
|
||||
template <>
|
||||
void invokeWSLayerNorm<GeneralFP4AddBiasResidualPreLayerNormParam<__nv_bfloat16>>(
|
||||
WarpSpecializedParam<GeneralFP4AddBiasResidualPreLayerNormParam<__nv_bfloat16>> param, bool use_rms_norm, int ctas)
|
||||
WarpSpecializedParam<GeneralFP4AddBiasResidualPreLayerNormParam<__nv_bfloat16>> param, bool use_rms_norm, int ctas,
|
||||
bool output_hp_norm)
|
||||
{
|
||||
invokeWSLayerNormImpl(param, use_rms_norm, ctas);
|
||||
invokeWSLayerNormImpl(param, use_rms_norm, ctas, output_hp_norm);
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
@ -67,6 +67,7 @@ add_library(
|
||||
dsv3FusedAGemmOp.cpp
|
||||
fusedQKNormRopeOp.cpp
|
||||
fusedAddRMSNormQuant.cpp
|
||||
fusedActivationQuant.cpp
|
||||
fusedTopkSoftmax.cpp
|
||||
gatherTreeOp.cpp
|
||||
groupRmsNormOp.cpp
|
||||
|
||||
94
cpp/tensorrt_llm/thop/fusedActivationQuant.cpp
Normal file
94
cpp/tensorrt_llm/thop/fusedActivationQuant.cpp
Normal file
@ -0,0 +1,94 @@
|
||||
/*
|
||||
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/fusedActivationQuant.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/kernels/quantization.h"
|
||||
#include "tensorrt_llm/thop/thUtils.h"
|
||||
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/EmptyTensor.h>
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
TRTLLM_NAMESPACE_BEGIN
|
||||
|
||||
namespace torch_ext
|
||||
{
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> fused_relu2_quantize(
|
||||
at::Tensor const& input, at::Tensor const& sf_scale, int64_t sf_vec_size)
|
||||
{
|
||||
CHECK_TH_CUDA(input);
|
||||
CHECK_CONTIGUOUS(input);
|
||||
CHECK_INPUT(sf_scale, torch::kFloat32);
|
||||
|
||||
auto const& inputShape = input.sizes();
|
||||
TORCH_CHECK(inputShape.size() == 2, "input should be 2D tensor [M, N].");
|
||||
|
||||
int64_t const m = inputShape[0];
|
||||
int64_t const n = inputShape[1];
|
||||
|
||||
TORCH_CHECK(sf_vec_size == 16, "sf_vec_size must be 16 for NVFP4.");
|
||||
TORCH_CHECK(n % sf_vec_size == 0, "N must be divisible by sf_vec_size.");
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
|
||||
at::Tensor output_fp4 = at::detail::empty_cuda({m, n / 2}, torch::kUInt8, input.device(), std::nullopt);
|
||||
int64_t const sfSize = tensorrt_llm::computeSwizzledLayoutSFSize(m, n / sf_vec_size);
|
||||
at::Tensor output_sf = at::detail::empty_cuda({sfSize}, SF_DTYPE, input.device(), std::nullopt);
|
||||
|
||||
float const* sfScalePtr = sf_scale.data_ptr<float>();
|
||||
|
||||
if (input.scalar_type() == at::ScalarType::Half)
|
||||
{
|
||||
kernels::invokeFusedRelu2Quantize<half>(reinterpret_cast<half const*>(input.data_ptr()), sfScalePtr,
|
||||
output_fp4.data_ptr<uint8_t>(), output_sf.data_ptr<uint8_t>(), static_cast<int>(m), static_cast<int>(n),
|
||||
static_cast<int>(sf_vec_size), stream);
|
||||
}
|
||||
else if (input.scalar_type() == at::ScalarType::BFloat16)
|
||||
{
|
||||
#ifdef ENABLE_BF16
|
||||
kernels::invokeFusedRelu2Quantize<__nv_bfloat16>(reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr()),
|
||||
sfScalePtr, output_fp4.data_ptr<uint8_t>(), output_sf.data_ptr<uint8_t>(), static_cast<int>(m),
|
||||
static_cast<int>(n), static_cast<int>(sf_vec_size), stream);
|
||||
#else
|
||||
C10_THROW_ERROR(NotImplementedError, "BFloat16 not enabled.");
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
C10_THROW_ERROR(NotImplementedError, "fused_relu2_quantize only supports fp16/bf16.");
|
||||
}
|
||||
|
||||
return std::make_tuple(output_fp4, output_sf);
|
||||
}
|
||||
|
||||
} // namespace torch_ext
|
||||
|
||||
TRTLLM_NAMESPACE_END
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def("fused_relu2_quantize(Tensor input, Tensor sf_scale, int sf_vec_size=16) -> (Tensor, Tensor)");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
{
|
||||
m.impl("fused_relu2_quantize", &tensorrt_llm::torch_ext::fused_relu2_quantize);
|
||||
}
|
||||
@ -43,16 +43,18 @@ namespace torch_ext
|
||||
// gamma: [N] - RMSNorm weight (fp16/bf16)
|
||||
// sf_scale: [1] - optional scale factor for FP4 quantization (float)
|
||||
// use_rms_norm: bool - if true use RMSNorm, else use LayerNorm
|
||||
// output_hp_norm: bool - if true, also output high precision normalized values (same dtype as input) for MoE gate.
|
||||
// Returns:
|
||||
// normed_output: [M, N/8] - FP4 quantized normalized output (uint32_t, packed)
|
||||
// output: [M, N] - pre-norm output (input + residual), same dtype as input
|
||||
// sf_out: scale factors for FP4 (uint8_t), swizzled layout
|
||||
// high_precision_normed_output: [M, N] - normalized output before quant (only if output_hp_norm=true, else empty)
|
||||
//
|
||||
// NOTE: This kernel requires SM90 (Hopper) or SM100 (Blackwell) GPU architecture.
|
||||
// NOTE: Hidden dimension N must be >= 2048 and <= 16384.
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_add_rms_norm_quant(at::Tensor const& input,
|
||||
at::Tensor const& residual, at::Tensor const& gamma, std::optional<at::Tensor> const& sf_scale, bool use_rms_norm,
|
||||
double eps)
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, std::optional<at::Tensor>> fused_add_rms_norm_quant(
|
||||
at::Tensor const& input, at::Tensor const& residual, at::Tensor const& gamma,
|
||||
std::optional<at::Tensor> const& sf_scale, bool use_rms_norm, double eps, bool output_hp_norm)
|
||||
{
|
||||
CHECK_TH_CUDA(input);
|
||||
CHECK_CONTIGUOUS(input);
|
||||
@ -116,6 +118,14 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_add_rms_norm_quant(at::Tens
|
||||
int64_t const sfSizePadded = tensorrt_llm::computeSwizzledLayoutSFSize(m_padded, n / sfVecSize);
|
||||
at::Tensor sf_out_padded = at::detail::empty_cuda({sfSizePadded}, SF_DTYPE, input.device(), std::nullopt);
|
||||
at::Tensor sf_out = (m_padded == m) ? sf_out_padded : sf_out_padded.narrow(0, 0, sfSize);
|
||||
std::optional<at::Tensor> high_precision_normed_output = std::nullopt;
|
||||
if (output_hp_norm)
|
||||
{
|
||||
at::Tensor hp_normed_output_padded
|
||||
= at::detail::empty_cuda({m_padded, n}, input.scalar_type(), input.device(), std::nullopt);
|
||||
high_precision_normed_output
|
||||
= (m_padded == m) ? hp_normed_output_padded : hp_normed_output_padded.narrow(0, 0, m);
|
||||
}
|
||||
|
||||
// Get number of SMs for persistent kernel
|
||||
static int const multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
@ -152,12 +162,14 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_add_rms_norm_quant(at::Tens
|
||||
param.bias = nullptr; \
|
||||
param.gamma = reinterpret_cast<T const*>(gamma.data_ptr()); \
|
||||
param.beta = nullptr; \
|
||||
param.high_precision_normed_output \
|
||||
= output_hp_norm ? reinterpret_cast<T*>(high_precision_normed_output.value().data_ptr()) : nullptr; \
|
||||
param.m = static_cast<int>(m); \
|
||||
param.n = static_cast<int>(n); \
|
||||
param.layernorm_eps = static_cast<float>(eps); \
|
||||
param.stream = stream; \
|
||||
param.counters = counters; \
|
||||
tensorrt_llm::kernels::invokeWSLayerNorm<Param>(param, use_rms_norm, multiProcessorCount); \
|
||||
tensorrt_llm::kernels::invokeWSLayerNorm<Param>(param, use_rms_norm, multiProcessorCount, output_hp_norm); \
|
||||
} while (0)
|
||||
|
||||
if (input.scalar_type() == at::ScalarType::Half)
|
||||
@ -180,7 +192,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_add_rms_norm_quant(at::Tens
|
||||
|
||||
#undef LAUNCH_FUSED_ADD_RMS_NORM_QUANT
|
||||
// No explicit sync needed - kernel runs asynchronously on the stream
|
||||
return std::make_tuple(normed_output, output, sf_out);
|
||||
return std::make_tuple(normed_output, output, sf_out, high_precision_normed_output);
|
||||
}
|
||||
|
||||
} // namespace torch_ext
|
||||
@ -191,7 +203,8 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def(
|
||||
"fused_add_rms_norm_quant(Tensor input, Tensor residual, Tensor gamma, "
|
||||
"Tensor? sf_scale, bool use_rms_norm=True, float eps=1e-6) -> (Tensor, Tensor, Tensor)");
|
||||
"Tensor? sf_scale, bool use_rms_norm=True, float eps=1e-6, bool output_hp_norm=False) -> (Tensor, Tensor, "
|
||||
"Tensor, Tensor?)");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
|
||||
@ -1003,7 +1003,9 @@ def _register_fake():
|
||||
sf_scale: Optional[torch.Tensor],
|
||||
use_rms_norm: bool = True,
|
||||
eps: float = 1e-5,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
output_hp_norm: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
|
||||
Optional[torch.Tensor]]:
|
||||
m, n = input.shape
|
||||
# normed_output_fp4: [M, N/8] as int32 (8 FP4 values packed per int32)
|
||||
normed_output_fp4 = input.new_empty((m, n // 8), dtype=torch.int32)
|
||||
@ -1013,4 +1015,22 @@ def _register_fake():
|
||||
sf_vec_size = 16
|
||||
sf_size = ((m + 127) // 128) * 128 * ((n // sf_vec_size + 3) // 4) * 4
|
||||
sf_out = input.new_empty((sf_size, ), dtype=torch.uint8)
|
||||
return normed_output_fp4, output, sf_out
|
||||
# high_precision_normed_output: [M, N] optional, only when output_hp_norm=True
|
||||
hp_output = input.new_empty(
|
||||
(m, n), dtype=input.dtype) if output_hp_norm else None
|
||||
return normed_output_fp4, output, sf_out, hp_output
|
||||
|
||||
@torch.library.register_fake("trtllm::fused_relu2_quantize")
|
||||
def _(
|
||||
input: torch.Tensor,
|
||||
sf_scale: torch.Tensor,
|
||||
sf_vec_size: int = 16,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# input: 2D tensor [M, N] (bf16 or fp16)
|
||||
# output_fp4: [M, N/2] (packed FP4 values, 2 values per byte)
|
||||
# output_sf: swizzled scale factors
|
||||
output_shape, scale_shape = fp4_utils.get_fp4_shape(
|
||||
input.shape, sf_vec_size, is_swizzled_layout=True)
|
||||
output_fp4 = input.new_empty(output_shape, dtype=torch.uint8)
|
||||
output_sf = input.new_empty((scale_shape, ), dtype=torch.uint8)
|
||||
return output_fp4, output_sf
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -23,6 +22,7 @@ from transformers import AutoConfig, PretrainedConfig
|
||||
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
|
||||
BaseWeightMapper
|
||||
from tensorrt_llm._torch.utils import ActivationType, relu2
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..distributed import AllReduce
|
||||
@ -37,7 +37,7 @@ from ..modules.mlp import MLP
|
||||
from ..modules.multi_stream_utils import maybe_execute_in_parallel
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from ..speculative import SpecMetadata
|
||||
from ..utils import AuxStreamType, EventType
|
||||
from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor
|
||||
from .modeling_deepseekv3 import DeepseekV3MTPHead
|
||||
from .modeling_speculative import SpecDecOneEngineForCausalLM
|
||||
from .modeling_utils import DecoderModel, register_auto_model
|
||||
@ -121,7 +121,7 @@ class NemotronHMOE(nn.Module):
|
||||
self,
|
||||
model_config: ModelConfig[PretrainedConfig],
|
||||
layer_idx: int,
|
||||
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
|
||||
aux_stream_dict: dict[AuxStreamType, torch.cuda.Stream],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -242,13 +242,20 @@ class NemotronHMOE(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states: torch.Tensor
|
||||
| tuple[torch.Tensor | Fp4QuantizedTensor, torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert hidden_states.shape[-1] == self.hidden_dim
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_dim)
|
||||
|
||||
if isinstance(hidden_states, tuple):
|
||||
hidden_states, hidden_states_hp = hidden_states
|
||||
else:
|
||||
hidden_states_hp = hidden_states
|
||||
|
||||
assert hidden_states_hp.shape[-1] == self.hidden_dim
|
||||
orig_shape = hidden_states_hp.shape
|
||||
hidden_states_hp_2d = hidden_states_hp.view(-1, self.hidden_dim)
|
||||
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
|
||||
|
||||
def _compute_shared_output():
|
||||
@ -259,7 +266,8 @@ class NemotronHMOE(nn.Module):
|
||||
return shared_expert_output
|
||||
|
||||
def _compute_routed_output():
|
||||
router_logits = self.gate(hidden_states)
|
||||
# Gate uses high precision input for accurate routing decisions.
|
||||
router_logits = self.gate(hidden_states_hp_2d)
|
||||
|
||||
routed_hidden_states = hidden_states
|
||||
if self.use_latent_moe:
|
||||
@ -301,7 +309,7 @@ class NemotronHLayer(DecoderLayer):
|
||||
# - -> MLPLayer
|
||||
# * -> TransformerLayer
|
||||
layer_type: str,
|
||||
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
|
||||
aux_stream_dict: dict[AuxStreamType, torch.cuda.Stream],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -310,10 +318,31 @@ class NemotronHLayer(DecoderLayer):
|
||||
self.layer_idx = layer_idx
|
||||
self.layer_type = layer_type
|
||||
|
||||
self.is_nvfp4 = (model_config.quant_config is not None
|
||||
and model_config.quant_config.quant_mode is not None
|
||||
and model_config.quant_config.quant_mode.has_nvfp4())
|
||||
# The fused RMSNorm+NVFP4 CUDA kernel requires hidden_size to be
|
||||
# a supported tile size. Non-power-of-2 hidden sizes within tile
|
||||
# ranges may cause kernel hangs. Disable fused NVFP4 for such cases.
|
||||
# Supported tile sizes: 2048, 4096, 8192, 16384
|
||||
_SUPPORTED_NVFP4_HIDDEN_SIZES = {2048, 4096, 8192, 16384}
|
||||
if self.is_nvfp4 and config.hidden_size not in _SUPPORTED_NVFP4_HIDDEN_SIZES:
|
||||
logger.warning_once(
|
||||
f"Layer {layer_idx}: Disabling fused NVFP4 RMSNorm for hidden_size={config.hidden_size}. "
|
||||
f"Supported sizes: {_SUPPORTED_NVFP4_HIDDEN_SIZES}. Using non-fused path.",
|
||||
key=f"disable_nvfp4_rmsnorm_with_{config.hidden_size}")
|
||||
self.is_nvfp4 = False
|
||||
|
||||
self.norm = RMSNorm(
|
||||
hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype,
|
||||
# Enable fused NVFP4 quantization if possible.
|
||||
# It might be overridden in `_try_attach_nvfp4_scale` function.
|
||||
quantize_type="nvfp4" if self.is_nvfp4 else None,
|
||||
# Enable high precision output for MoE layer (only with NVFP4).
|
||||
# It might be overridden in `_try_attach_nvfp4_scale` function.
|
||||
return_hp_output=layer_type == "E" and self.is_nvfp4,
|
||||
)
|
||||
|
||||
if layer_type == "M":
|
||||
@ -343,29 +372,71 @@ class NemotronHLayer(DecoderLayer):
|
||||
else:
|
||||
raise ValueError(f"{layer_type} is not supported")
|
||||
|
||||
def post_load_weights(self):
|
||||
"""Post-process after loading weights."""
|
||||
if self.norm.is_nvfp4 and not hasattr(self.norm, 'nvfp4_scale'):
|
||||
self._try_attach_nvfp4_scale()
|
||||
|
||||
def _try_attach_nvfp4_scale(self):
|
||||
"""Attach input_scale from mixer's first linear to norm for fused RMSNorm+Quant."""
|
||||
# Normal handling for Mamba, MLP, and Attention layers.
|
||||
first_linear_attr = {
|
||||
'M': 'in_proj',
|
||||
'-': 'up_proj',
|
||||
'*': 'qkv_proj'
|
||||
}.get(self.layer_type)
|
||||
if first_linear_attr:
|
||||
first_linear = getattr(self.mixer, first_linear_attr, None)
|
||||
if first_linear and hasattr(first_linear, 'input_scale'):
|
||||
self.norm.nvfp4_scale = first_linear.input_scale
|
||||
return
|
||||
|
||||
# Special handling for MoE layer: fetch shared_expert.up_proj.input_scale
|
||||
# as representation of the input scale.
|
||||
if self.layer_type == 'E':
|
||||
if (hasattr(self.mixer, 'shared_experts')
|
||||
and self.mixer.shared_experts is not None
|
||||
and hasattr(self.mixer.shared_experts, 'up_proj')
|
||||
and hasattr(self.mixer.shared_experts.up_proj,
|
||||
'input_scale') and
|
||||
self.mixer.shared_experts.up_proj.input_scale is not None):
|
||||
self.norm.nvfp4_scale = self.mixer.shared_experts.up_proj.input_scale
|
||||
# Enable high precision output for MoE layer.
|
||||
self.norm.return_hp_output = True
|
||||
return
|
||||
|
||||
self.norm.is_nvfp4 = False
|
||||
self.norm.return_hp_output = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: torch.IntTensor,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
residual: torch.Tensor | None = None,
|
||||
spec_metadata: SpecMetadata | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None:
|
||||
residual = torch.zeros_like(hidden_states)
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if self.norm.return_hp_output:
|
||||
hidden_states, residual, high_precision_normed_output = self.norm(
|
||||
hidden_states, residual)
|
||||
hidden_states = (hidden_states, high_precision_normed_output)
|
||||
else:
|
||||
hidden_states, residual = self.norm(hidden_states, residual)
|
||||
hidden_states = self.mixer(hidden_states,
|
||||
attn_metadata,
|
||||
spec_metadata=spec_metadata,
|
||||
**kwargs)
|
||||
hidden_states = torch.add(hidden_states, residual)
|
||||
|
||||
if spec_metadata is not None and spec_metadata.is_layer_capture(
|
||||
self.layer_idx):
|
||||
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
|
||||
hidden_states, None)
|
||||
hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class NemotronHModel(DecoderModel):
|
||||
@ -426,10 +497,10 @@ class NemotronHModel(DecoderModel):
|
||||
def forward(
|
||||
self,
|
||||
attn_metadata: AttentionMetadata,
|
||||
input_ids: Optional[torch.IntTensor] = None,
|
||||
position_ids: Optional[torch.IntTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
input_ids: torch.IntTensor | None = None,
|
||||
position_ids: torch.IntTensor | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
spec_metadata: SpecMetadata | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
@ -443,16 +514,15 @@ class NemotronHModel(DecoderModel):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
residual = torch.zeros_like(hidden_states)
|
||||
for layer in self.layers[:self.num_hidden_layers]:
|
||||
hidden_states = layer(position_ids,
|
||||
hidden_states,
|
||||
attn_metadata,
|
||||
spec_metadata=spec_metadata,
|
||||
mamba_metadata=mamba_metadata)
|
||||
|
||||
hidden_states = self.norm_f(hidden_states)
|
||||
|
||||
hidden_states, residual = layer(position_ids,
|
||||
hidden_states,
|
||||
residual=residual,
|
||||
attn_metadata=attn_metadata,
|
||||
spec_metadata=spec_metadata,
|
||||
mamba_metadata=mamba_metadata)
|
||||
hidden_states, _ = self.norm_f(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -517,7 +587,7 @@ class NemotronHForCausalLM(SpecDecOneEngineForCausalLM[NemotronHModel,
|
||||
self.epilogue.extend(self.draft_model.mtp_layers)
|
||||
self.epilogue.append(self.spec_worker)
|
||||
|
||||
def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper):
|
||||
def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper):
|
||||
new_weights = weight_mapper.preprocess_weights(weights)
|
||||
super().load_weights(weights=new_weights, weight_mapper=weight_mapper)
|
||||
|
||||
@ -528,7 +598,7 @@ class NemotronHMTPDecoderLayer(NemotronHLayer):
|
||||
self,
|
||||
model_config: ModelConfig[NemotronHConfig],
|
||||
layer_idx: int,
|
||||
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
|
||||
aux_stream_dict: dict[AuxStreamType, torch.cuda.Stream],
|
||||
has_start_projections: bool,
|
||||
has_end_norm: bool,
|
||||
layer_type: str,
|
||||
@ -625,7 +695,7 @@ class NemotronHMTPDecoderLayer(NemotronHLayer):
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
attn_metadata: AttentionMetadata | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
|
||||
if self.has_start_projections:
|
||||
@ -672,7 +742,7 @@ class NemotronHMTP(nn.Module):
|
||||
def __init__(self,
|
||||
model_config: ModelConfig[NemotronHConfig],
|
||||
layer_idx: int,
|
||||
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
|
||||
aux_stream_dict: dict[AuxStreamType, torch.cuda.Stream],
|
||||
is_separate_draft_engine: bool = False,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
@ -744,8 +814,8 @@ class NemotronHMTP(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
embed_tokens: Embedding,
|
||||
attn_metadata: AttentionMetadata,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
all_rank_num_tokens: list[int] | None = None,
|
||||
spec_metadata: SpecMetadata | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = embed_tokens(input_ids)
|
||||
|
||||
176
tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py
Normal file
176
tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py
Normal file
@ -0,0 +1,176 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fused elementwise operations for Mamba2 prefill optimization."""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _extract_transpose_prefill_kernel(
|
||||
src_ptr,
|
||||
dst_ptr,
|
||||
num_prefill_tokens,
|
||||
d_in_proj,
|
||||
d_inner,
|
||||
conv_dim,
|
||||
BLOCK_SEQ: tl.constexpr,
|
||||
BLOCK_CONV: tl.constexpr,
|
||||
):
|
||||
"""Extract src[0:num_prefill_tokens, d_inner:d_inner+conv_dim] and
|
||||
transpose to dst[conv_dim, num_prefill_tokens]."""
|
||||
pid_seq = tl.program_id(0)
|
||||
pid_conv = tl.program_id(1)
|
||||
|
||||
seq_offsets = pid_seq * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)
|
||||
conv_offsets = pid_conv * BLOCK_CONV + tl.arange(0, BLOCK_CONV)
|
||||
|
||||
seq_mask = seq_offsets < num_prefill_tokens
|
||||
conv_mask = conv_offsets < conv_dim
|
||||
mask = seq_mask[:, None] & conv_mask[None, :]
|
||||
|
||||
src_offsets = seq_offsets[:, None] * d_in_proj + (d_inner + conv_offsets[None, :])
|
||||
data = tl.load(src_ptr + src_offsets, mask=mask, other=0.0)
|
||||
|
||||
dst_offsets = conv_offsets[:, None] * num_prefill_tokens + seq_offsets[None, :]
|
||||
tl.store(dst_ptr + dst_offsets, tl.trans(data), mask=conv_mask[:, None] & seq_mask[None, :])
|
||||
|
||||
|
||||
def extract_transpose_xbc_prefill(
|
||||
zxbcdt: torch.Tensor,
|
||||
num_prefill_tokens: int,
|
||||
d_inner: int,
|
||||
conv_dim: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Extract and transpose xbc slice from zxbcdt for causal_conv1d_fn.
|
||||
|
||||
Input: zxbcdt[num_tokens, d_in_proj]
|
||||
Output: [conv_dim, num_prefill_tokens]
|
||||
"""
|
||||
out = torch.empty(conv_dim, num_prefill_tokens, dtype=zxbcdt.dtype, device=zxbcdt.device)
|
||||
|
||||
BLOCK_SEQ, BLOCK_CONV = 32, 128
|
||||
grid = (triton.cdiv(num_prefill_tokens, BLOCK_SEQ), triton.cdiv(conv_dim, BLOCK_CONV))
|
||||
|
||||
_extract_transpose_prefill_kernel[grid](
|
||||
zxbcdt,
|
||||
out,
|
||||
num_prefill_tokens,
|
||||
zxbcdt.shape[1],
|
||||
d_inner,
|
||||
conv_dim,
|
||||
BLOCK_SEQ,
|
||||
BLOCK_CONV,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fused_conv_output_transpose_kernel(
|
||||
src_ptr,
|
||||
out_x_ptr,
|
||||
out_B_ptr,
|
||||
out_C_ptr,
|
||||
num_prefill_tokens,
|
||||
d_inner,
|
||||
bc_size,
|
||||
x_tiles,
|
||||
bc_tiles,
|
||||
BLOCK_SEQ: tl.constexpr,
|
||||
BLOCK_DIM: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Transpose and split conv1d output into x, B, C using linear grid mapping.
|
||||
|
||||
Grid: tiles [0, x_tiles) -> x, [x_tiles, x_tiles+bc_tiles) -> B, rest -> C
|
||||
"""
|
||||
tile_id = tl.program_id(0)
|
||||
|
||||
is_x = tile_id < x_tiles
|
||||
is_B = (tile_id >= x_tiles) & (tile_id < x_tiles + bc_tiles)
|
||||
|
||||
local_tile = tl.where(
|
||||
is_x, tile_id, tl.where(is_B, tile_id - x_tiles, tile_id - x_tiles - bc_tiles)
|
||||
)
|
||||
dim_size = tl.where(is_x, d_inner, bc_size)
|
||||
num_dim_blocks = tl.cdiv(dim_size, BLOCK_DIM)
|
||||
|
||||
pid_seq = local_tile // num_dim_blocks
|
||||
pid_dim = local_tile % num_dim_blocks
|
||||
|
||||
seq_offsets = pid_seq * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)
|
||||
dim_offsets = pid_dim * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
|
||||
|
||||
seq_mask = seq_offsets < num_prefill_tokens
|
||||
dim_mask = dim_offsets < dim_size
|
||||
|
||||
src_offset = tl.where(is_x, 0, tl.where(is_B, d_inner, d_inner + bc_size))
|
||||
src_indices = (src_offset + dim_offsets[:, None]) * num_prefill_tokens + seq_offsets[None, :]
|
||||
data = tl.load(src_ptr + src_indices, mask=dim_mask[:, None] & seq_mask[None, :], other=0.0)
|
||||
|
||||
out_ptr = tl.where(is_x, out_x_ptr, tl.where(is_B, out_B_ptr, out_C_ptr))
|
||||
dst_indices = seq_offsets[:, None] * dim_size + dim_offsets[None, :]
|
||||
tl.store(out_ptr + dst_indices, tl.trans(data), mask=seq_mask[:, None] & dim_mask[None, :])
|
||||
|
||||
|
||||
def fused_split_rearrange_after_conv1d(
|
||||
xbc: torch.Tensor,
|
||||
d_inner: int,
|
||||
n_groups: int,
|
||||
d_state: int,
|
||||
nheads: int,
|
||||
head_dim: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Split and rearrange causal_conv1d output into contiguous x, B, C tensors.
|
||||
|
||||
Input: xbc[conv_dim, num_prefill_tokens]
|
||||
Output: x[1, num_prefill_tokens, nheads, head_dim],
|
||||
B[1, num_prefill_tokens, n_groups, d_state],
|
||||
C[1, num_prefill_tokens, n_groups, d_state]
|
||||
"""
|
||||
conv_dim, num_prefill_tokens = xbc.shape
|
||||
bc_size = n_groups * d_state
|
||||
|
||||
x_flat = torch.empty(num_prefill_tokens, d_inner, dtype=xbc.dtype, device=xbc.device)
|
||||
B_flat = torch.empty(num_prefill_tokens, bc_size, dtype=xbc.dtype, device=xbc.device)
|
||||
C_flat = torch.empty(num_prefill_tokens, bc_size, dtype=xbc.dtype, device=xbc.device)
|
||||
|
||||
BLOCK_SEQ, BLOCK_DIM = 64, 64
|
||||
num_seq_blocks = triton.cdiv(num_prefill_tokens, BLOCK_SEQ)
|
||||
x_tiles = num_seq_blocks * triton.cdiv(d_inner, BLOCK_DIM)
|
||||
bc_tiles = num_seq_blocks * triton.cdiv(bc_size, BLOCK_DIM)
|
||||
|
||||
_fused_conv_output_transpose_kernel[(x_tiles + 2 * bc_tiles,)](
|
||||
xbc,
|
||||
x_flat,
|
||||
B_flat,
|
||||
C_flat,
|
||||
num_prefill_tokens,
|
||||
d_inner,
|
||||
bc_size,
|
||||
x_tiles,
|
||||
bc_tiles,
|
||||
BLOCK_SEQ,
|
||||
BLOCK_DIM,
|
||||
)
|
||||
|
||||
return (
|
||||
x_flat.view(1, num_prefill_tokens, nheads, head_dim),
|
||||
B_flat.view(1, num_prefill_tokens, n_groups, d_state),
|
||||
C_flat.view(1, num_prefill_tokens, n_groups, d_state),
|
||||
)
|
||||
@ -17,6 +17,8 @@ import math
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata
|
||||
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
|
||||
@ -25,6 +27,86 @@ from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import \
|
||||
use_cpp_mamba_cache_manager
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _cu_seqlens_triton_kernel(
|
||||
cu_seqlens_ptr, # [num_seqs + 1]
|
||||
chunk_indices_ptr, # [N] output
|
||||
chunk_offsets_ptr, # [N] output
|
||||
num_seqs: tl.constexpr,
|
||||
chunk_size: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""Computes chunk_indices and chunk_offsets in a single kernel launch."""
|
||||
pid = tl.program_id(0)
|
||||
chunk_start = pid * BLOCK_SIZE
|
||||
offsets = chunk_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < N
|
||||
chunk_indices = offsets.to(tl.int64)
|
||||
chunk_offsets = tl.zeros([BLOCK_SIZE], dtype=tl.int64)
|
||||
|
||||
p = 0
|
||||
for seq_idx in range(num_seqs - 1):
|
||||
seq_start = tl.load(cu_seqlens_ptr + seq_idx + 1).to(tl.int64)
|
||||
seq_end = tl.load(cu_seqlens_ptr + seq_idx + 2).to(tl.int64)
|
||||
is_misaligned = (seq_start % chunk_size) > 0
|
||||
p = p + is_misaligned
|
||||
s_chunk = seq_start // chunk_size + p
|
||||
e_chunk = seq_end // chunk_size + p + ((seq_end % chunk_size) > 0)
|
||||
in_range = (offsets >= s_chunk) & (offsets < e_chunk)
|
||||
chunk_indices = tl.where(in_range & mask, chunk_indices - p,
|
||||
chunk_indices)
|
||||
is_start = (offsets == s_chunk)
|
||||
chunk_offsets = tl.where(is_start & mask, seq_start % chunk_size,
|
||||
chunk_offsets)
|
||||
|
||||
tl.store(chunk_indices_ptr + offsets, chunk_indices.to(tl.int32), mask=mask)
|
||||
tl.store(chunk_offsets_ptr + offsets, chunk_offsets.to(tl.int32), mask=mask)
|
||||
|
||||
|
||||
def cu_seqlens_to_chunk_indices_offsets_triton(
|
||||
cu_seqlens: torch.Tensor,
|
||||
chunk_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Optimized version of cu_seqlens_to_chunk_indices_offsets."""
|
||||
device = cu_seqlens.device
|
||||
num_seqs = cu_seqlens.numel() - 1
|
||||
|
||||
if num_seqs == 0:
|
||||
return (torch.empty(0, dtype=torch.int, device=device),
|
||||
torch.empty(0, dtype=torch.int, device=device))
|
||||
|
||||
cu = cu_seqlens.to(dtype=torch.int64)
|
||||
total_seqlens = cu[-1].item()
|
||||
|
||||
if num_seqs == 1:
|
||||
# Fast path for single sequence (no boundaries to process)
|
||||
N = (total_seqlens + chunk_size - 1) // chunk_size
|
||||
return (torch.arange(N, device=device, dtype=torch.int),
|
||||
torch.zeros(N, device=device, dtype=torch.int))
|
||||
|
||||
seq_starts = cu[1:-1]
|
||||
misaligned = ((seq_starts % chunk_size) > 0).to(torch.int64)
|
||||
p = torch.cumsum(misaligned, dim=0)
|
||||
extra_chunks = p[-1].item() if p.numel() > 0 else 0
|
||||
N = (total_seqlens + chunk_size - 1) // chunk_size + extra_chunks
|
||||
chunk_indices = torch.empty(N, device=device, dtype=torch.int)
|
||||
chunk_offsets = torch.empty(N, device=device, dtype=torch.int)
|
||||
|
||||
BLOCK_SIZE = 256
|
||||
grid = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, )
|
||||
_cu_seqlens_triton_kernel[grid](
|
||||
cu,
|
||||
chunk_indices,
|
||||
chunk_offsets,
|
||||
num_seqs=num_seqs,
|
||||
chunk_size=chunk_size,
|
||||
N=N,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return chunk_indices, chunk_offsets
|
||||
|
||||
|
||||
def cu_seqlens_to_chunk_indices_offsets(
|
||||
cu_seqlens: torch.Tensor,
|
||||
chunk_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -117,15 +199,15 @@ class Mamba2Metadata:
|
||||
self.state_indices = torch.zeros(max_batch_size,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
self._query_start_loc_long_buf = torch.arange(0,
|
||||
max_batch_size + 1,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
self._query_start_loc_buf = torch.zeros(max_batch_size + 1,
|
||||
dtype=torch.int,
|
||||
device="cuda")
|
||||
self.query_start_loc_long = self._query_start_loc_long_buf
|
||||
self.query_start_loc = self._query_start_loc_buf
|
||||
|
||||
# Pre-allocated buffers.
|
||||
self._arange_buffer = torch.arange(max_batch_size + 1,
|
||||
dtype=torch.int,
|
||||
device="cuda")
|
||||
self._arange_buffer_long = self._arange_buffer.to(torch.long)
|
||||
self._cu_seqlens_long = torch.zeros(max_batch_size + 1,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
|
||||
def prepare(self, attn_metadata: AttentionMetadata):
|
||||
batch_size = attn_metadata.seq_lens.shape[0]
|
||||
@ -158,47 +240,32 @@ class Mamba2Metadata:
|
||||
dtype=torch.int,
|
||||
out=self.cu_seqlens[1:num_contexts + 1])
|
||||
torch.add(self.cu_seqlens[num_contexts],
|
||||
torch.arange(1,
|
||||
batch_size - num_contexts + 1,
|
||||
dtype=self.cu_seqlens.dtype,
|
||||
device=self.cu_seqlens.device),
|
||||
self._arange_buffer[1:batch_size - num_contexts + 1],
|
||||
out=self.cu_seqlens[num_contexts + 1:batch_size + 1])
|
||||
# Need both `query_start_loc` and `query_start_loc_long` because `causal_conv1d_fn`
|
||||
# accepts only `int32` while `chunk_gated_delta_rule` accepts only `long`.
|
||||
self._query_start_loc_buf[:batch_size +
|
||||
1] = self.cu_seqlens[:batch_size + 1]
|
||||
self.query_start_loc = self._query_start_loc_buf[:batch_size + 1]
|
||||
self._query_start_loc_long_buf[:batch_size + 1].copy_(
|
||||
self.query_start_loc.to(torch.long), non_blocking=True)
|
||||
self.query_start_loc_long = self._query_start_loc_long_buf[:
|
||||
batch_size
|
||||
+ 1]
|
||||
self.query_start_loc = self.cu_seqlens[:batch_size + 1]
|
||||
self._cu_seqlens_long[:batch_size + 1].copy_(self.query_start_loc)
|
||||
self.query_start_loc_long = self._cu_seqlens_long[:batch_size + 1]
|
||||
self.seq_idx = torch.repeat_interleave(
|
||||
torch.arange(num_contexts,
|
||||
dtype=torch.int,
|
||||
device=self.cu_seqlens.device),
|
||||
self._arange_buffer[:num_contexts],
|
||||
repeats=context_lens,
|
||||
output_size=num_ctx_tokens).unsqueeze(0)
|
||||
|
||||
num_cached_tokens_per_seq = attn_metadata.kv_cache_params.num_cached_tokens_per_seq
|
||||
self.has_initial_states[:num_contexts] = torch.tensor(
|
||||
num_cached_tokens_per_seq[:num_contexts]) > 0
|
||||
# precomputed bool to avoid host<->device syncs during forward pass
|
||||
self.use_initial_states = torch.any(
|
||||
self.has_initial_states[:num_contexts]).item()
|
||||
initial_states = [
|
||||
num_cached_tokens_per_seq[i] > 0 for i in range(num_contexts)
|
||||
]
|
||||
self.use_initial_states = any(initial_states)
|
||||
if self.use_initial_states:
|
||||
self.chunk_indices, self.chunk_offsets = cu_seqlens_to_chunk_indices_offsets(
|
||||
self.has_initial_states[:num_contexts] = torch.tensor(
|
||||
initial_states, dtype=torch.bool)
|
||||
self.chunk_indices, self.chunk_offsets = cu_seqlens_to_chunk_indices_offsets_triton(
|
||||
self.cu_seqlens[:num_contexts + 1], self.chunk_size)
|
||||
else:
|
||||
self.chunk_indices = None
|
||||
self.chunk_offsets = None
|
||||
else:
|
||||
self.query_start_loc = None
|
||||
torch.arange(0,
|
||||
batch_size + 1,
|
||||
dtype=torch.long,
|
||||
device=self.cu_seqlens.device,
|
||||
out=self._query_start_loc_long_buf[:batch_size + 1])
|
||||
self.query_start_loc_long = self._query_start_loc_long_buf[:
|
||||
batch_size
|
||||
+ 1]
|
||||
self.query_start_loc_long = self._arange_buffer_long[:batch_size +
|
||||
1]
|
||||
|
||||
@ -33,6 +33,8 @@ from ..linear import Linear, TensorParallelMode
|
||||
from .causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||
from .causal_conv1d_triton import \
|
||||
causal_conv1d_update as causal_conv1d_update_triton
|
||||
from .fuse_elementwise_ops import (extract_transpose_xbc_prefill,
|
||||
fused_split_rearrange_after_conv1d)
|
||||
from .layernorm_gated import RMSNorm as RMSNormGated
|
||||
from .selective_state_update import \
|
||||
selective_state_update as selective_state_update_native
|
||||
@ -227,15 +229,17 @@ class Mamba2Mixer(nn.Module):
|
||||
|
||||
# in_proj
|
||||
zxbcdt = self.in_proj(hidden_states)
|
||||
z, xbc, dt = torch.split(
|
||||
zxbcdt,
|
||||
[self.tp_d_inner, self.tp_conv_dim, self.tp_nheads],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# Split z and dt with views.
|
||||
z = zxbcdt[:, :self.tp_d_inner]
|
||||
dt = zxbcdt[:, self.tp_d_inner + self.tp_conv_dim:]
|
||||
z_p, z_d = torch.split(z, seqlen_split_size, dim=0)
|
||||
xbc_p, xbc_d = torch.split(xbc, seqlen_split_size, dim=0)
|
||||
dt_p, dt_d = torch.split(dt, seqlen_split_size, dim=0)
|
||||
|
||||
# Decode path uses regular view since no transpose is needed.
|
||||
xbc_d = zxbcdt[num_prefill_tokens:num_actual_tokens,
|
||||
self.tp_d_inner:self.tp_d_inner + self.tp_conv_dim]
|
||||
|
||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||
# and decode outputs
|
||||
preallocated_ssm_out = torch.empty(
|
||||
@ -243,8 +247,8 @@ class Mamba2Mixer(nn.Module):
|
||||
zxbcdt.shape[0],
|
||||
(self.num_heads * self.head_dim) // self.tp_size,
|
||||
],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
dtype=zxbcdt.dtype,
|
||||
device=zxbcdt.device,
|
||||
)
|
||||
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
|
||||
preallocated_ssm_out,
|
||||
@ -259,27 +263,29 @@ class Mamba2Mixer(nn.Module):
|
||||
has_initial_states = mamba_metadata.has_initial_states[:
|
||||
num_prefills]
|
||||
|
||||
xbc_p = causal_conv1d_fn(xbc_p.transpose(0, 1),
|
||||
# Fused kernel to avoid expensive .contiguous() call in causal_conv1d_fn.
|
||||
xbc_p_t = extract_transpose_xbc_prefill(zxbcdt, num_prefill_tokens,
|
||||
self.tp_d_inner,
|
||||
self.tp_conv_dim)
|
||||
xbc_p = causal_conv1d_fn(xbc_p_t,
|
||||
self.conv1d.weight,
|
||||
self.conv1d.bias,
|
||||
activation="silu",
|
||||
conv_states=conv_states,
|
||||
has_initial_state=has_initial_states,
|
||||
query_start_loc=cu_seqlens,
|
||||
cache_indices=state_indices_p).transpose(
|
||||
0, 1)
|
||||
cache_indices=state_indices_p)
|
||||
|
||||
x_p, B_p, C_p = torch.split(xbc_p.unsqueeze(0), [
|
||||
# Fused kernel to avoid expensive .contiguous() calls after split/rearrange.
|
||||
x_p, B_p, C_p = fused_split_rearrange_after_conv1d(
|
||||
xbc_p,
|
||||
self.tp_d_inner,
|
||||
self.tp_ngroups * self.d_state,
|
||||
self.tp_ngroups * self.d_state,
|
||||
],
|
||||
dim=-1)
|
||||
|
||||
x_p = rearrange(x_p, "b l (h p) -> b l h p", h=self.tp_nheads)
|
||||
self.tp_ngroups,
|
||||
self.d_state,
|
||||
self.tp_nheads,
|
||||
self.head_dim,
|
||||
)
|
||||
dt_p = dt_p.unsqueeze(0)
|
||||
B_p = rearrange(B_p, "b l (g n) -> b l g n", g=self.tp_ngroups)
|
||||
C_p = rearrange(C_p, "b l (g n) -> b l g n", g=self.tp_ngroups)
|
||||
z_p = rearrange(z_p.unsqueeze(0),
|
||||
"b l (h p) -> b l h p",
|
||||
h=self.tp_nheads)
|
||||
|
||||
@ -26,6 +26,124 @@ TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
# =================================================================
|
||||
# Higher warp count configs for better latency hiding
|
||||
# More warps = more instructions in flight = better memory latency hiding
|
||||
# =================================================================
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 32
|
||||
},
|
||||
num_stages=2,
|
||||
num_warps=8, # 8 warps = 256 threads per block
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32
|
||||
},
|
||||
num_stages=2,
|
||||
num_warps=8, # 8 warps for better latency hiding
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 32
|
||||
},
|
||||
num_stages=2,
|
||||
num_warps=8,
|
||||
),
|
||||
# Smaller tiles with more stages for software pipelining
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 32
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64
|
||||
},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
# =================================================================
|
||||
# Low register pressure configs (num_stages=1) for large dstate
|
||||
# =================================================================
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
# num_stages=2 configs - moderate register pressure
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64
|
||||
},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32
|
||||
},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32
|
||||
},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
# Original configs for smaller dstate values
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
@ -355,14 +473,17 @@ def _chunk_scan_fwd_kernel(
|
||||
|
||||
if not HAS_INITSTATES:
|
||||
# - this is for continuous batching where there is no init states
|
||||
scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m),
|
||||
# Use exp2 for faster computation: exp(x) = exp2(x * log2(e))
|
||||
scale_m = tl.where(seq_idx_m == seq_idx_prev,
|
||||
tl.math.exp2(dA_cs_m * 1.4426950408889634),
|
||||
0.0)
|
||||
else:
|
||||
# - if there is initstates, we will rely on prev_states, no zeroing
|
||||
# required.
|
||||
scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary)
|
||||
scale_m = tl.math.exp2(
|
||||
(dA_cs_m - dA_cs_m_boundary) * 1.4426950408889634)
|
||||
else:
|
||||
scale_m = tl.exp(dA_cs_m)
|
||||
scale_m = tl.math.exp2(dA_cs_m * 1.4426950408889634)
|
||||
if BLOCK_SIZE_DSTATE <= 128:
|
||||
C = tl.load(
|
||||
C_ptrs,
|
||||
@ -421,7 +542,9 @@ def _chunk_scan_fwd_kernel(
|
||||
other=0.0).to(tl.float32)
|
||||
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
|
||||
# So we don't need masking wrt seq_idx here.
|
||||
cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :])
|
||||
# Use exp2 for faster computation: exp(x) = exp2(x * log2(e))
|
||||
cb *= tl.math.exp2(
|
||||
(dA_cs_m[:, None] - dA_cs_k[None, :]) * 1.4426950408889634)
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k,
|
||||
other=0.0).to(tl.float32)
|
||||
cb *= dt_k
|
||||
|
||||
@ -128,6 +128,54 @@ def _chunk_cumsum_fwd_kernel(
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
# Small headdim/dstate configs (hdim<=64, dstate<=128) - increased parallelism
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 32
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 32
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=4,
|
||||
),
|
||||
# Low register pressure configs for large dstate (dstate=128)
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64
|
||||
},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64
|
||||
},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
# Original configs for larger headdim/dstate values
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
@ -175,40 +223,13 @@ def _chunk_cumsum_fwd_kernel(
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 32
|
||||
},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32
|
||||
},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=2,
|
||||
),
|
||||
],
|
||||
key=["hdim", "dstate", "chunk_size"],
|
||||
)
|
||||
@ -351,6 +372,41 @@ def _chunk_state_fwd_kernel(
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
# Small headdim/dstate configs for better parallelism
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 32
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 32
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=4),
|
||||
# Low register pressure configs
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64
|
||||
},
|
||||
num_stages=2,
|
||||
num_warps=4),
|
||||
# Original configs for larger dimensions
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
@ -393,36 +449,12 @@ def _chunk_state_fwd_kernel(
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 32
|
||||
},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32
|
||||
},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=2),
|
||||
],
|
||||
key=["hdim", "dstate", "chunk_size"],
|
||||
)
|
||||
|
||||
@ -8,23 +8,25 @@ from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ..model_config import ModelConfig
|
||||
from ..peft.lora.layer import LoraLayer, LoraModuleType
|
||||
from ..utils import Fp4QuantizedTensor, relu2
|
||||
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
bias: bool,
|
||||
activation: Callable[[torch.Tensor], torch.Tensor] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
config: Optional[ModelConfig] = None,
|
||||
layer_idx: Optional[int] = None,
|
||||
reduce_output: bool = True,
|
||||
overridden_tp_size: Optional[int] = None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
bias: bool,
|
||||
activation: Callable[[torch.Tensor], torch.Tensor] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
config: Optional[ModelConfig] = None,
|
||||
layer_idx: Optional[int] = None,
|
||||
reduce_output: bool = True,
|
||||
overridden_tp_size: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.hidden_size = hidden_size
|
||||
@ -81,7 +83,22 @@ class MLP(nn.Module):
|
||||
lora=self.down_lora,
|
||||
allreduce_strategy=config.allreduce_strategy,
|
||||
force_dynamic_quantization=config.force_dynamic_quantization,
|
||||
reduce_output=reduce_output)
|
||||
reduce_output=reduce_output,
|
||||
)
|
||||
|
||||
self._use_fused_relu2_quant = False
|
||||
|
||||
def create_weights(self):
|
||||
self.up_proj.create_weights()
|
||||
self.down_proj.create_weights()
|
||||
|
||||
has_nvfp4 = hasattr(self.down_proj,
|
||||
'has_nvfp4') and self.down_proj.has_nvfp4
|
||||
has_kernel = hasattr(torch.ops.trtllm, 'fused_relu2_quantize')
|
||||
has_scale = hasattr(self.down_proj, 'input_scale')
|
||||
is_relu2 = self.activation is relu2
|
||||
|
||||
self._use_fused_relu2_quant = has_nvfp4 and has_kernel and has_scale and is_relu2
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -92,11 +109,34 @@ class MLP(nn.Module):
|
||||
return self.forward_lora(x, lora_params=lora_params)
|
||||
|
||||
x_up = self.up_proj(x)
|
||||
x_act = self.activation(x_up)
|
||||
|
||||
if self._use_fused_relu2_quant:
|
||||
x_act = self._fused_relu2_quant(x_up)
|
||||
else:
|
||||
x_act = self.activation(x_up)
|
||||
|
||||
x_down = self.down_proj(x_act)
|
||||
|
||||
return x_down
|
||||
|
||||
def _fused_relu2_quant(self, x: torch.Tensor) -> Fp4QuantizedTensor:
|
||||
x_flat = x.view(-1, x.shape[-1])
|
||||
|
||||
if not x_flat.is_contiguous():
|
||||
x_flat = x_flat.contiguous()
|
||||
|
||||
if x_flat.dtype not in (torch.float16, torch.bfloat16):
|
||||
x_flat = x_flat.to(torch.bfloat16)
|
||||
|
||||
fp4_tensor, sf_tensor = torch.ops.trtllm.fused_relu2_quantize(
|
||||
x_flat, self.down_proj.input_scale, 16)
|
||||
|
||||
return Fp4QuantizedTensor(
|
||||
fp4_tensor=fp4_tensor,
|
||||
scaling_factor=sf_tensor,
|
||||
is_sf_swizzled=True,
|
||||
)
|
||||
|
||||
def forward_lora(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
|
||||
@ -41,6 +41,7 @@ class RMSNorm(nn.Module):
|
||||
use_gemma: bool = False,
|
||||
quantize_type: Optional[str] = None,
|
||||
use_cuda_tile: bool = False,
|
||||
return_hp_output: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -72,6 +73,7 @@ class RMSNorm(nn.Module):
|
||||
self.variance_epsilon = eps
|
||||
self.use_gemma = use_gemma
|
||||
self.use_cuda_tile = use_cuda_tile
|
||||
self.return_hp_output = return_hp_output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -80,7 +82,8 @@ class RMSNorm(nn.Module):
|
||||
Optional[torch.Tensor],
|
||||
_ArgumentNotSpecifiedSentinelType] = _ARGUMENT_NOT_SPECIFIED_SENTINEL,
|
||||
) -> Union[torch.Tensor, Fp4QuantizedTensor, Tuple[Union[
|
||||
torch.Tensor, Fp4QuantizedTensor], Optional[torch.Tensor]]]:
|
||||
torch.Tensor, Fp4QuantizedTensor], Optional[torch.Tensor]], Tuple[
|
||||
Fp4QuantizedTensor, torch.Tensor, torch.Tensor]]:
|
||||
has_residual = residual is not self._ARGUMENT_NOT_SPECIFIED_SENTINEL
|
||||
if not has_residual:
|
||||
residual = None
|
||||
@ -116,14 +119,16 @@ class RMSNorm(nn.Module):
|
||||
|
||||
sf_scale = nvfp4_scale.contiguous()
|
||||
|
||||
normed_fp4_i32, residual_out_2d, sf_fused = torch.ops.trtllm.fused_add_rms_norm_quant(
|
||||
results = torch.ops.trtllm.fused_add_rms_norm_quant(
|
||||
hs_2d,
|
||||
res_2d,
|
||||
gamma,
|
||||
sf_scale,
|
||||
True,
|
||||
eps=self.variance_epsilon,
|
||||
output_hp_norm=self.return_hp_output,
|
||||
)
|
||||
normed_fp4_i32, residual_out_2d, sf_fused = results[:3]
|
||||
normed_fp4_u8 = normed_fp4_i32.view(torch.uint8)
|
||||
if len(orig_shape) != 2:
|
||||
normed_fp4_u8 = normed_fp4_u8.reshape(*orig_shape[:-1], n // 2)
|
||||
@ -132,9 +137,21 @@ class RMSNorm(nn.Module):
|
||||
residual_out = residual_out_2d
|
||||
|
||||
hidden_states_fused = Fp4QuantizedTensor(normed_fp4_u8, sf_fused)
|
||||
return (hidden_states_fused,
|
||||
residual_out) if has_residual else hidden_states_fused
|
||||
elif self.use_cuda_tile:
|
||||
|
||||
outputs = [hidden_states_fused]
|
||||
if has_residual:
|
||||
outputs.append(residual_out)
|
||||
if self.return_hp_output:
|
||||
high_precision_normed_output = results[3].reshape(orig_shape)
|
||||
outputs.append(high_precision_normed_output)
|
||||
return outputs[0] if len(outputs) == 1 else tuple(outputs)
|
||||
|
||||
if self.return_hp_output:
|
||||
raise ValueError(
|
||||
"Auxiliary high precision output is only supported for NVFP4 fused path"
|
||||
)
|
||||
|
||||
if self.use_cuda_tile:
|
||||
if isinstance(residual, torch.Tensor):
|
||||
# Use fused residual kernel
|
||||
hidden_states = hidden_states.contiguous()
|
||||
|
||||
@ -593,7 +593,9 @@ class MambaCacheManager(BaseResourceManager):
|
||||
return self._impl.get_intermediate_conv_states(layer_idx)
|
||||
|
||||
def is_speculative(self) -> bool:
|
||||
assert not self._use_cpp, "is_speculative is not supported in CppMambaCacheManager"
|
||||
if self._use_cpp:
|
||||
# CppMambaCacheManager does not support speculative decoding for now.
|
||||
return False
|
||||
return self._impl.is_speculative()
|
||||
|
||||
def mamba_layer_cache(
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
import os
|
||||
import unittest.mock
|
||||
@ -443,9 +444,9 @@ class Runner:
|
||||
|
||||
def forward(position_ids, hidden_states, attn_metadata, residual, **kwargs):
|
||||
# TODO: to be more general, we should call DecoderModel.forward
|
||||
residual_fusion = hasattr(model.model.layers[layer_indices[0]], "next_layer_layernorm")
|
||||
for layer_idx in layer_indices:
|
||||
layer = model.model.layers[layer_idx]
|
||||
residual_fusion = "residual" in inspect.signature(layer.forward).parameters
|
||||
if residual_fusion:
|
||||
hidden_states, residual = layer(
|
||||
position_ids, hidden_states, attn_metadata, residual, **kwargs
|
||||
|
||||
246
tests/unittest/_torch/modules/mamba/test_causal_conv1d.py
Normal file
246
tests/unittest/_torch/modules/mamba/test_causal_conv1d.py
Normal file
@ -0,0 +1,246 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tensorrt_llm._torch.modules.mamba import PAD_SLOT_ID
|
||||
|
||||
|
||||
def mamba_conv1d_ref(x, past_conv_state, conv_weight, conv_bias, apply_silu):
|
||||
"""
|
||||
Reference implementation for causal conv1d.
|
||||
|
||||
Arguments:
|
||||
x: [batch_size, dim, seq_len]
|
||||
past_conv_state: [batch_size, dim, dconv-1]
|
||||
conv_weight: [dim, 1, dconv]
|
||||
conv_bias: [dim]
|
||||
Output:
|
||||
y: [batch_size, dim, seq_len]
|
||||
present_conv_state: [batch_size, dim, dconv-1]
|
||||
"""
|
||||
assert x.dim() == 3
|
||||
assert past_conv_state.dim() == 3
|
||||
assert conv_weight.dim() == 3
|
||||
assert conv_bias.dim() == 1
|
||||
batch_size, dim, seq_len = x.shape
|
||||
assert past_conv_state.shape[0] == batch_size
|
||||
assert past_conv_state.shape[1] == dim
|
||||
dconv = past_conv_state.shape[2] + 1
|
||||
assert conv_weight.shape[0] == dim
|
||||
assert conv_weight.shape[1] == 1
|
||||
assert conv_weight.shape[2] == dconv
|
||||
|
||||
padded_x = torch.cat([past_conv_state, x], dim=2)
|
||||
present_conv_state = padded_x[:, :, -(dconv - 1) :]
|
||||
x_conv = F.conv1d(padded_x, conv_weight, bias=conv_bias, groups=dim)
|
||||
|
||||
y = F.silu(x_conv) if apply_silu else x_conv
|
||||
return y, present_conv_state
|
||||
|
||||
|
||||
def trtllm_causal_conv1d_available():
|
||||
"""Check if trtllm.causal_conv1d_fwd is available."""
|
||||
return hasattr(torch.ops, "trtllm") and hasattr(torch.ops.trtllm, "causal_conv1d_fwd")
|
||||
|
||||
|
||||
skip_unsupported = pytest.mark.skipif(
|
||||
not torch.cuda.is_available() or not trtllm_causal_conv1d_available(),
|
||||
reason="Requires CUDA and trtllm.causal_conv1d_fwd op",
|
||||
)
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
class TestCausalConv1d:
|
||||
"""Tests for the causal_conv1d CUDA kernel."""
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["float16", "bfloat16", "float32"])
|
||||
@pytest.mark.parametrize("apply_silu", [True, False])
|
||||
@pytest.mark.parametrize("dim", [256, 512, 1024, 2048])
|
||||
def test_basic_correctness(self, dtype, apply_silu, dim):
|
||||
"""Test basic correctness against reference implementation."""
|
||||
torch.manual_seed(42)
|
||||
device = "cuda"
|
||||
torch_dtype = getattr(torch, dtype)
|
||||
|
||||
batch_size = 4
|
||||
seq_len = 32
|
||||
dconv = 4
|
||||
std_dev = 0.5
|
||||
x = torch.randn(batch_size, dim, seq_len, dtype=torch_dtype, device=device)
|
||||
x = x * std_dev
|
||||
conv_state = torch.zeros(batch_size, dim, dconv - 1, dtype=torch_dtype, device=device)
|
||||
conv_weight = torch.randn(dim, 1, dconv, dtype=torch_dtype, device=device)
|
||||
conv_bias = torch.randn(dim, dtype=torch_dtype, device=device)
|
||||
x_kernel = x.clone()
|
||||
conv_state_kernel = conv_state.clone()
|
||||
|
||||
conv_weight_input = conv_weight.squeeze(1).contiguous()
|
||||
torch.ops.trtllm.causal_conv1d_fwd(
|
||||
x_kernel,
|
||||
conv_weight_input,
|
||||
conv_bias,
|
||||
conv_state_kernel,
|
||||
None, # query_start_loc
|
||||
None, # cache_indices
|
||||
None, # has_initial_state
|
||||
apply_silu,
|
||||
PAD_SLOT_ID,
|
||||
)
|
||||
out_ref, conv_state_ref = mamba_conv1d_ref(
|
||||
x, conv_state, conv_weight, conv_bias, apply_silu
|
||||
)
|
||||
|
||||
torch.testing.assert_close(x_kernel, out_ref, rtol=1e-2, atol=1e-2)
|
||||
torch.testing.assert_close(conv_state_kernel, conv_state_ref, rtol=1e-2, atol=1e-2)
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
|
||||
def test_various_batch_sizes(self, batch_size):
|
||||
"""Test with various batch sizes."""
|
||||
torch.manual_seed(42)
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
dim = 1024
|
||||
seq_len = 64
|
||||
dconv = 4
|
||||
apply_silu = True
|
||||
|
||||
x = torch.randn(batch_size, dim, seq_len, dtype=dtype, device=device) * 0.5
|
||||
conv_state = torch.zeros(batch_size, dim, dconv - 1, dtype=dtype, device=device)
|
||||
conv_weight = torch.randn(dim, 1, dconv, dtype=dtype, device=device)
|
||||
conv_bias = torch.randn(dim, dtype=dtype, device=device)
|
||||
x_kernel = x.clone()
|
||||
conv_state_kernel = conv_state.clone()
|
||||
|
||||
conv_weight_input = conv_weight.squeeze(1).contiguous()
|
||||
torch.ops.trtllm.causal_conv1d_fwd(
|
||||
x_kernel,
|
||||
conv_weight_input,
|
||||
conv_bias,
|
||||
conv_state_kernel,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
apply_silu,
|
||||
PAD_SLOT_ID,
|
||||
)
|
||||
out_ref, conv_state_ref = mamba_conv1d_ref(
|
||||
x, conv_state, conv_weight, conv_bias, apply_silu
|
||||
)
|
||||
|
||||
torch.testing.assert_close(x_kernel, out_ref, rtol=1e-2, atol=1e-1)
|
||||
torch.testing.assert_close(conv_state_kernel, conv_state_ref, rtol=1e-2, atol=1e-1)
|
||||
|
||||
@pytest.mark.parametrize("dconv", [2, 3, 4])
|
||||
def test_various_kernel_widths(self, dconv):
|
||||
"""Test with different convolution kernel widths."""
|
||||
torch.manual_seed(42)
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
batch_size = 4
|
||||
dim = 1024
|
||||
seq_len = 64
|
||||
apply_silu = True
|
||||
x = torch.randn(batch_size, dim, seq_len, dtype=dtype, device=device) * 0.5
|
||||
conv_state = torch.zeros(batch_size, dim, dconv - 1, dtype=dtype, device=device)
|
||||
conv_weight = torch.randn(dim, 1, dconv, dtype=dtype, device=device)
|
||||
conv_bias = torch.randn(dim, dtype=dtype, device=device)
|
||||
x_kernel = x.clone()
|
||||
conv_state_kernel = conv_state.clone()
|
||||
|
||||
conv_weight_input = conv_weight.squeeze(1).contiguous()
|
||||
torch.ops.trtllm.causal_conv1d_fwd(
|
||||
x_kernel,
|
||||
conv_weight_input,
|
||||
conv_bias,
|
||||
conv_state_kernel,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
apply_silu,
|
||||
PAD_SLOT_ID,
|
||||
)
|
||||
out_ref, conv_state_ref = mamba_conv1d_ref(
|
||||
x, conv_state, conv_weight, conv_bias, apply_silu
|
||||
)
|
||||
|
||||
torch.testing.assert_close(x_kernel, out_ref, rtol=1e-2, atol=1e-1)
|
||||
torch.testing.assert_close(conv_state_kernel, conv_state_ref, rtol=1e-2, atol=1e-1)
|
||||
|
||||
def test_with_initial_state(self):
|
||||
"""Test with non-zero initial conv state."""
|
||||
torch.manual_seed(42)
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
batch_size = 4
|
||||
dim = 1024
|
||||
seq_len = 32
|
||||
dconv = 4
|
||||
apply_silu = True
|
||||
|
||||
x = torch.randn(batch_size, dim, seq_len, dtype=dtype, device=device) * 0.5
|
||||
# Non-zero initial state
|
||||
conv_state = torch.randn(batch_size, dim, dconv - 1, dtype=dtype, device=device)
|
||||
conv_state = conv_state * 0.5
|
||||
conv_weight = torch.randn(dim, 1, dconv, dtype=dtype, device=device)
|
||||
conv_bias = torch.randn(dim, dtype=dtype, device=device)
|
||||
conv_state_kernel = conv_state.clone()
|
||||
# Need to tell the kernel about initial state
|
||||
has_initial_state = torch.ones(batch_size, dtype=torch.bool, device=device)
|
||||
query_start_loc = torch.tensor(
|
||||
[0] + [seq_len * (i + 1) for i in range(batch_size)],
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
# Reshape for varlen format
|
||||
x_varlen = x.transpose(1, 2).reshape(-1, dim).T.contiguous()
|
||||
|
||||
conv_weight_input = conv_weight.squeeze(1).contiguous()
|
||||
torch.ops.trtllm.causal_conv1d_fwd(
|
||||
x_varlen,
|
||||
conv_weight_input,
|
||||
conv_bias,
|
||||
conv_state_kernel,
|
||||
query_start_loc,
|
||||
None, # cache_indices
|
||||
has_initial_state,
|
||||
apply_silu,
|
||||
PAD_SLOT_ID,
|
||||
)
|
||||
|
||||
out_ref_list = []
|
||||
conv_state_ref_list = []
|
||||
for b in range(batch_size):
|
||||
out_b, state_b = mamba_conv1d_ref(
|
||||
x[b : b + 1],
|
||||
conv_state[b : b + 1],
|
||||
conv_weight,
|
||||
conv_bias,
|
||||
apply_silu,
|
||||
)
|
||||
out_ref_list.append(out_b)
|
||||
conv_state_ref_list.append(state_b)
|
||||
out_ref = torch.cat(out_ref_list, dim=0)
|
||||
conv_state_ref = torch.cat(conv_state_ref_list, dim=0)
|
||||
x_kernel_reshaped = (
|
||||
x_varlen.T.reshape(batch_size, seq_len, dim).transpose(1, 2).contiguous()
|
||||
)
|
||||
|
||||
torch.testing.assert_close(x_kernel_reshaped, out_ref, rtol=1e-2, atol=1e-1)
|
||||
torch.testing.assert_close(conv_state_kernel, conv_state_ref, rtol=1e-2, atol=1e-1)
|
||||
113
tests/unittest/_torch/modules/mamba/test_fuse_elementwise_ops.py
Normal file
113
tests/unittest/_torch/modules/mamba/test_fuse_elementwise_ops.py
Normal file
@ -0,0 +1,113 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit tests for fused elementwise operations in Mamba2 prefill."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.modules.mamba.fuse_elementwise_ops import (
|
||||
extract_transpose_xbc_prefill,
|
||||
fused_split_rearrange_after_conv1d,
|
||||
)
|
||||
|
||||
skip_no_cuda = pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="CUDA required for triton kernels",
|
||||
)
|
||||
|
||||
|
||||
def extract_transpose_xbc_prefill_ref(
|
||||
zxbcdt: torch.Tensor,
|
||||
num_prefill_tokens: int,
|
||||
d_inner: int,
|
||||
conv_dim: int,
|
||||
) -> torch.Tensor:
|
||||
"""Reference implementation for extract_transpose_xbc_prefill."""
|
||||
# Extract the xbc slice and transpose
|
||||
xbc = zxbcdt[:num_prefill_tokens, d_inner : d_inner + conv_dim]
|
||||
return xbc.transpose(0, 1).contiguous()
|
||||
|
||||
|
||||
def fused_split_rearrange_after_conv1d_ref(
|
||||
xbc: torch.Tensor,
|
||||
d_inner: int,
|
||||
n_groups: int,
|
||||
d_state: int,
|
||||
nheads: int,
|
||||
head_dim: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Reference implementation for fused_split_rearrange_after_conv1d."""
|
||||
conv_dim, num_prefill_tokens = xbc.shape
|
||||
bc_size = n_groups * d_state
|
||||
|
||||
# Transpose and split
|
||||
xbc_t = xbc.transpose(0, 1).contiguous() # [num_prefill_tokens, conv_dim]
|
||||
x, B, C = torch.split(xbc_t, [d_inner, bc_size, bc_size], dim=-1)
|
||||
x = x.contiguous().view(1, num_prefill_tokens, nheads, head_dim)
|
||||
B = B.contiguous().view(1, num_prefill_tokens, n_groups, d_state)
|
||||
C = C.contiguous().view(1, num_prefill_tokens, n_groups, d_state)
|
||||
return x, B, C
|
||||
|
||||
|
||||
@skip_no_cuda
|
||||
@pytest.mark.parametrize("num_prefill_tokens", [1, 32, 128, 1024])
|
||||
@pytest.mark.parametrize(
|
||||
"d_inner,conv_dim,d_in_proj", [(256, 512, 800), (512, 1024, 1600), (1024, 2048, 3200)]
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_extract_transpose_xbc_prefill(num_prefill_tokens, d_inner, conv_dim, d_in_proj, dtype):
|
||||
"""Test extract_transpose_xbc_prefill matches reference implementation."""
|
||||
torch.manual_seed(42)
|
||||
device = torch.device("cuda")
|
||||
|
||||
num_total_tokens = num_prefill_tokens + 16
|
||||
zxbcdt = torch.randn(num_total_tokens, d_in_proj, dtype=dtype, device=device)
|
||||
out_ref = extract_transpose_xbc_prefill_ref(zxbcdt, num_prefill_tokens, d_inner, conv_dim)
|
||||
out_fused = extract_transpose_xbc_prefill(zxbcdt, num_prefill_tokens, d_inner, conv_dim)
|
||||
|
||||
assert out_fused.shape == out_ref.shape, f"Shape mismatch: {out_fused.shape} vs {out_ref.shape}"
|
||||
torch.testing.assert_close(out_fused, out_ref, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
@skip_no_cuda
|
||||
@pytest.mark.parametrize("num_prefill_tokens", [1, 32, 128, 1024])
|
||||
@pytest.mark.parametrize(
|
||||
"nheads,head_dim,n_groups,d_state", [(8, 64, 1, 128), (16, 64, 2, 64), (32, 64, 4, 64)]
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_fused_split_rearrange_after_conv1d(
|
||||
num_prefill_tokens, nheads, head_dim, n_groups, d_state, dtype
|
||||
):
|
||||
"""Test fused_split_rearrange_after_conv1d matches reference implementation."""
|
||||
torch.manual_seed(42)
|
||||
device = torch.device("cuda")
|
||||
|
||||
d_inner = nheads * head_dim
|
||||
bc_size = n_groups * d_state
|
||||
conv_dim = d_inner + 2 * bc_size
|
||||
xbc = torch.randn(conv_dim, num_prefill_tokens, dtype=dtype, device=device)
|
||||
x_ref, B_ref, C_ref = fused_split_rearrange_after_conv1d_ref(
|
||||
xbc, d_inner, n_groups, d_state, nheads, head_dim
|
||||
)
|
||||
x_fused, B_fused, C_fused = fused_split_rearrange_after_conv1d(
|
||||
xbc, d_inner, n_groups, d_state, nheads, head_dim
|
||||
)
|
||||
|
||||
assert x_fused.shape == x_ref.shape, f"x shape mismatch: {x_fused.shape} vs {x_ref.shape}"
|
||||
assert B_fused.shape == B_ref.shape, f"B shape mismatch: {B_fused.shape} vs {B_ref.shape}"
|
||||
assert C_fused.shape == C_ref.shape, f"C shape mismatch: {C_fused.shape} vs {C_ref.shape}"
|
||||
torch.testing.assert_close(x_fused, x_ref, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(B_fused, B_ref, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(C_fused, C_ref, rtol=1e-3, atol=1e-3)
|
||||
133
tests/unittest/_torch/modules/mamba/test_mamba2_metadata.py
Normal file
133
tests/unittest/_torch/modules/mamba/test_mamba2_metadata.py
Normal file
@ -0,0 +1,133 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit tests for Mamba2 metadata preparation optimizations."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import (
|
||||
cu_seqlens_to_chunk_indices_offsets,
|
||||
cu_seqlens_to_chunk_indices_offsets_triton,
|
||||
)
|
||||
|
||||
skip_no_cuda = pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="CUDA required for triton kernels",
|
||||
)
|
||||
|
||||
|
||||
@skip_no_cuda
|
||||
class TestCuSeqlensToChunkIndicesOffsets:
|
||||
"""Tests for cu_seqlens_to_chunk_indices_offsets_triton function."""
|
||||
|
||||
def test_empty_sequence(self):
|
||||
"""Test with empty cu_seqlens (no sequences)."""
|
||||
cu_seqlens = torch.tensor([0], dtype=torch.int, device="cuda")
|
||||
chunk_size = 8
|
||||
|
||||
indices_triton, offsets_triton = cu_seqlens_to_chunk_indices_offsets_triton(
|
||||
cu_seqlens, chunk_size
|
||||
)
|
||||
|
||||
assert indices_triton.numel() == 0
|
||||
assert offsets_triton.numel() == 0
|
||||
|
||||
def test_single_sequence_aligned(self):
|
||||
"""Test with a single sequence that aligns with chunk size."""
|
||||
cu_seqlens = torch.tensor([0, 16], dtype=torch.int, device="cuda")
|
||||
chunk_size = 8
|
||||
|
||||
indices_ref, offsets_ref = cu_seqlens_to_chunk_indices_offsets(cu_seqlens, chunk_size)
|
||||
indices_triton, offsets_triton = cu_seqlens_to_chunk_indices_offsets_triton(
|
||||
cu_seqlens, chunk_size
|
||||
)
|
||||
|
||||
torch.testing.assert_close(indices_triton, indices_ref)
|
||||
torch.testing.assert_close(offsets_triton, offsets_ref)
|
||||
|
||||
def test_single_sequence_unaligned(self):
|
||||
"""Test with a single sequence that doesn't align with chunk size."""
|
||||
cu_seqlens = torch.tensor([0, 10], dtype=torch.int, device="cuda")
|
||||
chunk_size = 8
|
||||
|
||||
indices_ref, offsets_ref = cu_seqlens_to_chunk_indices_offsets(cu_seqlens, chunk_size)
|
||||
indices_triton, offsets_triton = cu_seqlens_to_chunk_indices_offsets_triton(
|
||||
cu_seqlens, chunk_size
|
||||
)
|
||||
|
||||
torch.testing.assert_close(indices_triton, indices_ref)
|
||||
torch.testing.assert_close(offsets_triton, offsets_ref)
|
||||
|
||||
def test_two_sequences_aligned(self):
|
||||
"""Test with two sequences, both aligned with chunk boundaries."""
|
||||
cu_seqlens = torch.tensor([0, 8, 16], dtype=torch.int, device="cuda")
|
||||
chunk_size = 8
|
||||
|
||||
indices_ref, offsets_ref = cu_seqlens_to_chunk_indices_offsets(cu_seqlens, chunk_size)
|
||||
indices_triton, offsets_triton = cu_seqlens_to_chunk_indices_offsets_triton(
|
||||
cu_seqlens, chunk_size
|
||||
)
|
||||
|
||||
torch.testing.assert_close(indices_triton, indices_ref)
|
||||
torch.testing.assert_close(offsets_triton, offsets_ref)
|
||||
|
||||
def test_two_sequences_misaligned(self):
|
||||
"""Test with two sequences where second starts at misaligned position."""
|
||||
# Example from docstring: cu_seqlens = [0, 5, 10], chunk_size = 8
|
||||
# -> chunk_indices = [0, 0, 1], chunk_offsets = [0, 5, 0]
|
||||
cu_seqlens = torch.tensor([0, 5, 10], dtype=torch.int, device="cuda")
|
||||
chunk_size = 8
|
||||
|
||||
indices_ref, offsets_ref = cu_seqlens_to_chunk_indices_offsets(cu_seqlens, chunk_size)
|
||||
indices_triton, offsets_triton = cu_seqlens_to_chunk_indices_offsets_triton(
|
||||
cu_seqlens, chunk_size
|
||||
)
|
||||
|
||||
# Verify against expected values from docstring
|
||||
expected_indices = torch.tensor([0, 0, 1], dtype=torch.int, device="cuda")
|
||||
expected_offsets = torch.tensor([0, 5, 0], dtype=torch.int, device="cuda")
|
||||
|
||||
torch.testing.assert_close(indices_ref, expected_indices)
|
||||
torch.testing.assert_close(offsets_ref, expected_offsets)
|
||||
|
||||
torch.testing.assert_close(indices_triton, indices_ref)
|
||||
torch.testing.assert_close(offsets_triton, offsets_ref)
|
||||
|
||||
@pytest.mark.parametrize("chunk_size", [8, 16, 32, 64, 128])
|
||||
def test_multiple_sequences_various_chunk_sizes(self, chunk_size):
|
||||
"""Test with multiple sequences and various chunk sizes."""
|
||||
# Create sequences with varying lengths
|
||||
cu_seqlens = torch.tensor([0, 10, 25, 40, 60, 75], dtype=torch.int, device="cuda")
|
||||
|
||||
indices_ref, offsets_ref = cu_seqlens_to_chunk_indices_offsets(cu_seqlens, chunk_size)
|
||||
indices_triton, offsets_triton = cu_seqlens_to_chunk_indices_offsets_triton(
|
||||
cu_seqlens, chunk_size
|
||||
)
|
||||
|
||||
torch.testing.assert_close(indices_triton, indices_ref)
|
||||
torch.testing.assert_close(offsets_triton, offsets_ref)
|
||||
|
||||
def test_all_sequences_within_one_chunk(self):
|
||||
"""Test when all sequences fit within a single chunk."""
|
||||
cu_seqlens = torch.tensor([0, 2, 4, 6], dtype=torch.int, device="cuda")
|
||||
chunk_size = 64 # Large chunk size
|
||||
|
||||
indices_ref, offsets_ref = cu_seqlens_to_chunk_indices_offsets(cu_seqlens, chunk_size)
|
||||
indices_triton, offsets_triton = cu_seqlens_to_chunk_indices_offsets_triton(
|
||||
cu_seqlens, chunk_size
|
||||
)
|
||||
|
||||
torch.testing.assert_close(indices_triton, indices_ref)
|
||||
torch.testing.assert_close(offsets_triton, offsets_ref)
|
||||
223
tests/unittest/_torch/modules/test_fused_activation_quant.py
Normal file
223
tests/unittest/_torch/modules/test_fused_activation_quant.py
Normal file
@ -0,0 +1,223 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit tests for fused relu2 + NVFP4 quantization kernel."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tests.unittest.utils.util import getSMVersion
|
||||
|
||||
|
||||
def fused_relu2_quantize_available():
|
||||
"""Check if the fused_relu2_quantize op is available."""
|
||||
return hasattr(torch.ops, "trtllm") and hasattr(torch.ops.trtllm, "fused_relu2_quantize")
|
||||
|
||||
|
||||
def fp4_quantize_available():
|
||||
"""Check if the fp4_quantize op is available."""
|
||||
return hasattr(torch.ops, "trtllm") and hasattr(torch.ops.trtllm, "fp4_quantize")
|
||||
|
||||
|
||||
skip_unless_fused_relu2_quantize = pytest.mark.skipif(
|
||||
getSMVersion() < 100 or not fused_relu2_quantize_available(),
|
||||
reason="Requires SM100+ (Blackwell) and trtllm.fused_relu2_quantize op",
|
||||
)
|
||||
|
||||
skip_unless_fused_relu2_and_fp4_quantize = pytest.mark.skipif(
|
||||
getSMVersion() < 100 or not fused_relu2_quantize_available() or not fp4_quantize_available(),
|
||||
reason="Requires SM100+ (Blackwell) and trtllm fused_relu2_quantize + fp4_quantize ops",
|
||||
)
|
||||
|
||||
|
||||
# FP4 E2M1 lookup table for reference implementation
|
||||
E2M1_BOUNDS = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5])
|
||||
|
||||
|
||||
def relu2(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Reference relu2 activation: square(relu(x))."""
|
||||
return torch.square(F.relu(x))
|
||||
|
||||
|
||||
def cast_to_fp4(weight: torch.Tensor) -> torch.Tensor:
|
||||
"""Cast tensor values to FP4 E2M1 format (as uint8)."""
|
||||
device = weight.device
|
||||
|
||||
mask = torch.tensor([0, 1, 0, 1, 0, 1, 0], dtype=torch.uint8).to(device)
|
||||
mask_shape = list(weight.shape)
|
||||
mask = mask.expand([*mask_shape, 7])
|
||||
|
||||
sign_bit = (weight < 0).to(torch.uint8)
|
||||
weight_abs = weight.abs()
|
||||
|
||||
ord_val = torch.searchsorted(E2M1_BOUNDS.to(device), weight_abs, out_int32=True).to(torch.uint8)
|
||||
round_val = torch.any((weight_abs.unsqueeze(-1) == E2M1_BOUNDS.to(device)) * mask, dim=-1)
|
||||
fp4_val = (sign_bit * 0b1000 + ord_val + round_val).to(torch.uint8)
|
||||
return fp4_val
|
||||
|
||||
|
||||
def quantize_nvfp4_ref(
|
||||
input: torch.Tensor, sf_scale: torch.Tensor, sf_vec_size: int = 16
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Reference NVFP4 quantization implementation.
|
||||
|
||||
Args:
|
||||
input: Input tensor [M, N], already activated (e.g., after relu2)
|
||||
sf_scale: Per-tensor scaling factor (sf_scale = amax / (6 * 448))
|
||||
sf_vec_size: Block size for per-block scaling (default 16)
|
||||
|
||||
Returns:
|
||||
Tuple of (fp4_packed, scale_factors)
|
||||
"""
|
||||
m, n = input.shape
|
||||
assert n % sf_vec_size == 0, f"N ({n}) must be divisible by sf_vec_size ({sf_vec_size})"
|
||||
|
||||
# Reshape for block-wise quantization
|
||||
input_blocked = input.view(m, n // sf_vec_size, sf_vec_size)
|
||||
|
||||
# Compute per-block amax
|
||||
per_block_amax = input_blocked.abs().amax(dim=-1).float()
|
||||
|
||||
# Compute per-block scale: amax / 6.0
|
||||
per_block_scale = per_block_amax / 6.0
|
||||
|
||||
# Quantize per-block scale to FP8
|
||||
q_per_block_scale = per_block_scale / sf_scale
|
||||
q_per_block_scale[per_block_scale == 0] = 1.0
|
||||
q_per_block_scale_fp8 = q_per_block_scale.to(torch.float8_e4m3fn)
|
||||
|
||||
# Dequantize scale for actual quantization
|
||||
scale_dequant = q_per_block_scale_fp8.float() * sf_scale
|
||||
|
||||
# Scale the input
|
||||
scale_expanded = scale_dequant.unsqueeze(-1).expand_as(input_blocked)
|
||||
scaled_input = input_blocked / (scale_expanded + 1e-12)
|
||||
scaled_input = scaled_input.view(m, n)
|
||||
|
||||
# Cast to FP4
|
||||
fp4_vals = cast_to_fp4(scaled_input)
|
||||
|
||||
# Pack two FP4 values into one uint8
|
||||
packed = (fp4_vals[..., 1::2] << 4) | fp4_vals[..., 0::2]
|
||||
|
||||
return packed, q_per_block_scale_fp8
|
||||
|
||||
|
||||
def fused_relu2_quantize_ref(
|
||||
input: torch.Tensor, sf_scale: torch.Tensor, sf_vec_size: int = 16
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Reference implementation for fused relu2 + NVFP4 quantization.
|
||||
|
||||
Args:
|
||||
input: Input tensor [M, N]
|
||||
sf_scale: Per-tensor scaling factor
|
||||
sf_vec_size: Block size for per-block scaling (default 16)
|
||||
|
||||
Returns:
|
||||
Tuple of (fp4_packed, scale_factors)
|
||||
"""
|
||||
# Apply relu2 activation
|
||||
activated = relu2(input)
|
||||
# Quantize to NVFP4
|
||||
return quantize_nvfp4_ref(activated, sf_scale, sf_vec_size)
|
||||
|
||||
|
||||
@skip_unless_fused_relu2_quantize
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_fused_relu2_quantize_zeros(dtype):
|
||||
"""Test fused_relu2_quantize with inputs that produce zeros after relu2."""
|
||||
device = torch.device("cuda")
|
||||
|
||||
# All negative inputs -> relu2 produces all zeros
|
||||
m, n = 32, 64
|
||||
input_tensor = -torch.abs(torch.randn(m, n, dtype=dtype, device=device))
|
||||
sf_scale = torch.tensor([1.0], dtype=torch.float32, device=device)
|
||||
fp4_fused, sf_fused = torch.ops.trtllm.fused_relu2_quantize(input_tensor, sf_scale, 16)
|
||||
|
||||
assert fp4_fused.shape == (m, n // 2)
|
||||
assert (fp4_fused == 0).all(), "All negative inputs should produce zero output"
|
||||
|
||||
|
||||
@skip_unless_fused_relu2_and_fp4_quantize
|
||||
@pytest.mark.parametrize("m", [1, 16, 64, 128])
|
||||
@pytest.mark.parametrize("n", [32, 64, 128, 256])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_fused_relu2_quantize_vs_separate_ops(m, n, dtype):
|
||||
"""
|
||||
Compare fused_relu2_quantize kernel output against separate relu2 + fp4_quantize.
|
||||
|
||||
This test verifies that the fused CUDA kernel produces FP4 packed values that
|
||||
closely match running relu2 activation followed by fp4_quantize separately.
|
||||
|
||||
Note: Due to floating point precision differences in intermediate calculations
|
||||
(e.g., FMA vs separate mul+add), a small percentage of values at quantization
|
||||
boundaries may differ. We require >= 99% match rate.
|
||||
"""
|
||||
torch.manual_seed(42)
|
||||
device = torch.device("cuda")
|
||||
|
||||
input_tensor = torch.randn(m, n, dtype=dtype, device=device)
|
||||
activated = relu2(input_tensor)
|
||||
sf_scale = (activated.abs().amax().float() / (6.0 * 448.0)).to(device)
|
||||
sf_scale = sf_scale.view(1)
|
||||
|
||||
fp4_separate, sf_separate = torch.ops.trtllm.fp4_quantize(
|
||||
activated,
|
||||
sf_scale,
|
||||
16,
|
||||
False,
|
||||
True, # use_ue8m0=False, is_sf_swizzled_layout=True
|
||||
)
|
||||
fp4_fused, sf_fused = torch.ops.trtllm.fused_relu2_quantize(
|
||||
input_tensor.contiguous(), sf_scale, 16
|
||||
)
|
||||
|
||||
match_rate = (fp4_fused == fp4_separate).float().mean().item()
|
||||
assert match_rate >= 0.99, (
|
||||
f"FP4 values match rate {match_rate:.4f} < 0.99 for shape ({m}, {n}), dtype {dtype}"
|
||||
)
|
||||
|
||||
|
||||
@skip_unless_fused_relu2_and_fp4_quantize
|
||||
def test_fused_relu2_quantize_vs_separate_ops_various_sf_scales():
|
||||
"""
|
||||
Test with various sf_scale values to ensure consistent behavior.
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
m, n = 64, 128
|
||||
dtype = torch.bfloat16
|
||||
|
||||
torch.manual_seed(123)
|
||||
input_tensor = torch.randn(m, n, dtype=dtype, device=device)
|
||||
activated = relu2(input_tensor)
|
||||
|
||||
# Test with different sf_scale values
|
||||
for scale_multiplier in [0.1, 1.0, 10.0]:
|
||||
sf_scale = (
|
||||
(activated.abs().amax().float() / (6.0 * 448.0) * scale_multiplier).to(device).view(1)
|
||||
)
|
||||
fp4_separate, sf_separate = torch.ops.trtllm.fp4_quantize(
|
||||
activated, sf_scale, 16, False, True
|
||||
)
|
||||
fp4_fused, sf_fused = torch.ops.trtllm.fused_relu2_quantize(
|
||||
input_tensor.contiguous(), sf_scale, 16
|
||||
)
|
||||
|
||||
match_rate = (fp4_fused == fp4_separate).float().mean().item()
|
||||
assert match_rate >= 0.99, (
|
||||
f"FP4 values match rate {match_rate:.4f} < 0.99 with scale_multiplier={scale_multiplier}"
|
||||
)
|
||||
336
tests/unittest/_torch/modules/test_fused_add_rms_norm_quant.py
Normal file
336
tests/unittest/_torch/modules/test_fused_add_rms_norm_quant.py
Normal file
@ -0,0 +1,336 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit tests for fused_add_rms_norm_quant with/without output_hp_norm support."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.unittest.utils.util import getSMVersion
|
||||
|
||||
|
||||
def fused_add_rms_norm_quant_available():
|
||||
"""Check if the fused_add_rms_norm_quant op is available."""
|
||||
return hasattr(torch.ops, "trtllm") and hasattr(torch.ops.trtllm, "fused_add_rms_norm_quant")
|
||||
|
||||
|
||||
skip_unsupported = pytest.mark.skipif(
|
||||
getSMVersion() < 100 or not fused_add_rms_norm_quant_available(),
|
||||
reason="Requires Blackwell+ (SM100+) and trtllm.fused_add_rms_norm_quant op",
|
||||
)
|
||||
|
||||
|
||||
def rms_norm_ref(
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
gamma: torch.Tensor,
|
||||
eps: float,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Reference RMSNorm implementation with residual addition.
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor [M, N]
|
||||
residual: Residual tensor [M, N]
|
||||
gamma: Weight tensor [N]
|
||||
eps: Epsilon for numerical stability
|
||||
|
||||
Returns:
|
||||
Tuple of (normalized_output, residual_output)
|
||||
"""
|
||||
input_dtype = hidden_states.dtype
|
||||
|
||||
# Add residual
|
||||
hidden_states_fp32 = hidden_states.float() + residual.float()
|
||||
residual_out = hidden_states_fp32.to(input_dtype)
|
||||
|
||||
# RMSNorm
|
||||
variance = hidden_states_fp32.pow(2).mean(-1, keepdim=True)
|
||||
normed = hidden_states_fp32 * torch.rsqrt(variance + eps)
|
||||
normed_output = (gamma.float() * normed).to(input_dtype)
|
||||
|
||||
return normed_output, residual_out
|
||||
|
||||
|
||||
def get_swizzled_sf_indices(m: int, n: int, sf_vec_size: int = 16) -> list[int]:
|
||||
"""
|
||||
Compute the valid indices in swizzled SF layout for given m and n.
|
||||
|
||||
The swizzled layout uses 128x4 tiles:
|
||||
- SF layout: [numMTiles, numKTiles, 32 (outerM), 4 (innerM), 4 (innerK)]
|
||||
- Each SF block has 128 rows, padded to multiple of 128
|
||||
- Each SF block has columns padded to multiple of 4
|
||||
|
||||
Args:
|
||||
m: Number of rows
|
||||
n: Hidden dimension
|
||||
sf_vec_size: Number of elements sharing one scale factor (default 16)
|
||||
|
||||
Returns:
|
||||
List of valid indices in the swizzled buffer
|
||||
"""
|
||||
num_col_vecs = n // sf_vec_size # Number of SF columns
|
||||
indices = []
|
||||
|
||||
for m_idx in range(m):
|
||||
for k_idx in range(num_col_vecs):
|
||||
# Compute swizzled offset using 128x4 tile layout
|
||||
inner_k_idx = k_idx % 4
|
||||
inner_k_stride = 1
|
||||
|
||||
inner_m_idx = (m_idx % 128) // 32
|
||||
inner_m_stride = 4 * inner_k_stride # 4
|
||||
|
||||
outer_m_idx = m_idx % 32
|
||||
outer_m_stride = 4 * inner_m_stride # 16
|
||||
|
||||
k_tile_idx = k_idx // 4
|
||||
k_tile_stride = 32 * outer_m_stride # 512
|
||||
|
||||
num_k_tiles = (num_col_vecs + 3) // 4
|
||||
m_tile_idx = m_idx // 128
|
||||
m_tile_stride = num_k_tiles * k_tile_stride
|
||||
|
||||
offset = (
|
||||
m_tile_idx * m_tile_stride
|
||||
+ k_tile_idx * k_tile_stride
|
||||
+ outer_m_idx * outer_m_stride
|
||||
+ inner_m_idx * inner_m_stride
|
||||
+ inner_k_idx * inner_k_stride
|
||||
)
|
||||
indices.append(offset)
|
||||
|
||||
return indices
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
@pytest.mark.parametrize("m", [1, 16, 64, 128])
|
||||
@pytest.mark.parametrize("n", [2048, 4096])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_fused_add_rms_norm_quant_basic(m, n, dtype):
|
||||
"""Test basic functionality of fused_add_rms_norm_quant without hp_output."""
|
||||
torch.manual_seed(42)
|
||||
device = torch.device("cuda")
|
||||
|
||||
# Create input tensors
|
||||
hidden_states = torch.randn(m, n, dtype=dtype, device=device)
|
||||
residual = torch.randn(m, n, dtype=dtype, device=device)
|
||||
gamma = torch.ones(n, dtype=dtype, device=device)
|
||||
|
||||
# Compute sf_scale (per-tensor scale)
|
||||
eps = 1e-6
|
||||
normed_ref, _ = rms_norm_ref(hidden_states, residual, gamma, eps)
|
||||
sf_scale = (normed_ref.abs().amax().float() / (6.0 * 448.0)).view(1)
|
||||
|
||||
# Run fused kernel without hp_output
|
||||
normed_fp4, residual_out, sf_out, dummy_output = torch.ops.trtllm.fused_add_rms_norm_quant(
|
||||
hidden_states, residual, gamma, sf_scale, True, eps=eps
|
||||
)
|
||||
assert dummy_output is None
|
||||
|
||||
# Verify output shapes
|
||||
assert normed_fp4.shape[0] == m
|
||||
assert residual_out.shape == (m, n)
|
||||
assert residual_out.dtype == dtype
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
@pytest.mark.parametrize("m", [1, 16, 64, 128])
|
||||
@pytest.mark.parametrize("n", [2048, 4096])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_fused_add_rms_norm_quant_with_hp_output(m, n, dtype):
|
||||
"""Test fused_add_rms_norm_quant with output_hp_norm=True."""
|
||||
torch.manual_seed(42)
|
||||
device = torch.device("cuda")
|
||||
|
||||
# Create input tensors
|
||||
hidden_states = torch.randn(m, n, dtype=dtype, device=device)
|
||||
residual = torch.randn(m, n, dtype=dtype, device=device)
|
||||
gamma = torch.ones(n, dtype=dtype, device=device)
|
||||
|
||||
# Compute sf_scale
|
||||
eps = 1e-6
|
||||
normed_ref, residual_ref = rms_norm_ref(hidden_states, residual, gamma, eps)
|
||||
sf_scale = (normed_ref.abs().amax().float() / (6.0 * 448.0)).view(1)
|
||||
|
||||
# Run fused kernel with hp_output
|
||||
results = torch.ops.trtllm.fused_add_rms_norm_quant(
|
||||
hidden_states, residual, gamma, sf_scale, True, eps=eps, output_hp_norm=True
|
||||
)
|
||||
|
||||
# Should return 4 tensors when output_hp_norm=True
|
||||
assert len(results) == 4, f"Expected 4 outputs, got {len(results)}"
|
||||
|
||||
normed_fp4, residual_out, sf_out, hp_normed_output = results
|
||||
|
||||
# Verify output shapes
|
||||
assert normed_fp4.shape[0] == m
|
||||
assert residual_out.shape == (m, n)
|
||||
assert hp_normed_output.shape == (m, n)
|
||||
|
||||
# Verify dtypes
|
||||
assert residual_out.dtype == dtype
|
||||
assert hp_normed_output.dtype == dtype
|
||||
|
||||
# Verify high precision output matches reference
|
||||
torch.testing.assert_close(hp_normed_output, normed_ref, rtol=1e-2, atol=1e-2)
|
||||
|
||||
# Verify residual output matches reference
|
||||
torch.testing.assert_close(residual_out, residual_ref, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_fused_add_rms_norm_quant_hp_output_consistency(dtype):
|
||||
"""Test that hp_output is consistent with the quantized output."""
|
||||
torch.manual_seed(42)
|
||||
device = torch.device("cuda")
|
||||
|
||||
m, n = 64, 4096
|
||||
|
||||
# Create input tensors
|
||||
hidden_states = torch.randn(m, n, dtype=dtype, device=device)
|
||||
residual = torch.randn(m, n, dtype=dtype, device=device)
|
||||
gamma = torch.ones(n, dtype=dtype, device=device)
|
||||
|
||||
eps = 1e-6
|
||||
normed_ref, _ = rms_norm_ref(hidden_states, residual, gamma, eps)
|
||||
sf_scale = (normed_ref.abs().amax().float() / (6.0 * 448.0)).view(1)
|
||||
|
||||
# Run without hp_output
|
||||
results_no_hp = torch.ops.trtllm.fused_add_rms_norm_quant(
|
||||
hidden_states, residual, gamma, sf_scale, True, eps=eps, output_hp_norm=False
|
||||
)
|
||||
assert results_no_hp[3] is None
|
||||
normed_fp4_no_hp, residual_out_no_hp, sf_out_no_hp = results_no_hp[:3]
|
||||
|
||||
# Run with hp_output
|
||||
results_hp = torch.ops.trtllm.fused_add_rms_norm_quant(
|
||||
hidden_states, residual, gamma, sf_scale, True, eps=eps, output_hp_norm=True
|
||||
)
|
||||
normed_fp4_hp, residual_out_hp, sf_out_hp, hp_normed_output = results_hp
|
||||
|
||||
# The quantized outputs should be identical regardless of hp_output flag
|
||||
torch.testing.assert_close(normed_fp4_hp, normed_fp4_no_hp, rtol=0, atol=0)
|
||||
torch.testing.assert_close(residual_out_hp, residual_out_no_hp, rtol=0, atol=0)
|
||||
# Compare only valid SF indices (swizzled layout pads rows to 128)
|
||||
valid_sf_indices = get_swizzled_sf_indices(m, n)
|
||||
sf_out_hp_valid = sf_out_hp[valid_sf_indices]
|
||||
sf_out_no_hp_valid = sf_out_no_hp[valid_sf_indices]
|
||||
torch.testing.assert_close(sf_out_hp_valid, sf_out_no_hp_valid, rtol=0, atol=0)
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_fused_add_rms_norm_quant_gamma_weight(dtype):
|
||||
"""Test fused_add_rms_norm_quant with non-trivial gamma weights."""
|
||||
torch.manual_seed(42)
|
||||
device = torch.device("cuda")
|
||||
|
||||
m, n = 32, 2048
|
||||
|
||||
# Create input tensors
|
||||
hidden_states = torch.randn(m, n, dtype=dtype, device=device)
|
||||
residual = torch.randn(m, n, dtype=dtype, device=device)
|
||||
# Non-trivial gamma weights
|
||||
gamma = torch.randn(n, dtype=dtype, device=device) * 0.5 + 1.0
|
||||
|
||||
eps = 1e-6
|
||||
normed_ref, residual_ref = rms_norm_ref(hidden_states, residual, gamma, eps)
|
||||
sf_scale = (normed_ref.abs().amax().float() / (6.0 * 448.0)).view(1)
|
||||
|
||||
# Run with hp_output
|
||||
results = torch.ops.trtllm.fused_add_rms_norm_quant(
|
||||
hidden_states, residual, gamma, sf_scale, True, eps=eps, output_hp_norm=True
|
||||
)
|
||||
normed_fp4, residual_out, sf_out, hp_normed_output = results
|
||||
|
||||
# Verify high precision output matches reference
|
||||
torch.testing.assert_close(hp_normed_output, normed_ref, rtol=1e-2, atol=1e-2)
|
||||
|
||||
# Verify residual output matches reference
|
||||
torch.testing.assert_close(residual_out, residual_ref, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
def test_fused_add_rms_norm_quant_large_batch():
|
||||
"""Test fused_add_rms_norm_quant with larger batch size."""
|
||||
torch.manual_seed(42)
|
||||
device = torch.device("cuda")
|
||||
|
||||
m, n = 512, 4096
|
||||
dtype = torch.bfloat16
|
||||
|
||||
hidden_states = torch.randn(m, n, dtype=dtype, device=device)
|
||||
residual = torch.randn(m, n, dtype=dtype, device=device)
|
||||
gamma = torch.ones(n, dtype=dtype, device=device)
|
||||
|
||||
eps = 1e-6
|
||||
normed_ref, residual_ref = rms_norm_ref(hidden_states, residual, gamma, eps)
|
||||
sf_scale = (normed_ref.abs().amax().float() / (6.0 * 448.0)).view(1)
|
||||
|
||||
results = torch.ops.trtllm.fused_add_rms_norm_quant(
|
||||
hidden_states, residual, gamma, sf_scale, True, eps=eps, output_hp_norm=True
|
||||
)
|
||||
normed_fp4, residual_out, sf_out, hp_normed_output = results
|
||||
|
||||
assert hp_normed_output.shape == (m, n)
|
||||
torch.testing.assert_close(hp_normed_output, normed_ref, rtol=1e-2, atol=1e-2)
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_low_latency_layernorm_hp_output_consistency(dtype):
|
||||
"""
|
||||
Test that low_latency_layernorm hp_output is consistent with/without the flag.
|
||||
|
||||
The quantized outputs should be identical regardless of output_hp_norm flag.
|
||||
Uses m=1 to trigger the low_latency_layernorm path.
|
||||
"""
|
||||
torch.manual_seed(42)
|
||||
device = torch.device("cuda")
|
||||
|
||||
m, n = 1, 4096 # m=1 triggers low_latency_layernorm path
|
||||
|
||||
hidden_states = torch.randn(m, n, dtype=dtype, device=device)
|
||||
residual = torch.randn(m, n, dtype=dtype, device=device)
|
||||
gamma = torch.ones(n, dtype=dtype, device=device)
|
||||
|
||||
eps = 1e-6
|
||||
normed_ref, _ = rms_norm_ref(hidden_states, residual, gamma, eps)
|
||||
sf_scale = (normed_ref.abs().amax().float() / (6.0 * 448.0)).view(1)
|
||||
|
||||
# Run without hp_output
|
||||
results_no_hp = torch.ops.trtllm.fused_add_rms_norm_quant(
|
||||
hidden_states, residual, gamma, sf_scale, True, eps=eps
|
||||
)
|
||||
assert len(results_no_hp) == 4, f"Expected 4 outputs, got {len(results_no_hp)}"
|
||||
assert results_no_hp[3] is None, "Expected 4th output to be None when output_hp_norm=False"
|
||||
normed_fp4_no_hp, residual_out_no_hp, sf_out_no_hp = results_no_hp[:3]
|
||||
|
||||
# Run with hp_output
|
||||
results_hp = torch.ops.trtllm.fused_add_rms_norm_quant(
|
||||
hidden_states, residual, gamma, sf_scale, True, eps=eps, output_hp_norm=True
|
||||
)
|
||||
assert len(results_hp) == 4, f"Expected 4 outputs, got {len(results_hp)}"
|
||||
normed_fp4_hp, residual_out_hp, sf_out_hp, hp_normed_output = results_hp
|
||||
|
||||
# The quantized outputs should be identical regardless of hp_output flag
|
||||
torch.testing.assert_close(normed_fp4_hp, normed_fp4_no_hp, rtol=0, atol=0)
|
||||
torch.testing.assert_close(residual_out_hp, residual_out_no_hp, rtol=0, atol=0)
|
||||
# Compare only valid SF indices (swizzled layout pads rows to 128)
|
||||
valid_sf_indices = get_swizzled_sf_indices(m, n)
|
||||
sf_out_hp_valid = sf_out_hp[valid_sf_indices]
|
||||
sf_out_no_hp_valid = sf_out_no_hp[valid_sf_indices]
|
||||
torch.testing.assert_close(sf_out_hp_valid, sf_out_no_hp_valid, rtol=0, atol=0)
|
||||
Loading…
Reference in New Issue
Block a user