mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[feat] Optimizations on weight-only batched gemv kernel (#5420)
Signed-off-by: Cheng Hang <chang@nvidia.com>
This commit is contained in:
parent
b4dab23e7b
commit
64db7d27f6
@ -63,6 +63,10 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca
|
||||
= (tid * StepK / (Details::kInterleave * Details::LayoutDetails::kTileSize)) * Details::LayoutDetails::kTileSize
|
||||
+ ((tid * StepK) % Details::LayoutDetails::kTileSize);
|
||||
|
||||
bool constexpr scale_zero_ldg128 = Details::kInterleave == 1 && CtaN == 8;
|
||||
|
||||
using AccessTypeScaleZero = std::conditional_t<scale_zero_ldg128, AccessTypeA, TypeA>;
|
||||
|
||||
GMemIterator<Mandatory, AccessTypeA, CtaM, Details::kAccessNumA, TypeA> act_iterator(
|
||||
act, offset_m * origin_k + real_offset_k, CtaK / Details::kInterleave, origin_k);
|
||||
GMemIterator<EnableActScale, AccessTypeA, 1, Details::kAccessNumA, TypeA> act_scale_iterator(
|
||||
@ -70,10 +74,10 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca
|
||||
GMemIterator<Mandatory, AccessTypeW, CtaN, Details::kAccessNumW, uint8_t> weight_iterator(weight,
|
||||
(interleaved_offset_n * interleaved_k + tid * StepK) / Details::kElemsPerByteW, CtaK / Details::kElemsPerByteW,
|
||||
interleaved_k / Details::kElemsPerByteW);
|
||||
GMemIterator<Mandatory, TypeA, CtaN, 1, TypeA> scales_iterator(scales,
|
||||
GMemIterator<Mandatory, AccessTypeScaleZero, CtaN, 1, TypeA> scales_iterator(scales,
|
||||
(GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n,
|
||||
(GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave);
|
||||
GMemIterator<EnableZero, TypeA, CtaN, 1, TypeA> zeros_iterator(zeros,
|
||||
GMemIterator<EnableZero, AccessTypeScaleZero, CtaN, 1, TypeA> zeros_iterator(zeros,
|
||||
(GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n,
|
||||
(GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave);
|
||||
|
||||
@ -92,11 +96,19 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca
|
||||
TypeA vec_scale[CtaN], vec_zero[CtaN];
|
||||
TypeA tile_a[StepK], tile_w[StepK], tile_w_pack2[CtaN * StepK];
|
||||
uint8_t tile_w_quantized[StepK / Details::kElemsPerByteW];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CtaN; ++i)
|
||||
if constexpr (scale_zero_ldg128)
|
||||
{
|
||||
scales_iterator.load(vec_scale + i, iter, i);
|
||||
zeros_iterator.load(vec_zero + i, iter, i);
|
||||
scales_iterator.load(vec_scale, iter);
|
||||
zeros_iterator.load(vec_zero, iter);
|
||||
}
|
||||
else
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CtaN; ++i)
|
||||
{
|
||||
scales_iterator.load(vec_scale + i, iter, i);
|
||||
zeros_iterator.load(vec_zero + i, iter, i);
|
||||
}
|
||||
}
|
||||
act_scale_iterator.load(vec_act_scale, iter);
|
||||
#pragma unroll
|
||||
@ -112,7 +124,8 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca
|
||||
{
|
||||
act_iterator.load(tile_a, iter, i);
|
||||
apply_scale<Details, 1, StepK, EnableActScale>(tile_a, vec_act_scale);
|
||||
mma<Details, 1, CtaN, StepK>(tile_acc + i * CtaN, tile_w_pack2, tile_a);
|
||||
mma<Details, 1, CtaN, StepK, EnableZero, ApplyAlphaInAdvance>(
|
||||
tile_acc + i * CtaN, tile_w_pack2, tile_a, vec_scale);
|
||||
}
|
||||
}
|
||||
epilogue<Details, CtaM, CtaN, Threads, EnableBias, ApplyAlphaInAdvance>(out, n, tile_acc, bias, alpha);
|
||||
|
||||
@ -133,33 +133,37 @@ __device__ __forceinline__ void dequantize(void* w, void* quantized_w, void* sca
|
||||
{
|
||||
Converter::convert<K>(reinterpret_cast<uint8_t*>(quantized_w) + n * K / Details::kElemsPerByteW,
|
||||
reinterpret_cast<Type*>(w) + n * K);
|
||||
Type2 vec_scale, vec_zero;
|
||||
if constexpr (ApplyAlphaInAdvance)
|
||||
|
||||
if constexpr (EnableZero || ApplyAlphaInAdvance)
|
||||
{
|
||||
// For W4A8, we assume scales/zero is always half data type, no matter activation dtype is bf16 or fp16
|
||||
Type scales_ = static_cast<float>(reinterpret_cast<half*>(scales)[n]) * alpha;
|
||||
vec_scale = MathWrapper<typename Details::TypeDetailsA>::to_vec2(scales_);
|
||||
vec_zero = MathWrapper<typename Details::TypeDetailsA>::to_vec2(static_cast<Type>(0.f));
|
||||
if constexpr (EnableZero)
|
||||
Type2 vec_scale, vec_zero;
|
||||
if constexpr (ApplyAlphaInAdvance)
|
||||
{
|
||||
vec_zero = MathWrapper<typename Details::TypeDetailsA>::to_vec2(
|
||||
static_cast<float>(reinterpret_cast<half*>(zeros)[n]) * alpha);
|
||||
Type scales_ = static_cast<float>(reinterpret_cast<half*>(scales)[n]) * alpha;
|
||||
vec_scale = MathWrapper<typename Details::TypeDetailsA>::to_vec2(scales_);
|
||||
vec_zero = MathWrapper<typename Details::TypeDetailsA>::to_vec2(static_cast<Type>(0.f));
|
||||
if constexpr (EnableZero)
|
||||
{
|
||||
vec_zero = MathWrapper<typename Details::TypeDetailsA>::to_vec2(
|
||||
static_cast<float>(reinterpret_cast<half*>(zeros)[n]) * alpha);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
vec_scale = MathWrapper<typename Details::TypeDetailsA>::to_vec2(reinterpret_cast<Type*>(scales)[n]);
|
||||
vec_zero = MathWrapper<typename Details::TypeDetailsA>::to_vec2(static_cast<Type>(0.f));
|
||||
if constexpr (EnableZero)
|
||||
else
|
||||
{
|
||||
vec_zero = MathWrapper<typename Details::TypeDetailsA>::to_vec2(reinterpret_cast<Type*>(zeros)[n]);
|
||||
vec_scale = MathWrapper<typename Details::TypeDetailsA>::to_vec2(reinterpret_cast<Type*>(scales)[n]);
|
||||
vec_zero = MathWrapper<typename Details::TypeDetailsA>::to_vec2(static_cast<Type>(0.f));
|
||||
if constexpr (EnableZero)
|
||||
{
|
||||
vec_zero = MathWrapper<typename Details::TypeDetailsA>::to_vec2(reinterpret_cast<Type*>(zeros)[n]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < VecK; ++k)
|
||||
{
|
||||
reinterpret_cast<Type2*>(w)[n * VecK + k] = MathWrapper<typename Details::TypeDetailsA>::fma2(
|
||||
reinterpret_cast<Type2*>(w)[n * VecK + k], vec_scale, vec_zero);
|
||||
for (int k = 0; k < VecK; ++k)
|
||||
{
|
||||
reinterpret_cast<Type2*>(w)[n * VecK + k] = MathWrapper<typename Details::TypeDetailsA>::fma2(
|
||||
reinterpret_cast<Type2*>(w)[n * VecK + k], vec_scale, vec_zero);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -177,8 +181,8 @@ __device__ __forceinline__ void pack_to_vec2(void* dst, void* src, int n)
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Details, int M, int N, int K>
|
||||
__device__ __forceinline__ void mma(void* acc, void* w_pack2, void* act)
|
||||
template <typename Details, int M, int N, int K, bool EnableZero, bool ApplyAlphaInAdvance>
|
||||
__device__ __forceinline__ void mma(void* acc, void* w_pack2, void* act, void* scale)
|
||||
{
|
||||
using Type = typename MathWrapper<typename Details::TypeDetailsA>::Type;
|
||||
using Type2 = typename MathWrapper<typename Details::TypeDetailsA>::Type2;
|
||||
@ -190,13 +194,30 @@ __device__ __forceinline__ void mma(void* acc, void* w_pack2, void* act)
|
||||
#pragma unroll
|
||||
for (int n = 0; n < VecN; ++n)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int k = 0; k < K; ++k)
|
||||
if constexpr (EnableZero || ApplyAlphaInAdvance)
|
||||
{
|
||||
reinterpret_cast<Type2*>(acc)[m * VecN + n]
|
||||
= MathWrapper<typename Details::TypeDetailsA>::fma2(reinterpret_cast<Type2*>(w_pack2)[n * K + k],
|
||||
#pragma unroll
|
||||
for (int k = 0; k < K; ++k)
|
||||
{
|
||||
reinterpret_cast<Type2*>(acc)[m * VecN + n] = MathWrapper<typename Details::TypeDetailsA>::fma2(
|
||||
reinterpret_cast<Type2*>(w_pack2)[n * K + k],
|
||||
MathWrapper<typename Details::TypeDetailsA>::to_vec2(reinterpret_cast<Type*>(act)[m * K + k]),
|
||||
reinterpret_cast<Type2*>(acc)[m * VecN + n]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
Type2 local_acc{};
|
||||
#pragma unroll
|
||||
for (int k = 0; k < K; ++k)
|
||||
{
|
||||
local_acc = MathWrapper<typename Details::TypeDetailsA>::fma2(
|
||||
reinterpret_cast<Type2*>(w_pack2)[n * K + k],
|
||||
MathWrapper<typename Details::TypeDetailsA>::to_vec2(reinterpret_cast<Type*>(act)[m * K + k]),
|
||||
local_acc);
|
||||
}
|
||||
reinterpret_cast<Type2*>(acc)[m * VecN + n] = MathWrapper<typename Details::TypeDetailsA>::fma2(
|
||||
local_acc, reinterpret_cast<Type2*>(scale)[n], reinterpret_cast<Type2*>(acc)[m * VecN + n]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -165,13 +165,13 @@ struct cutlassTypeMapper
|
||||
} \
|
||||
};
|
||||
CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::FP16Int8Groupwise, "FP16Int8Groupwise", half, uint8_t, 8,
|
||||
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS);
|
||||
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY);
|
||||
CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::BF16Int8Groupwise, "BF16Int8Groupwise", __nv_bfloat16, uint8_t, 8,
|
||||
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS);
|
||||
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY);
|
||||
CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::FP16Int4Groupwise, "FP16Int4Groupwise", half, cutlass::uint4b_t, 4,
|
||||
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS);
|
||||
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY);
|
||||
CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::BF16Int4Groupwise, "BF16Int4Groupwise", __nv_bfloat16, cutlass::uint4b_t,
|
||||
4, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS);
|
||||
4, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY);
|
||||
CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::FP16Int8PerChannel, "FP16Int8PerChannel", half, uint8_t, 8,
|
||||
cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY);
|
||||
CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::BF16Int8PerChannel, "BF16Int8PerChannel", __nv_bfloat16, uint8_t, 8,
|
||||
@ -228,7 +228,7 @@ void exec_cutlass_kernel(
|
||||
runner.gemm(
|
||||
act, params.weight, params.scales, params.out, params.m, params.n, params.k, config, ws, ws_size, stream);
|
||||
}
|
||||
else if (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS)
|
||||
else if (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY)
|
||||
{
|
||||
runner.gemm(act, params.weight, params.scales, params.zeros, params.bias, params.out, params.m, params.n,
|
||||
params.k, params.groupsize, config, ws, ws_size, stream);
|
||||
@ -290,7 +290,7 @@ float run_cutlass_kernel(wo::Params& params, int warmup, int iter)
|
||||
msg << "\n (for"
|
||||
<< " m=" << params.m << ", n=" << params.n << ", k=" << params.k << ")"
|
||||
<< ", reason: \"" << e.what() << "\". Skipped\n";
|
||||
std::cout << msg.str();
|
||||
TLLM_LOG_TRACE(msg.str());
|
||||
cudaGetLastError(); // Reset the last cudaError to cudaSuccess.
|
||||
continue;
|
||||
}
|
||||
@ -332,7 +332,7 @@ bool benchmark_and_verify(int m, int n, int k, int groupsize, int warmup, int it
|
||||
{
|
||||
simple_assert(groupsize == 0);
|
||||
}
|
||||
else if (cutlassTypeMapper<KT>::QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS)
|
||||
else if (cutlassTypeMapper<KT>::QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY)
|
||||
{
|
||||
simple_assert(groupsize == 64 || groupsize == 128);
|
||||
}
|
||||
@ -379,7 +379,6 @@ bool benchmark_and_verify(int m, int n, int k, int groupsize, int warmup, int it
|
||||
|
||||
if (groupsize != 0)
|
||||
{
|
||||
p_zeros = d_zeros.data();
|
||||
p_bias = d_bias.data();
|
||||
p_act_scale = d_act_scale.data();
|
||||
}
|
||||
@ -402,9 +401,9 @@ TEST(Kernel, WeightOnly)
|
||||
int const arch = tensorrt_llm::common::getSMVersion();
|
||||
bool pass;
|
||||
int warmup = 10, iter = 30;
|
||||
std::vector<int> ms{2, 4, 6, 8, 10, 12, 14};
|
||||
std::vector<int> ms{1, 2, 4, 6, 8, 10, 12, 14};
|
||||
std::vector<int> ns{4096};
|
||||
std::vector<int> ks{2048};
|
||||
std::vector<int> ks{4096};
|
||||
for (auto m : ms)
|
||||
{
|
||||
for (auto n : ns)
|
||||
@ -428,6 +427,10 @@ TEST(Kernel, WeightOnly)
|
||||
#if defined(ENABLE_BF16)
|
||||
if (arch >= 80)
|
||||
{
|
||||
pass = benchmark_and_verify<wo::KernelType::BF16Int8PerChannel>(m, n, k, 0, warmup, iter);
|
||||
EXPECT_TRUE(pass);
|
||||
pass = benchmark_and_verify<wo::KernelType::BF16Int4PerChannel>(m, n, k, 0, warmup, iter);
|
||||
EXPECT_TRUE(pass);
|
||||
pass = benchmark_and_verify<wo::KernelType::BF16Int8Groupwise>(m, n, k, 64, warmup, iter);
|
||||
EXPECT_TRUE(pass);
|
||||
pass = benchmark_and_verify<wo::KernelType::BF16Int8Groupwise>(m, n, k, 128, warmup, iter);
|
||||
@ -436,10 +439,6 @@ TEST(Kernel, WeightOnly)
|
||||
EXPECT_TRUE(pass);
|
||||
pass = benchmark_and_verify<wo::KernelType::BF16Int4Groupwise>(m, n, k, 128, warmup, iter);
|
||||
EXPECT_TRUE(pass);
|
||||
pass = benchmark_and_verify<wo::KernelType::BF16Int8PerChannel>(m, n, k, 0, warmup, iter);
|
||||
EXPECT_TRUE(pass);
|
||||
pass = benchmark_and_verify<wo::KernelType::BF16Int4PerChannel>(m, n, k, 0, warmup, iter);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -644,7 +644,7 @@ class TestSmoothQuant(unittest.TestCase):
|
||||
k = 4096
|
||||
|
||||
# Init operands for multiplication in int32
|
||||
mat1 = _utils.woq_gen_weights(m, k, dtype) * 200.0
|
||||
mat1 = _utils.woq_gen_weights(m, k, dtype)
|
||||
weight = _utils.woq_gen_weights(k, n, dtype)
|
||||
|
||||
ref_torch_weights, processed_torch_weights, torch_weight_scales = _utils.woq_conversion(
|
||||
|
||||
@ -87,7 +87,7 @@ class TestWeightOnlyQuantMatmul(unittest.TestCase):
|
||||
|
||||
def _woq_matmul(self, m, n, k, dtype, wTypeId, use_plugin=True):
|
||||
# Init operands for multiplication in int32
|
||||
mat1 = _utils.woq_gen_weights(m, k, dtype) * 200.0
|
||||
mat1 = _utils.woq_gen_weights(m, k, dtype)
|
||||
weight = _utils.woq_gen_weights(k, n, dtype)
|
||||
|
||||
ref_torch_weights, processed_torch_weights, torch_weight_scales = _utils.woq_conversion(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user