#include #include #include #include #include #include #include #include "cutlass/numeric_types.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/quantization.h" #include "tensorrt_llm/kernels/weightOnlyBatchedGemv/fp8Gemm.h" #include #include #include #include #include #include #include #include #include #include #include #include #include using namespace tensorrt_llm::kernels::fp8_gemm; void simple_assert(bool flag) { if (!flag) { throw std::runtime_error("assert failed"); } } inline void checkCublasStatus(cublasStatus_t status) { if (status != CUBLAS_STATUS_SUCCESS) { printf("cuBLAS API failed with status %d\n", status); throw std::logic_error("cuBLAS API failed"); } } struct CudaBuffer { void* _data; int _size; CudaBuffer(int size_in_bytes) : _size(size_in_bytes) { cudaMalloc(&_data, _size); } template T* data() { return reinterpret_cast(_data); } void copy_to(void* dst) { cudaMemcpy(dst, _data, _size, cudaMemcpyDeviceToHost); } void copy_from(void* src) { cudaMemcpy(_data, src, _size, cudaMemcpyHostToDevice); } ~CudaBuffer() { cudaFree(_data); } }; template bool compare(void* _pa, void* _pb, int size) { auto pa = reinterpret_cast(_pa); auto pb = reinterpret_cast(_pb); float max_diff = 0.f, tot_diff = 0.f; float max_val = 0.f; int diff_cnt = 0; float threshold = 1e-7; for (int n = 0; n < size; ++n) { float va = static_cast(pa[n]); float vb = static_cast(pb[n]); max_val = std::max(max_val, vb); float diff = std::abs(va - vb); if (diff > threshold) { max_diff = std::max(max_diff, diff); tot_diff += diff; ++diff_cnt; } } float diff_thres = max_val * 2e-3; #if defined(ENABLE_BF16) if constexpr (std::is_same_v) { // bfloat16 has fewer mantissa digits than float16(10 bits for fp16 but only 7 bits for bf16), so the cumulative // error will be larger. diff_thres *= 3.f; } else #endif { diff_thres *= 1.5f; } printf("max diff %f (diff threshold %f), avg diff %f, diff cnt %d/%d\n", max_diff, diff_thres, tot_diff / std::max(1, diff_cnt), diff_cnt, size); return max_diff <= diff_thres; } template void random_fill(std::vector& vec, T2 minv, T2 maxv) { std::mt19937 gen(rand()); std::uniform_real_distribution dis(static_cast(minv), static_cast(maxv)); for (auto& v : vec) { v = static_cast(dis(gen)); } } template void constant_fill(std::vector& vec, T2 value) { for (auto& v : vec) { v = static_cast(value); } } template void linear_fill(std::vector& vec, int length) { for (int i = 0; i < vec.size(); ++i) { vec[i] = static_cast((i % length) / 100.f); } } template void print_mat(std::vector const& data, int row, int col, char const* name) { assert(data.size() == row * col); printf("---------------%s\n", name); for (int n = 0; n < data.size(); ++n) { float value = static_cast(data[n]); printf("%f, ", value); if (n % col == col - 1) printf("\n"); } printf("\n"); } template void run_cpu(void* weight, void* activation, float scale, Params const& params, void* output) { for (int idx_m = 0; idx_m < params.m; ++idx_m) { for (int idx_n = 0; idx_n < params.n; ++idx_n) { float acc = 0.f; for (int idx_k = 0; idx_k < params.k; ++idx_k) { InputType a = reinterpret_cast(activation)[params.k * idx_m + idx_k]; InputType w = reinterpret_cast(weight)[params.k * idx_n + idx_k]; acc += static_cast(w) * static_cast(a); } reinterpret_cast(output)[idx_m * params.n + idx_n] = static_cast(acc * scale); } } } float run_cuda_kernel(Params& params, int warmup, int iter) { cudaStream_t s; cudaStreamCreate(&s); cudaEvent_t begin, end; cudaEventCreate(&begin); cudaEventCreate(&end); for (int i = 0; i < warmup; ++i) { tensorrt_llm::kernels::fp8_gemm::fp8GemmDispatcher(params, s); } cudaEventRecord(begin, s); for (int i = 0; i < iter; ++i) { tensorrt_llm::kernels::fp8_gemm::fp8GemmDispatcher(params, s); } cudaEventRecord(end, s); cudaEventSynchronize(end); float time; cudaEventElapsedTime(&time, begin, end); cudaEventDestroy(begin); cudaEventDestroy(end); cudaStreamDestroy(s); return time / iter; } template float run_cublas_kernel(Params& params, int warmup, int iter) { constexpr cudaDataType_t kOutputDatatype = std::is_same::value ? CUDA_R_16BF : std::is_same::value ? CUDA_R_32F : CUDA_R_16F; // use weight as A, use activation as B so that D is transposed(WIP) void const* A = params.weight; void const* B = params.act; void* D = params.output; int m = params.m, n = params.n, k = params.k; float h_alpha = params.alpha; void* workspace = nullptr; size_t workspaceSize = 32 * 1024 * 1024; // 32MB for Hopper cudaMalloc(&workspace, workspaceSize); cudaStream_t stream; cudaStreamCreate(&stream); cudaEvent_t begin, end; cudaEventCreate(&begin); cudaEventCreate(&end); cublasLtHandle_t ltHandle; checkCublasStatus(cublasLtCreate(<Handle)); cublasLtMatmulDesc_t operationDesc = NULL; cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL, Ddesc = NULL; cublasLtMatmulPreference_t preference = NULL; // only support TN for FP8 cublasOperation_t transa = CUBLAS_OP_T; cublasOperation_t transb = CUBLAS_OP_N; float h_beta = 0.0; // Can be non-zero starting from 12.0 int returnedResults = 0; cublasLtMatmulHeuristicResult_t heuristicResult = {}; // create operation desciriptor; see cublasLtMatmulDescAttributes_t for details about defaults; here we just need to // set the transforms for A and B checkCublasStatus(cublasLtMatmulDescCreate(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); checkCublasStatus( cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa))); checkCublasStatus( cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb))); // create matrix descriptors, we are good with the details here so no need to set any extra attributes // table of supported type combinations can be found in the documentation: // https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmul checkCublasStatus(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8F_E4M3, k, n, k)); checkCublasStatus(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8F_E4M3, k, m, k)); checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, kOutputDatatype, n, m, n)); checkCublasStatus(cublasLtMatrixLayoutCreate(&Ddesc, kOutputDatatype, n, m, n)); // create preference handle; here we could use extra attributes to disable tensor ops or to make sure algo selected // will work with badly aligned A, B, C; here for simplicity we just assume A,B,C are always well aligned (e.g. // directly come from cudaMalloc) checkCublasStatus(cublasLtMatmulPreferenceCreate(&preference)); checkCublasStatus(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize))); // we just need the best available heuristic to try and run matmul. There is no guarantee this will work, e.g. if A // is badly aligned, you can request more (e.g. 32) algos and try to run them one by one until something works checkCublasStatus(cublasLtMatmulAlgoGetHeuristic( ltHandle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, 1, &heuristicResult, &returnedResults)); if (returnedResults == 0) { checkCublasStatus(CUBLAS_STATUS_NOT_SUPPORTED); } for (int i = 0; i < warmup; ++i) { checkCublasStatus(cublasLtMatmul(ltHandle, operationDesc, &h_alpha, A, Adesc, B, Bdesc, &h_beta, nullptr, Cdesc, D, Ddesc, &heuristicResult.algo, workspace, workspaceSize, stream)); } cudaEventRecord(begin, stream); for (int i = 0; i < iter; ++i) { checkCublasStatus(cublasLtMatmul(ltHandle, operationDesc, &h_alpha, A, Adesc, B, Bdesc, &h_beta, nullptr, Cdesc, D, Ddesc, &heuristicResult.algo, workspace, workspaceSize, stream)); } cudaEventRecord(end, stream); cudaEventSynchronize(end); if (workspace) cudaFree(workspace); float time; cudaEventElapsedTime(&time, begin, end); cudaEventDestroy(begin); cudaEventDestroy(end); cudaStreamDestroy(stream); // descriptors are no longer needed as all GPU work was already enqueued if (preference) checkCublasStatus(cublasLtMatmulPreferenceDestroy(preference)); if (Ddesc) checkCublasStatus(cublasLtMatrixLayoutDestroy(Ddesc)); if (Cdesc) checkCublasStatus(cublasLtMatrixLayoutDestroy(Cdesc)); if (Bdesc) checkCublasStatus(cublasLtMatrixLayoutDestroy(Bdesc)); if (Adesc) checkCublasStatus(cublasLtMatrixLayoutDestroy(Adesc)); if (operationDesc) checkCublasStatus(cublasLtMatmulDescDestroy(operationDesc)); return time / iter; } template bool benchmark_and_verify(int m, int n, int k, tensorrt_llm::common::QuantMode const& quant_mode, int warmup, int iter, bool debug = false, bool run_cublas = false) { constexpr nvinfer1::DataType kInputDatatype = std::is_same::value ? nvinfer1::DataType::kFLOAT : std::is_same::value ? nvinfer1::DataType::kHALF : std::is_same::value ? nvinfer1::DataType::kBF16 : nvinfer1::DataType::kFP8; constexpr nvinfer1::DataType kOutputDatatype = std::is_same::value ? nvinfer1::DataType::kFLOAT : std::is_same::value ? nvinfer1::DataType::kHALF : std::is_same::value ? nvinfer1::DataType::kBF16 : nvinfer1::DataType::kFP8; std::srand(20240123); simple_assert(m <= 4); printf("mnk (%d, %d, %d), output %s\n", m, n, k, typeid(OutputType).name()); CudaBuffer d_act(m * k * sizeof(InputType)); CudaBuffer d_weight(k * n * sizeof(InputType)); CudaBuffer d_out(m * n * sizeof(OutputType)); std::vector h_act(m * k); std::vector h_weight(k * n); std::vector h_alpha(1); std::vector h_out_cuda(m * n), h_out_cublas(m * n), h_out_gt(m * n); random_fill(h_act, -1.f, 1.f); random_fill(h_weight, -1.f, 1.f); random_fill(h_alpha, -1.f, 1.f); if (debug) { print_mat(h_act, m, k, "h_act"); print_mat(h_weight, k, n, "h_weight"); print_mat(h_alpha, 1, 1, "h_alpha"); } d_act.copy_from(h_act.data()); d_weight.copy_from(h_weight.data()); Params params{ d_act.data(), d_weight.data(), h_alpha[0], d_out.data(), m, n, k, quant_mode, kInputDatatype, kOutputDatatype}; run_cpu(h_weight.data(), h_act.data(), h_alpha[0], params, h_out_gt.data()); float time1, time2; time1 = run_cuda_kernel(params, warmup, iter); d_out.copy_to(h_out_cuda.data()); bool pass_cuda_kernel = compare(h_out_cuda.data(), h_out_gt.data(), m * n); if (debug) { print_mat(h_out_gt, m, n, "h_out_cpu"); print_mat(h_out_cuda, m, n, "h_out_cuda"); } if (run_cublas) { time2 = run_cublas_kernel(params, warmup, iter); d_out.copy_to(h_out_cublas.data()); bool pass_cublas = compare(h_out_cublas.data(), h_out_gt.data(), m * n); if (debug) { print_mat(h_out_cublas, m, n, "h_out_cublas"); } printf("cuda kernel cost time %.6f, cublas kernel cost time %.6f, cuda speedup %.3f\n", time1, time2, time2 / time1); return pass_cuda_kernel && pass_cublas; } printf("cuda kernel cost time %.6f\n", time1); return pass_cuda_kernel; } #ifdef ENABLE_FP8 TEST(Kernel, Fp8Gemv) { int const arch = tensorrt_llm::common::getSMVersion(); bool pass; int warmup = 10, iter = 30; std::vector ms{1, 2, 3, 4}; std::vector ns{2048, 4096}; std::vector ks{2048, 4096}; tensorrt_llm::common::QuantMode quant_mode = tensorrt_llm::common::QuantMode::fromQuantAlgo("FP8"); for (auto m : ms) { for (auto n : ns) { for (auto k : ks) { pass = benchmark_and_verify<__nv_fp8_e4m3, float>(m, n, k, quant_mode, warmup, iter); EXPECT_TRUE(pass); pass = benchmark_and_verify<__nv_fp8_e4m3, half>(m, n, k, quant_mode, warmup, iter); EXPECT_TRUE(pass); #if defined(ENABLE_BF16) pass = benchmark_and_verify<__nv_fp8_e4m3, __nv_bfloat16>(m, n, k, quant_mode, warmup, iter); EXPECT_TRUE(pass); #endif } } } } #endif TEST(Kernel, Fp16Gemv) { int const arch = tensorrt_llm::common::getSMVersion(); bool pass; int warmup = 10, iter = 30; std::vector ms{1, 2, 3, 4}; std::vector ns{2047, 2048, 4096}; std::vector ks{2048, 4096}; tensorrt_llm::common::QuantMode quant_mode = tensorrt_llm::common::QuantMode::fromQuantAlgo("FP8"); for (auto m : ms) { for (auto n : ns) { for (auto k : ks) { pass = benchmark_and_verify(m, n, k, quant_mode, warmup, iter); EXPECT_TRUE(pass); pass = benchmark_and_verify(m, n, k, quant_mode, warmup, iter); EXPECT_TRUE(pass); #if defined(ENABLE_BF16) pass = benchmark_and_verify<__nv_bfloat16, __nv_bfloat16>(m, n, k, quant_mode, warmup, iter); EXPECT_TRUE(pass); #endif } } } }