From 64db7d27f60997563bd68c1a8ab1b057e8016dd4 Mon Sep 17 00:00:00 2001 From: Cheng Hang <854803517@qq.com> Date: Mon, 30 Jun 2025 10:20:16 +0800 Subject: [PATCH] [feat] Optimizations on weight-only batched gemv kernel (#5420) Signed-off-by: Cheng Hang --- .../kernels/weightOnlyBatchedGemv/kernel.h | 27 +++++-- .../kernels/weightOnlyBatchedGemv/utility.h | 75 ++++++++++++------- .../weightOnly/weightOnlyKernelTest.cpp | 27 ++++--- .../trt/quantization/test_quant_layer.py | 2 +- .../test_weight_only_quant_matmul.py | 2 +- 5 files changed, 83 insertions(+), 50 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h index 7df9305d98..de4a960e14 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h @@ -63,6 +63,10 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca = (tid * StepK / (Details::kInterleave * Details::LayoutDetails::kTileSize)) * Details::LayoutDetails::kTileSize + ((tid * StepK) % Details::LayoutDetails::kTileSize); + bool constexpr scale_zero_ldg128 = Details::kInterleave == 1 && CtaN == 8; + + using AccessTypeScaleZero = std::conditional_t; + GMemIterator act_iterator( act, offset_m * origin_k + real_offset_k, CtaK / Details::kInterleave, origin_k); GMemIterator act_scale_iterator( @@ -70,10 +74,10 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca GMemIterator weight_iterator(weight, (interleaved_offset_n * interleaved_k + tid * StepK) / Details::kElemsPerByteW, CtaK / Details::kElemsPerByteW, interleaved_k / Details::kElemsPerByteW); - GMemIterator scales_iterator(scales, + GMemIterator scales_iterator(scales, (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n, (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave); - GMemIterator zeros_iterator(zeros, + GMemIterator zeros_iterator(zeros, (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n, (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave); @@ -92,11 +96,19 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca TypeA vec_scale[CtaN], vec_zero[CtaN]; TypeA tile_a[StepK], tile_w[StepK], tile_w_pack2[CtaN * StepK]; uint8_t tile_w_quantized[StepK / Details::kElemsPerByteW]; -#pragma unroll - for (int i = 0; i < CtaN; ++i) + if constexpr (scale_zero_ldg128) { - scales_iterator.load(vec_scale + i, iter, i); - zeros_iterator.load(vec_zero + i, iter, i); + scales_iterator.load(vec_scale, iter); + zeros_iterator.load(vec_zero, iter); + } + else + { +#pragma unroll + for (int i = 0; i < CtaN; ++i) + { + scales_iterator.load(vec_scale + i, iter, i); + zeros_iterator.load(vec_zero + i, iter, i); + } } act_scale_iterator.load(vec_act_scale, iter); #pragma unroll @@ -112,7 +124,8 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca { act_iterator.load(tile_a, iter, i); apply_scale(tile_a, vec_act_scale); - mma(tile_acc + i * CtaN, tile_w_pack2, tile_a); + mma( + tile_acc + i * CtaN, tile_w_pack2, tile_a, vec_scale); } } epilogue(out, n, tile_acc, bias, alpha); diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h index 8dfdf9f383..4e660f0d60 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h @@ -133,33 +133,37 @@ __device__ __forceinline__ void dequantize(void* w, void* quantized_w, void* sca { Converter::convert(reinterpret_cast(quantized_w) + n * K / Details::kElemsPerByteW, reinterpret_cast(w) + n * K); - Type2 vec_scale, vec_zero; - if constexpr (ApplyAlphaInAdvance) + + if constexpr (EnableZero || ApplyAlphaInAdvance) { - // For W4A8, we assume scales/zero is always half data type, no matter activation dtype is bf16 or fp16 - Type scales_ = static_cast(reinterpret_cast(scales)[n]) * alpha; - vec_scale = MathWrapper::to_vec2(scales_); - vec_zero = MathWrapper::to_vec2(static_cast(0.f)); - if constexpr (EnableZero) + Type2 vec_scale, vec_zero; + if constexpr (ApplyAlphaInAdvance) { - vec_zero = MathWrapper::to_vec2( - static_cast(reinterpret_cast(zeros)[n]) * alpha); + Type scales_ = static_cast(reinterpret_cast(scales)[n]) * alpha; + vec_scale = MathWrapper::to_vec2(scales_); + vec_zero = MathWrapper::to_vec2(static_cast(0.f)); + if constexpr (EnableZero) + { + vec_zero = MathWrapper::to_vec2( + static_cast(reinterpret_cast(zeros)[n]) * alpha); + } } - } - else - { - vec_scale = MathWrapper::to_vec2(reinterpret_cast(scales)[n]); - vec_zero = MathWrapper::to_vec2(static_cast(0.f)); - if constexpr (EnableZero) + else { - vec_zero = MathWrapper::to_vec2(reinterpret_cast(zeros)[n]); + vec_scale = MathWrapper::to_vec2(reinterpret_cast(scales)[n]); + vec_zero = MathWrapper::to_vec2(static_cast(0.f)); + if constexpr (EnableZero) + { + vec_zero = MathWrapper::to_vec2(reinterpret_cast(zeros)[n]); + } } - } + #pragma unroll - for (int k = 0; k < VecK; ++k) - { - reinterpret_cast(w)[n * VecK + k] = MathWrapper::fma2( - reinterpret_cast(w)[n * VecK + k], vec_scale, vec_zero); + for (int k = 0; k < VecK; ++k) + { + reinterpret_cast(w)[n * VecK + k] = MathWrapper::fma2( + reinterpret_cast(w)[n * VecK + k], vec_scale, vec_zero); + } } } } @@ -177,8 +181,8 @@ __device__ __forceinline__ void pack_to_vec2(void* dst, void* src, int n) } } -template -__device__ __forceinline__ void mma(void* acc, void* w_pack2, void* act) +template +__device__ __forceinline__ void mma(void* acc, void* w_pack2, void* act, void* scale) { using Type = typename MathWrapper::Type; using Type2 = typename MathWrapper::Type2; @@ -190,13 +194,30 @@ __device__ __forceinline__ void mma(void* acc, void* w_pack2, void* act) #pragma unroll for (int n = 0; n < VecN; ++n) { -#pragma unroll - for (int k = 0; k < K; ++k) + if constexpr (EnableZero || ApplyAlphaInAdvance) { - reinterpret_cast(acc)[m * VecN + n] - = MathWrapper::fma2(reinterpret_cast(w_pack2)[n * K + k], +#pragma unroll + for (int k = 0; k < K; ++k) + { + reinterpret_cast(acc)[m * VecN + n] = MathWrapper::fma2( + reinterpret_cast(w_pack2)[n * K + k], MathWrapper::to_vec2(reinterpret_cast(act)[m * K + k]), reinterpret_cast(acc)[m * VecN + n]); + } + } + else + { + Type2 local_acc{}; +#pragma unroll + for (int k = 0; k < K; ++k) + { + local_acc = MathWrapper::fma2( + reinterpret_cast(w_pack2)[n * K + k], + MathWrapper::to_vec2(reinterpret_cast(act)[m * K + k]), + local_acc); + } + reinterpret_cast(acc)[m * VecN + n] = MathWrapper::fma2( + local_acc, reinterpret_cast(scale)[n], reinterpret_cast(acc)[m * VecN + n]); } } } diff --git a/cpp/tests/unit_tests/kernels/weightOnly/weightOnlyKernelTest.cpp b/cpp/tests/unit_tests/kernels/weightOnly/weightOnlyKernelTest.cpp index 3f22d594e2..ba95cfbe7d 100644 --- a/cpp/tests/unit_tests/kernels/weightOnly/weightOnlyKernelTest.cpp +++ b/cpp/tests/unit_tests/kernels/weightOnly/weightOnlyKernelTest.cpp @@ -165,13 +165,13 @@ struct cutlassTypeMapper } \ }; CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::FP16Int8Groupwise, "FP16Int8Groupwise", half, uint8_t, 8, - cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS); + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY); CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::BF16Int8Groupwise, "BF16Int8Groupwise", __nv_bfloat16, uint8_t, 8, - cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS); + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY); CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::FP16Int4Groupwise, "FP16Int4Groupwise", half, cutlass::uint4b_t, 4, - cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS); + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY); CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::BF16Int4Groupwise, "BF16Int4Groupwise", __nv_bfloat16, cutlass::uint4b_t, - 4, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS); + 4, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY); CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::FP16Int8PerChannel, "FP16Int8PerChannel", half, uint8_t, 8, cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY); CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::BF16Int8PerChannel, "BF16Int8PerChannel", __nv_bfloat16, uint8_t, 8, @@ -228,7 +228,7 @@ void exec_cutlass_kernel( runner.gemm( act, params.weight, params.scales, params.out, params.m, params.n, params.k, config, ws, ws_size, stream); } - else if (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS) + else if (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY) { runner.gemm(act, params.weight, params.scales, params.zeros, params.bias, params.out, params.m, params.n, params.k, params.groupsize, config, ws, ws_size, stream); @@ -290,7 +290,7 @@ float run_cutlass_kernel(wo::Params& params, int warmup, int iter) msg << "\n (for" << " m=" << params.m << ", n=" << params.n << ", k=" << params.k << ")" << ", reason: \"" << e.what() << "\". Skipped\n"; - std::cout << msg.str(); + TLLM_LOG_TRACE(msg.str()); cudaGetLastError(); // Reset the last cudaError to cudaSuccess. continue; } @@ -332,7 +332,7 @@ bool benchmark_and_verify(int m, int n, int k, int groupsize, int warmup, int it { simple_assert(groupsize == 0); } - else if (cutlassTypeMapper::QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS) + else if (cutlassTypeMapper::QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY) { simple_assert(groupsize == 64 || groupsize == 128); } @@ -379,7 +379,6 @@ bool benchmark_and_verify(int m, int n, int k, int groupsize, int warmup, int it if (groupsize != 0) { - p_zeros = d_zeros.data(); p_bias = d_bias.data(); p_act_scale = d_act_scale.data(); } @@ -402,9 +401,9 @@ TEST(Kernel, WeightOnly) int const arch = tensorrt_llm::common::getSMVersion(); bool pass; int warmup = 10, iter = 30; - std::vector ms{2, 4, 6, 8, 10, 12, 14}; + std::vector ms{1, 2, 4, 6, 8, 10, 12, 14}; std::vector ns{4096}; - std::vector ks{2048}; + std::vector ks{4096}; for (auto m : ms) { for (auto n : ns) @@ -428,6 +427,10 @@ TEST(Kernel, WeightOnly) #if defined(ENABLE_BF16) if (arch >= 80) { + pass = benchmark_and_verify(m, n, k, 0, warmup, iter); + EXPECT_TRUE(pass); + pass = benchmark_and_verify(m, n, k, 0, warmup, iter); + EXPECT_TRUE(pass); pass = benchmark_and_verify(m, n, k, 64, warmup, iter); EXPECT_TRUE(pass); pass = benchmark_and_verify(m, n, k, 128, warmup, iter); @@ -436,10 +439,6 @@ TEST(Kernel, WeightOnly) EXPECT_TRUE(pass); pass = benchmark_and_verify(m, n, k, 128, warmup, iter); EXPECT_TRUE(pass); - pass = benchmark_and_verify(m, n, k, 0, warmup, iter); - EXPECT_TRUE(pass); - pass = benchmark_and_verify(m, n, k, 0, warmup, iter); - EXPECT_TRUE(pass); } #endif } diff --git a/tests/unittest/trt/quantization/test_quant_layer.py b/tests/unittest/trt/quantization/test_quant_layer.py index 2f6c67da35..20cfad1a01 100644 --- a/tests/unittest/trt/quantization/test_quant_layer.py +++ b/tests/unittest/trt/quantization/test_quant_layer.py @@ -644,7 +644,7 @@ class TestSmoothQuant(unittest.TestCase): k = 4096 # Init operands for multiplication in int32 - mat1 = _utils.woq_gen_weights(m, k, dtype) * 200.0 + mat1 = _utils.woq_gen_weights(m, k, dtype) weight = _utils.woq_gen_weights(k, n, dtype) ref_torch_weights, processed_torch_weights, torch_weight_scales = _utils.woq_conversion( diff --git a/tests/unittest/trt/quantization/test_weight_only_quant_matmul.py b/tests/unittest/trt/quantization/test_weight_only_quant_matmul.py index 0f779c0613..1dae1b405d 100644 --- a/tests/unittest/trt/quantization/test_weight_only_quant_matmul.py +++ b/tests/unittest/trt/quantization/test_weight_only_quant_matmul.py @@ -87,7 +87,7 @@ class TestWeightOnlyQuantMatmul(unittest.TestCase): def _woq_matmul(self, m, n, k, dtype, wTypeId, use_plugin=True): # Init operands for multiplication in int32 - mat1 = _utils.woq_gen_weights(m, k, dtype) * 200.0 + mat1 = _utils.woq_gen_weights(m, k, dtype) weight = _utils.woq_gen_weights(k, n, dtype) ref_torch_weights, processed_torch_weights, torch_weight_scales = _utils.woq_conversion(