TensorRT-LLMs/cpp/tests/kernels/weightOnly/weightOnlyKernelTest.cpp
Kaiyu Xie 6755a3f077
Update TensorRT-LLM (#422)
* Update TensorRT-LLM

---------

Co-authored-by: Tltin <TltinDeng01@gmail.com>
Co-authored-by: zhaohb <zhaohbcloud@126.com>
Co-authored-by: Bradley Heilbrun <brad@repl.it>
Co-authored-by: nqbao11 <nqbao11.01@gmail.com>
Co-authored-by: Nikhil Varghese <nikhil@bot-it.ai>
2023-11-18 00:05:54 +08:00

442 lines
14 KiB
C++

#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include "cutlass/numeric_types.h"
#include "tensorrt_llm/common/quantization.h"
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h"
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/enabled.h"
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h"
#include <algorithm>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#include <functional>
#include <iostream>
#include <random>
#include <set>
#include <type_traits>
#include <vector>
using tensorrt_llm::kernels::WeightOnlyParams;
using tensorrt_llm::kernels::WeightOnlyType;
using tensorrt_llm::kernels::WeightOnlyQuantType;
using tensorrt_llm::kernels::WeightOnlyActivationType;
using tensorrt_llm::kernels::WeightOnlyActivationFunctionType;
template <WeightOnlyActivationType T>
struct AType;
template <>
struct AType<WeightOnlyActivationType::FP16>
{
using CudaKernelAType = half;
using CutlassKernelAType = half;
};
#if defined(ENABLE_BF16)
template <>
struct AType<WeightOnlyActivationType::BF16>
{
using CudaKernelAType = __nv_bfloat16;
using CutlassKernelAType = __nv_bfloat16;
};
#endif
template <WeightOnlyQuantType T>
struct BType;
template <>
struct BType<WeightOnlyQuantType::Int4b>
{
using CudaKernelBType = uint8_t;
using CutlassKernelBType = cutlass::uint4b_t;
static constexpr int elemsPerByte = 2;
};
template <>
struct BType<WeightOnlyQuantType::Int8b>
{
using CudaKernelBType = uint8_t;
using CutlassKernelBType = uint8_t;
static constexpr int elemsPerByte = 1;
};
struct CutlassKernel;
struct CudaKernel;
void simple_assert(bool flag)
{
if (!flag)
{
throw std::runtime_error("assert failed");
}
}
template <typename KernelFlag, WeightOnlyActivationType AFlag, WeightOnlyQuantType BFlag>
float benchmark_perchannel(void* act, void* weight, void* scales, void* zeros, void* bias, void* out, int m, int n,
int k, int group_size, int warmup, int iter)
{
simple_assert(zeros == nullptr && bias == nullptr && group_size == 0);
cudaStream_t s;
cudaStreamCreate(&s);
cudaEvent_t begin, end;
cudaEventCreate(&begin);
cudaEventCreate(&end);
if constexpr (std::is_same_v<KernelFlag, CudaKernel>)
{
WeightOnlyParams params{reinterpret_cast<uint8_t*>(weight), scales, zeros, act, bias, out, m, n, k, group_size,
BFlag, WeightOnlyType::PerChannel, WeightOnlyActivationFunctionType::Identity, AFlag};
for (int i = 0; i < warmup; ++i)
{
tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, s);
}
cudaEventRecord(begin, s);
for (int i = 0; i < iter; ++i)
{
tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, s);
}
}
else if (std::is_same_v<KernelFlag, CutlassKernel>)
{
tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<typename AType<AFlag>::CutlassKernelAType,
typename BType<BFlag>::CutlassKernelBType, cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>
gemm;
auto configs = gemm.getConfigs();
int ws_bytes = gemm.getWorkspaceSize(m, n, k);
char* ws_ptr = nullptr;
if (ws_bytes)
cudaMalloc(&ws_ptr, ws_bytes);
float fast_time = 1e8;
auto best_config = configs[0];
for (auto& config : configs)
{
for (int i = 0; i < 2; ++i)
{
gemm.gemm(act, weight, scales, out, m, n, k, config, ws_ptr, ws_bytes, s);
}
cudaEventRecord(begin, s);
for (int i = 0; i < 5; ++i)
{
gemm.gemm(act, weight, scales, out, m, n, k, config, ws_ptr, ws_bytes, s);
}
cudaEventRecord(end, s);
cudaEventSynchronize(end);
float time;
cudaEventElapsedTime(&time, begin, end);
if (time < fast_time)
{
fast_time = time;
best_config = config;
}
}
for (int i = 0; i < warmup; ++i)
{
gemm.gemm(act, weight, scales, out, m, n, k, best_config, ws_ptr, ws_bytes, s);
}
cudaEventRecord(begin, s);
for (int i = 0; i < iter; ++i)
{
gemm.gemm(act, weight, scales, out, m, n, k, best_config, ws_ptr, ws_bytes, s);
}
if (ws_ptr)
cudaFree(ws_ptr);
}
cudaEventRecord(end, s);
cudaEventSynchronize(end);
float time;
cudaEventElapsedTime(&time, begin, end);
cudaEventDestroy(begin);
cudaEventDestroy(end);
cudaStreamDestroy(s);
return time / iter;
}
template <typename KernelFlag, WeightOnlyActivationType AFlag, WeightOnlyQuantType BFlag>
float benchmark_groupwise(void* act, void* weight, void* scales, void* zeros, void* bias, void* out, int m, int n,
int k, int group_size, int warmup, int iter)
{
simple_assert(zeros && bias && (group_size == 64 || group_size == 128));
cudaStream_t s;
cudaStreamCreate(&s);
cudaEvent_t begin, end;
cudaEventCreate(&begin);
cudaEventCreate(&end);
if constexpr (std::is_same_v<KernelFlag, CudaKernel>)
{
WeightOnlyParams params{reinterpret_cast<uint8_t*>(weight), scales, zeros, act, bias, out, m, n, k, group_size,
BFlag, WeightOnlyType::GroupWise, WeightOnlyActivationFunctionType::Identity, AFlag};
for (int i = 0; i < warmup; ++i)
{
tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, s);
}
cudaEventRecord(begin, s);
for (int i = 0; i < iter; ++i)
{
tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, s);
}
}
else if (std::is_same_v<KernelFlag, CutlassKernel>)
{
tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<typename AType<AFlag>::CutlassKernelAType,
typename BType<BFlag>::CutlassKernelBType, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>
gemm;
auto configs = gemm.getConfigs();
int ws_bytes = gemm.getWorkspaceSize(m, n, k);
char* ws_ptr = nullptr;
if (ws_bytes)
cudaMalloc(&ws_ptr, ws_bytes);
float fast_time = 1e8;
auto best_config = configs[0];
for (auto& config : configs)
{
for (int i = 0; i < 2; ++i)
{
gemm.gemm(act, weight, scales, zeros, bias, out, m, n, k, group_size, config, ws_ptr, ws_bytes, s);
}
cudaEventRecord(begin, s);
for (int i = 0; i < 5; ++i)
{
gemm.gemm(act, weight, scales, zeros, bias, out, m, n, k, group_size, config, ws_ptr, ws_bytes, s);
}
cudaEventRecord(end, s);
cudaEventSynchronize(end);
float time;
cudaEventElapsedTime(&time, begin, end);
if (time < fast_time)
{
fast_time = time;
best_config = config;
}
}
for (int i = 0; i < warmup; ++i)
{
gemm.gemm(act, weight, scales, zeros, bias, out, m, n, k, group_size, best_config, ws_ptr, ws_bytes, s);
}
cudaEventRecord(begin, s);
for (int i = 0; i < iter; ++i)
{
gemm.gemm(act, weight, scales, zeros, bias, out, m, n, k, group_size, best_config, ws_ptr, ws_bytes, s);
}
if (ws_ptr)
cudaFree(ws_ptr);
}
cudaEventRecord(end, s);
cudaEventSynchronize(end);
float time;
cudaEventElapsedTime(&time, begin, end);
cudaEventDestroy(begin);
cudaEventDestroy(end);
cudaStreamDestroy(s);
return time / iter;
}
struct CudaBuffer
{
void* _data;
int _size;
CudaBuffer(int size_in_bytes)
: _size(size_in_bytes)
{
cudaMalloc(&_data, _size);
}
template <typename T = void>
T* data()
{
return reinterpret_cast<T*>(_data);
}
void copy_to(void* dst)
{
cudaMemcpy(dst, _data, _size, cudaMemcpyDeviceToHost);
}
void copy_from(void* src)
{
cudaMemcpy(_data, src, _size, cudaMemcpyHostToDevice);
}
~CudaBuffer()
{
cudaFree(_data);
}
};
template <typename T>
float compare(void* _pa, void* _pb, int size, float scale)
{
auto pa = reinterpret_cast<T*>(_pa);
auto pb = reinterpret_cast<T*>(_pb);
float max_diff = 0.f, tot_diff = 0.f;
float max_val = 0.f;
int diff_cnt = 0;
float threshold = 1e-7;
for (int n = 0; n < size; ++n)
{
float va = static_cast<float>(pa[n]);
float vb = static_cast<float>(pb[n]);
max_val = std::max(max_val, vb);
float diff = std::abs(va - vb);
if (diff > threshold)
{
max_diff = std::max(max_diff, diff);
tot_diff += diff;
++diff_cnt;
}
}
float diff_thres = max_val * scale;
#if defined(ENABLE_BF16)
if constexpr (std::is_same_v<T, __nv_bfloat16>)
{
// bfloat16 has fewer mantissa digits than float16, so the cumulative error will be larger.
diff_thres *= 2.f;
}
else
#endif
{
diff_thres *= 1.5f;
}
printf("max diff %f (diff threshold %f), avg diff %f, diff cnt %d/%d\n", max_diff, diff_thres, tot_diff / diff_cnt,
diff_cnt, size);
return max_diff <= diff_thres;
}
template <typename T1, typename T2>
void random_fill(std::vector<T1>& vec, T2 minv, T2 maxv)
{
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> dis(static_cast<float>(minv), static_cast<float>(maxv));
for (auto& v : vec)
{
v = static_cast<T1>(dis(gen));
}
}
template <WeightOnlyActivationType AFlag, WeightOnlyQuantType BFlag>
bool benchmark(int m, int n, int k, int group_size, int warmup, int iter)
{
printf("benchmark mnk (%d, %d, %d) ", m, n, k);
if (AFlag == WeightOnlyActivationType::FP16)
{
printf("FP16 Activation ");
}
else
{
printf("BF16 Activation ");
}
if (BFlag == WeightOnlyQuantType::Int8b)
{
printf("Int8b ");
}
else
{
printf("Int4b ");
}
if (group_size == 0)
{
printf("PerChannel Weight Only\n");
}
else
{
printf("GroupWise%d Weight Only\n", group_size);
}
using AT = typename AType<AFlag>::CudaKernelAType;
using BT = typename BType<BFlag>::CudaKernelBType;
constexpr int elem_per_byte = BType<BFlag>::elemsPerByte;
CudaBuffer d_act(m * k * sizeof(AT));
CudaBuffer d_weight(k * n * sizeof(uint8_t) / elem_per_byte);
CudaBuffer d_scales(n * k * sizeof(AT));
CudaBuffer d_zeros(n * k * sizeof(AT));
CudaBuffer d_bias(n * sizeof(AT));
CudaBuffer d_out(m * n * sizeof(AT));
std::vector<AT> h_act(m * k);
std::vector<uint8_t> h_weight(k * n);
std::vector<AT> h_scales(n * k), h_zeros(n * k), h_bias(n);
std::vector<AT> h_out1(m * n), h_out2(m * n);
random_fill(h_act, -1.f, 1.f);
random_fill(h_scales, -1.f, 1.f);
for (uint8_t& v : h_weight)
{
v = rand() % 256;
}
d_act.copy_from(h_act.data());
d_weight.copy_from(h_weight.data());
d_scales.copy_from(h_scales.data());
d_zeros.copy_from(h_zeros.data());
d_bias.copy_from(h_bias.data());
void* p_zeros = nullptr;
void* p_bias = nullptr;
if (group_size == 64 || group_size == 128)
{
p_zeros = d_zeros.data();
p_bias = d_bias.data();
}
float time1, time2;
std::function<decltype(benchmark_perchannel<CudaKernel, AFlag, BFlag>)> benchmark_func_cuda
= benchmark_perchannel<CudaKernel, AFlag, BFlag>;
std::function<decltype(benchmark_perchannel<CutlassKernel, AFlag, BFlag>)> benchmark_func_cutlass
= benchmark_perchannel<CutlassKernel, AFlag, BFlag>;
if (group_size != 0)
{
benchmark_func_cuda = benchmark_groupwise<CudaKernel, AFlag, BFlag>;
benchmark_func_cutlass = benchmark_groupwise<CutlassKernel, AFlag, BFlag>;
}
time1 = benchmark_func_cuda(d_act.data(), d_weight.data(), d_scales.data(), p_zeros, p_bias, d_out.data(), m, n, k,
group_size, warmup, iter);
d_out.copy_to(h_out1.data());
time2 = benchmark_func_cutlass(d_act.data(), d_weight.data(), d_scales.data(), p_zeros, p_bias, d_out.data(), m, n,
k, group_size, warmup, iter);
d_out.copy_to(h_out2.data());
float quant_scale = 1.f / (1 << (8 / elem_per_byte - 1));
bool pass = compare<AT>(h_out1.data(), h_out2.data(), m * n, quant_scale);
printf(
"cuda kernel cost time %.6f, cutlass kernel cost time %.6f, cuda speedup %.3f\n", time1, time2, time2 / time1);
return pass;
}
TEST(Kernel, WeightOnly)
{
bool pass;
int warmup = 10, iter = 30;
std::vector<int> ms{1, 2, 4};
std::vector<int> ns{512, 1024, 2048, 4096};
std::vector<int> ks{512, 1024, 2048, 4096};
std::vector<int> gss{0, 64, 128};
for (auto m : ms)
{
for (auto n : ns)
{
for (auto k : ks)
{
for (auto gs : gss)
{
pass = benchmark<WeightOnlyActivationType::FP16, WeightOnlyQuantType::Int8b>(
m, n, k, gs, warmup, iter);
EXPECT_TRUE(pass);
pass = benchmark<WeightOnlyActivationType::FP16, WeightOnlyQuantType::Int4b>(
m, n, k, gs, warmup, iter);
EXPECT_TRUE(pass);
#if defined(ENABLE_BF16)
pass = benchmark<WeightOnlyActivationType::BF16, WeightOnlyQuantType::Int8b>(
m, n, k, gs, warmup, iter);
EXPECT_TRUE(pass);
pass = benchmark<WeightOnlyActivationType::BF16, WeightOnlyQuantType::Int4b>(
m, n, k, gs, warmup, iter);
EXPECT_TRUE(pass);
#endif
}
}
}
}
}