[feat] Optimizations on weight-only batched gemv kernel (#5420)

Signed-off-by: Cheng Hang <chang@nvidia.com>
This commit is contained in:
Cheng Hang 2025-06-30 10:20:16 +08:00 committed by GitHub
parent b4dab23e7b
commit 64db7d27f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 83 additions and 50 deletions

View File

@ -63,6 +63,10 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca
= (tid * StepK / (Details::kInterleave * Details::LayoutDetails::kTileSize)) * Details::LayoutDetails::kTileSize
+ ((tid * StepK) % Details::LayoutDetails::kTileSize);
bool constexpr scale_zero_ldg128 = Details::kInterleave == 1 && CtaN == 8;
using AccessTypeScaleZero = std::conditional_t<scale_zero_ldg128, AccessTypeA, TypeA>;
GMemIterator<Mandatory, AccessTypeA, CtaM, Details::kAccessNumA, TypeA> act_iterator(
act, offset_m * origin_k + real_offset_k, CtaK / Details::kInterleave, origin_k);
GMemIterator<EnableActScale, AccessTypeA, 1, Details::kAccessNumA, TypeA> act_scale_iterator(
@ -70,10 +74,10 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca
GMemIterator<Mandatory, AccessTypeW, CtaN, Details::kAccessNumW, uint8_t> weight_iterator(weight,
(interleaved_offset_n * interleaved_k + tid * StepK) / Details::kElemsPerByteW, CtaK / Details::kElemsPerByteW,
interleaved_k / Details::kElemsPerByteW);
GMemIterator<Mandatory, TypeA, CtaN, 1, TypeA> scales_iterator(scales,
GMemIterator<Mandatory, AccessTypeScaleZero, CtaN, 1, TypeA> scales_iterator(scales,
(GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n,
(GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave);
GMemIterator<EnableZero, TypeA, CtaN, 1, TypeA> zeros_iterator(zeros,
GMemIterator<EnableZero, AccessTypeScaleZero, CtaN, 1, TypeA> zeros_iterator(zeros,
(GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n,
(GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave);
@ -92,11 +96,19 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca
TypeA vec_scale[CtaN], vec_zero[CtaN];
TypeA tile_a[StepK], tile_w[StepK], tile_w_pack2[CtaN * StepK];
uint8_t tile_w_quantized[StepK / Details::kElemsPerByteW];
#pragma unroll
for (int i = 0; i < CtaN; ++i)
if constexpr (scale_zero_ldg128)
{
scales_iterator.load(vec_scale + i, iter, i);
zeros_iterator.load(vec_zero + i, iter, i);
scales_iterator.load(vec_scale, iter);
zeros_iterator.load(vec_zero, iter);
}
else
{
#pragma unroll
for (int i = 0; i < CtaN; ++i)
{
scales_iterator.load(vec_scale + i, iter, i);
zeros_iterator.load(vec_zero + i, iter, i);
}
}
act_scale_iterator.load(vec_act_scale, iter);
#pragma unroll
@ -112,7 +124,8 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca
{
act_iterator.load(tile_a, iter, i);
apply_scale<Details, 1, StepK, EnableActScale>(tile_a, vec_act_scale);
mma<Details, 1, CtaN, StepK>(tile_acc + i * CtaN, tile_w_pack2, tile_a);
mma<Details, 1, CtaN, StepK, EnableZero, ApplyAlphaInAdvance>(
tile_acc + i * CtaN, tile_w_pack2, tile_a, vec_scale);
}
}
epilogue<Details, CtaM, CtaN, Threads, EnableBias, ApplyAlphaInAdvance>(out, n, tile_acc, bias, alpha);

View File

@ -133,33 +133,37 @@ __device__ __forceinline__ void dequantize(void* w, void* quantized_w, void* sca
{
Converter::convert<K>(reinterpret_cast<uint8_t*>(quantized_w) + n * K / Details::kElemsPerByteW,
reinterpret_cast<Type*>(w) + n * K);
Type2 vec_scale, vec_zero;
if constexpr (ApplyAlphaInAdvance)
if constexpr (EnableZero || ApplyAlphaInAdvance)
{
// For W4A8, we assume scales/zero is always half data type, no matter activation dtype is bf16 or fp16
Type scales_ = static_cast<float>(reinterpret_cast<half*>(scales)[n]) * alpha;
vec_scale = MathWrapper<typename Details::TypeDetailsA>::to_vec2(scales_);
vec_zero = MathWrapper<typename Details::TypeDetailsA>::to_vec2(static_cast<Type>(0.f));
if constexpr (EnableZero)
Type2 vec_scale, vec_zero;
if constexpr (ApplyAlphaInAdvance)
{
vec_zero = MathWrapper<typename Details::TypeDetailsA>::to_vec2(
static_cast<float>(reinterpret_cast<half*>(zeros)[n]) * alpha);
Type scales_ = static_cast<float>(reinterpret_cast<half*>(scales)[n]) * alpha;
vec_scale = MathWrapper<typename Details::TypeDetailsA>::to_vec2(scales_);
vec_zero = MathWrapper<typename Details::TypeDetailsA>::to_vec2(static_cast<Type>(0.f));
if constexpr (EnableZero)
{
vec_zero = MathWrapper<typename Details::TypeDetailsA>::to_vec2(
static_cast<float>(reinterpret_cast<half*>(zeros)[n]) * alpha);
}
}
}
else
{
vec_scale = MathWrapper<typename Details::TypeDetailsA>::to_vec2(reinterpret_cast<Type*>(scales)[n]);
vec_zero = MathWrapper<typename Details::TypeDetailsA>::to_vec2(static_cast<Type>(0.f));
if constexpr (EnableZero)
else
{
vec_zero = MathWrapper<typename Details::TypeDetailsA>::to_vec2(reinterpret_cast<Type*>(zeros)[n]);
vec_scale = MathWrapper<typename Details::TypeDetailsA>::to_vec2(reinterpret_cast<Type*>(scales)[n]);
vec_zero = MathWrapper<typename Details::TypeDetailsA>::to_vec2(static_cast<Type>(0.f));
if constexpr (EnableZero)
{
vec_zero = MathWrapper<typename Details::TypeDetailsA>::to_vec2(reinterpret_cast<Type*>(zeros)[n]);
}
}
}
#pragma unroll
for (int k = 0; k < VecK; ++k)
{
reinterpret_cast<Type2*>(w)[n * VecK + k] = MathWrapper<typename Details::TypeDetailsA>::fma2(
reinterpret_cast<Type2*>(w)[n * VecK + k], vec_scale, vec_zero);
for (int k = 0; k < VecK; ++k)
{
reinterpret_cast<Type2*>(w)[n * VecK + k] = MathWrapper<typename Details::TypeDetailsA>::fma2(
reinterpret_cast<Type2*>(w)[n * VecK + k], vec_scale, vec_zero);
}
}
}
}
@ -177,8 +181,8 @@ __device__ __forceinline__ void pack_to_vec2(void* dst, void* src, int n)
}
}
template <typename Details, int M, int N, int K>
__device__ __forceinline__ void mma(void* acc, void* w_pack2, void* act)
template <typename Details, int M, int N, int K, bool EnableZero, bool ApplyAlphaInAdvance>
__device__ __forceinline__ void mma(void* acc, void* w_pack2, void* act, void* scale)
{
using Type = typename MathWrapper<typename Details::TypeDetailsA>::Type;
using Type2 = typename MathWrapper<typename Details::TypeDetailsA>::Type2;
@ -190,13 +194,30 @@ __device__ __forceinline__ void mma(void* acc, void* w_pack2, void* act)
#pragma unroll
for (int n = 0; n < VecN; ++n)
{
#pragma unroll
for (int k = 0; k < K; ++k)
if constexpr (EnableZero || ApplyAlphaInAdvance)
{
reinterpret_cast<Type2*>(acc)[m * VecN + n]
= MathWrapper<typename Details::TypeDetailsA>::fma2(reinterpret_cast<Type2*>(w_pack2)[n * K + k],
#pragma unroll
for (int k = 0; k < K; ++k)
{
reinterpret_cast<Type2*>(acc)[m * VecN + n] = MathWrapper<typename Details::TypeDetailsA>::fma2(
reinterpret_cast<Type2*>(w_pack2)[n * K + k],
MathWrapper<typename Details::TypeDetailsA>::to_vec2(reinterpret_cast<Type*>(act)[m * K + k]),
reinterpret_cast<Type2*>(acc)[m * VecN + n]);
}
}
else
{
Type2 local_acc{};
#pragma unroll
for (int k = 0; k < K; ++k)
{
local_acc = MathWrapper<typename Details::TypeDetailsA>::fma2(
reinterpret_cast<Type2*>(w_pack2)[n * K + k],
MathWrapper<typename Details::TypeDetailsA>::to_vec2(reinterpret_cast<Type*>(act)[m * K + k]),
local_acc);
}
reinterpret_cast<Type2*>(acc)[m * VecN + n] = MathWrapper<typename Details::TypeDetailsA>::fma2(
local_acc, reinterpret_cast<Type2*>(scale)[n], reinterpret_cast<Type2*>(acc)[m * VecN + n]);
}
}
}

View File

@ -165,13 +165,13 @@ struct cutlassTypeMapper
} \
};
CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::FP16Int8Groupwise, "FP16Int8Groupwise", half, uint8_t, 8,
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS);
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY);
CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::BF16Int8Groupwise, "BF16Int8Groupwise", __nv_bfloat16, uint8_t, 8,
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS);
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY);
CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::FP16Int4Groupwise, "FP16Int4Groupwise", half, cutlass::uint4b_t, 4,
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS);
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY);
CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::BF16Int4Groupwise, "BF16Int4Groupwise", __nv_bfloat16, cutlass::uint4b_t,
4, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS);
4, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY);
CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::FP16Int8PerChannel, "FP16Int8PerChannel", half, uint8_t, 8,
cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY);
CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::BF16Int8PerChannel, "BF16Int8PerChannel", __nv_bfloat16, uint8_t, 8,
@ -228,7 +228,7 @@ void exec_cutlass_kernel(
runner.gemm(
act, params.weight, params.scales, params.out, params.m, params.n, params.k, config, ws, ws_size, stream);
}
else if (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS)
else if (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY)
{
runner.gemm(act, params.weight, params.scales, params.zeros, params.bias, params.out, params.m, params.n,
params.k, params.groupsize, config, ws, ws_size, stream);
@ -290,7 +290,7 @@ float run_cutlass_kernel(wo::Params& params, int warmup, int iter)
msg << "\n (for"
<< " m=" << params.m << ", n=" << params.n << ", k=" << params.k << ")"
<< ", reason: \"" << e.what() << "\". Skipped\n";
std::cout << msg.str();
TLLM_LOG_TRACE(msg.str());
cudaGetLastError(); // Reset the last cudaError to cudaSuccess.
continue;
}
@ -332,7 +332,7 @@ bool benchmark_and_verify(int m, int n, int k, int groupsize, int warmup, int it
{
simple_assert(groupsize == 0);
}
else if (cutlassTypeMapper<KT>::QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS)
else if (cutlassTypeMapper<KT>::QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY)
{
simple_assert(groupsize == 64 || groupsize == 128);
}
@ -379,7 +379,6 @@ bool benchmark_and_verify(int m, int n, int k, int groupsize, int warmup, int it
if (groupsize != 0)
{
p_zeros = d_zeros.data();
p_bias = d_bias.data();
p_act_scale = d_act_scale.data();
}
@ -402,9 +401,9 @@ TEST(Kernel, WeightOnly)
int const arch = tensorrt_llm::common::getSMVersion();
bool pass;
int warmup = 10, iter = 30;
std::vector<int> ms{2, 4, 6, 8, 10, 12, 14};
std::vector<int> ms{1, 2, 4, 6, 8, 10, 12, 14};
std::vector<int> ns{4096};
std::vector<int> ks{2048};
std::vector<int> ks{4096};
for (auto m : ms)
{
for (auto n : ns)
@ -428,6 +427,10 @@ TEST(Kernel, WeightOnly)
#if defined(ENABLE_BF16)
if (arch >= 80)
{
pass = benchmark_and_verify<wo::KernelType::BF16Int8PerChannel>(m, n, k, 0, warmup, iter);
EXPECT_TRUE(pass);
pass = benchmark_and_verify<wo::KernelType::BF16Int4PerChannel>(m, n, k, 0, warmup, iter);
EXPECT_TRUE(pass);
pass = benchmark_and_verify<wo::KernelType::BF16Int8Groupwise>(m, n, k, 64, warmup, iter);
EXPECT_TRUE(pass);
pass = benchmark_and_verify<wo::KernelType::BF16Int8Groupwise>(m, n, k, 128, warmup, iter);
@ -436,10 +439,6 @@ TEST(Kernel, WeightOnly)
EXPECT_TRUE(pass);
pass = benchmark_and_verify<wo::KernelType::BF16Int4Groupwise>(m, n, k, 128, warmup, iter);
EXPECT_TRUE(pass);
pass = benchmark_and_verify<wo::KernelType::BF16Int8PerChannel>(m, n, k, 0, warmup, iter);
EXPECT_TRUE(pass);
pass = benchmark_and_verify<wo::KernelType::BF16Int4PerChannel>(m, n, k, 0, warmup, iter);
EXPECT_TRUE(pass);
}
#endif
}

View File

@ -644,7 +644,7 @@ class TestSmoothQuant(unittest.TestCase):
k = 4096
# Init operands for multiplication in int32
mat1 = _utils.woq_gen_weights(m, k, dtype) * 200.0
mat1 = _utils.woq_gen_weights(m, k, dtype)
weight = _utils.woq_gen_weights(k, n, dtype)
ref_torch_weights, processed_torch_weights, torch_weight_scales = _utils.woq_conversion(

View File

@ -87,7 +87,7 @@ class TestWeightOnlyQuantMatmul(unittest.TestCase):
def _woq_matmul(self, m, n, k, dtype, wTypeId, use_plugin=True):
# Init operands for multiplication in int32
mat1 = _utils.woq_gen_weights(m, k, dtype) * 200.0
mat1 = _utils.woq_gen_weights(m, k, dtype)
weight = _utils.woq_gen_weights(k, n, dtype)
ref_torch_weights, processed_torch_weights, torch_weight_scales = _utils.woq_conversion(