TensorRT-LLMs/cpp/tensorrt_llm/common/cublasMMWrapper.cpp
2023-09-20 00:29:41 -07:00

641 lines
24 KiB
C++

/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cublasMMWrapper.h"
#include "tensorrt_llm/common/cublasVersionCheck.h"
#include <algorithm>
#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#endif
namespace tensorrt_llm
{
namespace common
{
cublasMMWrapper::cublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle,
std::shared_ptr<cublasLtHandle_t> cublasltHandle, cudaStream_t stream, cublasAlgoMap* cublas_algo_map,
std::mutex* mu, void* workspace)
: cublas_handle_(cublasHandle)
, cublaslt_handle_(cublasltHandle)
, stream_(stream)
, cublas_algo_map_(cublas_algo_map)
, mu_(mu)
, cublas_workspace_(workspace)
{
}
cublasMMWrapper::~cublasMMWrapper()
{
mu_ = nullptr;
}
cublasMMWrapper::cublasMMWrapper(const cublasMMWrapper& wrapper)
: cublas_handle_(wrapper.cublas_handle_)
, cublaslt_handle_(wrapper.cublaslt_handle_)
, stream_(wrapper.stream_)
, cublas_algo_map_(wrapper.cublas_algo_map_)
, mu_(wrapper.mu_)
{
}
void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
const void* alpha, const void* A, cudaDataType_t Atype, int lda, const void* B, cudaDataType_t Btype, int ldb,
const void* beta, void* C, cudaDataType_t Ctype, int ldc, cudaDataType_t computeType, cublasGemmAlgo_t algo)
{
mu_->lock();
check_cuda_error(cublasGemmEx(*cublas_handle_, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta,
C, Ctype, ldc, computeType, algo));
sync_check_cuda_error();
mu_->unlock();
}
void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc)
{
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f);
}
void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc, float f_alpha, float f_beta)
{
half h_alpha = (half) (f_alpha);
half h_beta = (half) (f_beta);
mu_->lock();
// TODO: default cublas libs
int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0;
bool using_cublasLt = (Atype_ == CUDA_R_16F) ? true : false;
int batch_count = 1;
// fp32 use cublas as default
// fp16 use cublasLt as default
const void* alpha = is_fp16_computeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void*>(&f_alpha);
const void* beta = is_fp16_computeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void*>(&f_beta);
int findAlgo = cublas_algo_map_->isExist(batch_count, m, n, k, getCublasDataType(Atype_));
cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype_));
if (findAlgo)
{
if (info.stages != -1)
{
using_cublasLt = true;
}
else
{
using_cublasLt = false;
}
}
if (using_cublasLt)
{
cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
cudaDataType_t scaleType;
#if (CUDART_VERSION >= 11000)
cublasComputeType_t computeType;
#else
cudaDataType_t computeType;
#endif
if (is_fp16_computeType)
{
#if (CUDART_VERSION >= 11000)
computeType = CUBLAS_COMPUTE_16F;
#else
computeType = CUDA_R_16F;
#endif
scaleType = CUDA_R_16F;
}
else
{
#if (CUDART_VERSION >= 11000)
computeType = CUBLAS_COMPUTE_32F;
#else
computeType = CUDA_R_32F;
#endif
scaleType = CUDA_R_32F;
}
// --------------------------------------
// Create descriptors for the original matrices
cublasLtMatrixLayoutCreate(&Adesc, Atype_, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
cublasLtMatrixLayoutCreate(&Bdesc, Btype_, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
cublasLtMatrixLayoutCreate(&Cdesc, Ctype_, m, n, ldc);
#if (CUDART_VERSION >= 11000)
cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
#else
cublasLtMatmulDescCreate(&operationDesc, computeType);
#endif
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t));
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t));
cublasLtMatmulAlgo_t algo;
void* workSpace = cublas_workspace_;
int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
if (findAlgo)
{
if (info.workspaceSize > workspaceSize)
{
findAlgo = 0;
}
else
{
cublasLtMatmulAlgoInit(
*cublaslt_handle_, computeType, scaleType, Atype_, Btype_, Ctype_, Ctype_, info.algoId, &algo);
cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption), sizeof(info.customOption));
cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(info.tile), sizeof(info.tile));
cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val), sizeof(info.splitK_val));
cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle), sizeof(info.swizzle));
cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(info.reductionScheme), sizeof(int));
#if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages));
#endif
}
}
cublasLtMatmul(*cublaslt_handle_, operationDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, C, Cdesc,
(findAlgo == 1 ? (&algo) : NULL), workSpace, workspaceSize, stream_);
cublasLtMatmulDescDestroy(operationDesc);
cublasLtMatrixLayoutDestroy(Adesc);
cublasLtMatrixLayoutDestroy(Bdesc);
cublasLtMatrixLayoutDestroy(Cdesc);
sync_check_cuda_error();
}
else
{
int cublasAlgo = info.algoId;
check_cuda_error(cublasGemmEx(*cublas_handle_, transa, transb, m, n, k, alpha, A, Atype_, lda, B, Btype_, ldb,
beta, C, Ctype_, ldc, computeType_, static_cast<cublasGemmAlgo_t>(cublasAlgo)));
sync_check_cuda_error();
}
mu_->unlock();
}
void cublasMMWrapper::setWorkspace(void* workspace)
{
cublas_workspace_ = workspace;
}
void cublasMMWrapper::setFP32GemmConfig()
{
Atype_ = CUDA_R_32F;
Btype_ = CUDA_R_32F;
Ctype_ = CUDA_R_32F;
computeType_ = CUDA_R_32F;
}
void cublasMMWrapper::setFP16GemmConfig()
{
Atype_ = CUDA_R_16F;
Btype_ = CUDA_R_16F;
Ctype_ = CUDA_R_16F;
computeType_ = CUDA_R_32F;
}
#ifdef ENABLE_BF16
void cublasMMWrapper::setBF16GemmConfig()
{
Atype_ = CUDA_R_16BF;
Btype_ = CUDA_R_16BF;
Ctype_ = CUDA_R_16BF;
computeType_ = CUDA_R_32F;
}
#endif
void cublasMMWrapper::setGemmConfig(
cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType)
{
Atype_ = aType;
Btype_ = bType;
Ctype_ = cType;
computeType_ = computeType;
}
CublasDataType cublasMMWrapper::getCublasDataType(cudaDataType_t data_type)
{
if (data_type == CUDA_R_16F)
{
return HALF_DATATYPE;
}
else if (data_type == CUDA_R_32F)
{
return FLOAT_DATATYPE;
}
else if (data_type == CUDA_R_8I)
{
return INT8_DATATYPE;
}
#ifdef ENABLE_BF16
else if (data_type == CUDA_R_16BF)
{
return BFLOAT16_DATATYPE;
}
#endif
return FLOAT_DATATYPE;
}
#if (CUDART_VERSION >= 11000)
// input, weight, output are row-major
// only works for cublas 11.x
void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
const void* A, const int lda, const void* B, const int ldb, const void* bias, void* C, const int ldc)
{
cudaDataType_t Atype, Btype, Ctype;
cublasComputeType_t computeType;
cudaDataType_t scaleType;
float alpha_float = 1.0f;
float beta_float = 0.0f;
half alpha_half = half(1.0f);
half beta_half = half(0.0f);
void *alpha, *beta;
// int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0;
if (Atype_ == CUDA_R_32F)
{
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
Atype = CUDA_R_32F;
Btype = CUDA_R_32F;
Ctype = CUDA_R_32F;
scaleType = CUDA_R_32F;
alpha = &alpha_float;
beta = &beta_float;
}
else if (Atype_ == CUDA_R_16BF)
{
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
Atype = CUDA_R_16BF;
Btype = CUDA_R_16BF;
Ctype = CUDA_R_16BF;
scaleType = CUDA_R_32F;
alpha = &alpha_float;
beta = &beta_float;
}
else
{
computeType = CUBLAS_COMPUTE_16F;
Atype = CUDA_R_16F;
Btype = CUDA_R_16F;
Ctype = CUDA_R_16F;
scaleType = CUDA_R_16F;
alpha = &alpha_half;
beta = &beta_half;
}
cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS;
cublasLtMatrixLayoutCreate(&Adesc, Atype, (transa == CUBLAS_OP_N) ? m : k, (transa == CUBLAS_OP_N) ? k : m, lda);
cublasLtMatrixLayoutCreate(&Bdesc, Btype, (transb == CUBLAS_OP_N) ? k : n, (transb == CUBLAS_OP_N) ? n : k, ldb);
cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldc);
cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t));
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t));
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epi, sizeof(cublasLtEpilogue_t));
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(const void*));
check_cuda_error(cublasLtMatmul(
*cublaslt_handle_, operationDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, C, Cdesc, NULL, NULL, 0, stream_));
cublasLtMatrixLayoutDestroy(Adesc);
cublasLtMatrixLayoutDestroy(Bdesc);
cublasLtMatrixLayoutDestroy(Cdesc);
cublasLtMatmulDescDestroy(operationDesc);
}
#endif
void cublasMMWrapper::setStream(cudaStream_t stream)
{
stream_ = stream;
}
void cublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
const int k, const void* A, const int lda, const int64_t strideA, const void* B, const int ldb,
const int64_t strideB, void* C, const int ldc, const int64_t strideC, const int batch_count, const float f_alpha,
const float f_beta)
{
half h_alpha = (half) f_alpha;
half h_beta = (half) f_beta;
mu_->lock();
int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0;
const void* alpha
= is_fp16_computeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<const void*>(&f_alpha);
const void* beta = is_fp16_computeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<const void*>(&f_beta);
cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype_));
check_cuda_error(cublasGemmStridedBatchedEx(*cublas_handle_, transa, transb, m, n, k, alpha, A, Atype_, lda,
strideA, B, Btype_, ldb, strideB, beta, C, Ctype_, ldc, strideC, batch_count, computeType_,
static_cast<cublasGemmAlgo_t>(info.algoId)));
mu_->unlock();
}
void cublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
const int k, const float f_alpha, const void* A, cudaDataType_t AType, const int lda, const int64_t strideA,
const void* B, cudaDataType_t BType, const int ldb, const int64_t strideB, const float f_beta, void* C,
cudaDataType_t CType, const int ldc, const int64_t strideC, const int batch_count, cudaDataType_t computeType)
{
half h_alpha = (half) f_alpha;
half h_beta = (half) f_beta;
mu_->lock();
int is_fp16_computeType = computeType == CUDA_R_16F ? 1 : 0;
const void* alpha
= is_fp16_computeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<const void*>(&f_alpha);
const void* beta = is_fp16_computeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<const void*>(&f_beta);
cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype_));
check_cuda_error(cublasGemmStridedBatchedEx(*cublas_handle_, transa, transb, m, n, k, alpha, A, AType, lda, strideA,
B, BType, ldb, strideB, beta, C, CType, ldc, strideC, batch_count, computeType,
static_cast<cublasGemmAlgo_t>(info.algoId)));
mu_->unlock();
}
void cublasMMWrapper::batchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
const int k, const void* const* A, const int lda, const void* const* B, const int ldb, void* const* C,
const int ldc, const int batch_count)
{
float f_alpha = static_cast<float>(1.0f);
float f_beta = static_cast<float>(0.0f);
half h_alpha = (half) 1.0f;
half h_beta = (half) 0.0f;
mu_->lock();
int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0;
const void* alpha = is_fp16_computeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void*>(&f_alpha);
const void* beta = is_fp16_computeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void*>(&f_beta);
cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype_));
check_cuda_error(cublasGemmBatchedEx(*cublas_handle_, transa, transb, m, n, k, alpha, A, Atype_, lda, B, Btype_,
ldb, beta, C, Ctype_, ldc, batch_count, computeType_, static_cast<cublasGemmAlgo_t>(info.algoId)));
mu_->unlock();
}
bool cublasMMWrapper::isFuseBatchGemm(const int batch_count, const int m, const int k, const int n)
{
CublasDataType data_type = getCublasDataType(Atype_);
if (cublas_algo_map_->isExist(batch_count, m, k, n, data_type) == false
|| cublas_algo_map_->isExist(1, m, k, n, data_type) == false)
{
return false;
}
else
{
return cublas_algo_map_->getAlgo(batch_count, m, k, n, data_type).exec_time
< 3 * cublas_algo_map_->getAlgo(1, m, k, n, data_type).exec_time;
}
}
std::pair<bool, cublasLtMatmulAlgo_t> cublasMMWrapper::findBestAlgo(cublasLtHandle_t lightHandle,
cublasLtMatmulDesc_t computeDesc, const void* alpha, const void* A, cublasLtMatrixLayout_t Adesc, const void* B,
cublasLtMatrixLayout_t Bdesc, const void* beta, const void* C, cublasLtMatrixLayout_t Cdesc, void* D,
cublasLtMatrixLayout_t Ddesc, cudaStream_t stream)
{
#if TLLM_CUBLAS_VER_LE(11, 4, 2)
TLLM_CHECK_WITH_INFO(false, "CUBLAS version too low, must be > 11.4.2.");
return {false, cublasLtMatmulAlgo_t{}};
#else
size_t returnSize;
int32_t pointer_mode;
cublasLtMatmulDescGetAttribute(
computeDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode), &returnSize);
std::vector<cublasLtMatmulHeuristicResult_t> heuristics(200);
cublasLtMatmulPreference_t preference;
check_cuda_error(cublasLtMatmulPreferenceCreate(&preference));
check_cuda_error(cublasLtMatmulPreferenceInit(preference));
uint64_t workspace_size = CUBLAS_WORKSPACE_SIZE;
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size)));
#if TLLM_CUBLAS_VER_LT(12, 0, 0)
uint32_t pointer_mode_mask = 0;
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_EPILOGUE_MASK, &pointer_mode_mask, sizeof(pointer_mode_mask)));
#endif
int return_count = 0;
auto ret = cublasLtMatmulAlgoGetHeuristic(lightHandle, computeDesc, Adesc, Bdesc, Cdesc, Ddesc, preference,
heuristics.size(), heuristics.data(), &return_count);
heuristics.resize(return_count);
std::map<int, std::vector<float>> algo_results;
for (const auto& heuristic : heuristics)
{
cublasLtMatmulAlgo_t algo = heuristic.algo;
int32_t algo_id;
cublasLtMatmulAlgoConfigGetAttribute(&algo, CUBLASLT_ALGO_CONFIG_ID, &algo_id, sizeof(algo_id), &returnSize);
cudaEvent_t start_event, stop_event;
cudaEventCreate(&start_event);
cudaEventCreate(&stop_event);
float my_alpha = 1.0f;
float my_beta = 0.0f;
for (int i = 0; i < 11; i++)
{
float duration_ms;
cudaEventRecord(start_event, stream);
check_cuda_error(cublasLtMatmul(lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, D,
Ddesc, &algo, cublas_workspace_, CUBLAS_WORKSPACE_SIZE, stream));
cudaEventRecord(stop_event, stream);
cudaEventSynchronize(stop_event);
cudaEventElapsedTime(&duration_ms, start_event, stop_event);
algo_results[algo_id].push_back(duration_ms);
}
std::sort(algo_results[algo_id].begin(), algo_results[algo_id].end());
}
cublasLtMatmulHeuristicResult_t result;
float best_time = INFINITY;
for (const auto& heuristic : heuristics)
{
cublasLtMatmulAlgo_t algo = heuristic.algo;
int32_t algo_id;
cublasLtMatmulAlgoConfigGetAttribute(&algo, CUBLASLT_ALGO_CONFIG_ID, &algo_id, sizeof(algo_id), &returnSize);
const auto& results = algo_results[algo_id];
if (results.size() > 0 && results[5] < best_time)
{
best_time = results[5];
result = heuristic;
}
}
return {best_time != INFINITY, result.algo};
#endif
}
cublasMMWrapper::MatrixLayout cublasMMWrapper::createMatrixLayout(cublasLtMatrixLayout_t Mdesc)
{
size_t returnSize;
MatrixLayout m_layout;
cublasLtMatrixLayoutGetAttribute(
Mdesc, CUBLASLT_MATRIX_LAYOUT_TYPE, &std::get<0>(m_layout), sizeof(std::get<0>(m_layout)), &returnSize);
cublasLtMatrixLayoutGetAttribute(
Mdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &std::get<1>(m_layout), sizeof(std::get<1>(m_layout)), &returnSize);
cublasLtMatrixLayoutGetAttribute(
Mdesc, CUBLASLT_MATRIX_LAYOUT_ROWS, &std::get<2>(m_layout), sizeof(std::get<2>(m_layout)), &returnSize);
cublasLtMatrixLayoutGetAttribute(
Mdesc, CUBLASLT_MATRIX_LAYOUT_COLS, &std::get<3>(m_layout), sizeof(std::get<3>(m_layout)), &returnSize);
return m_layout;
}
cublasStatus_t cublasMMWrapper::cublasLtMatmulWrapper(cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t computeDesc,
const void* alpha, const void* A, cublasLtMatrixLayout_t Adesc, const void* B, cublasLtMatrixLayout_t Bdesc,
const void* beta, const void* C, cublasLtMatrixLayout_t Cdesc, void* D, cublasLtMatrixLayout_t Ddesc,
const cublasLtMatmulAlgo_t* algo, void* workspace, size_t workspaceSizeInBytes, cudaStream_t stream)
{
cache_idx_t cache_idx{computeDesc,
{createMatrixLayout(Adesc), createMatrixLayout(Bdesc), createMatrixLayout(Cdesc), createMatrixLayout(Ddesc)}};
cublasLtMatmulAlgo_t algo_value;
bool found_algo = false;
if (algo == nullptr)
{
if (algo_cache.find(cache_idx) == algo_cache.end())
{
auto result
= findBestAlgo(lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, D, Ddesc, stream);
if (result.first)
{
algo_cache[cache_idx] = result.second;
algo_value = result.second;
found_algo = true;
}
}
else
{
algo_value = algo_cache[cache_idx];
found_algo = true;
}
}
return cublasLtMatmul(lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, D, Ddesc,
found_algo ? &algo_value : algo, workspace, workspaceSizeInBytes, stream);
}
void cublasMMWrapper::_Int8Gemm(const int m, const int n, const int k, const int8_t* A, const int lda, const int8_t* B,
const int ldb, void* C, const int ldc, const void* alpha, const int mode, const bool per_column_scaling)
{
/* mode:
* - 0: int8 * int8 -> int32 -> int8
* - 1: int8 * int8 -> int32 -> int32
*/
#if TLLM_CUBLAS_VER_LE(11, 4, 2)
TLLM_CHECK_WITH_INFO(false, "CUBLAS version too low, must be > 11.4.2.");
#else
mu_->lock();
const auto op_a = CUBLAS_OP_T;
const auto op_b = CUBLAS_OP_N;
const auto dataType = CUDA_R_8I;
const auto resultType = mode == 0 ? CUDA_R_8I : CUDA_R_32I;
const auto computeType = CUBLAS_COMPUTE_32I;
const auto scaleType = mode == 0 ? CUDA_R_32F : CUDA_R_32I;
const int batch_count = 1;
const void* beta;
int findAlgo = cublas_algo_map_->isExist(batch_count, m, n, k, getCublasDataType(dataType));
cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(dataType));
cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
// --------------------------------------
// Create descriptors for the original matrices
check_cuda_error(cublasLtMatrixLayoutCreate(&Adesc, dataType, k, m, lda));
check_cuda_error(cublasLtMatrixLayoutCreate(&Bdesc, dataType, k, n, ldb));
check_cuda_error(cublasLtMatrixLayoutCreate(&Cdesc, resultType, m, n, ldc));
check_cuda_error(cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType));
auto pointer_mode = CUBLASLT_POINTER_MODE_HOST;
if (mode == 0)
{
pointer_mode
= per_column_scaling ? CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST : CUBLASLT_POINTER_MODE_DEVICE;
}
check_cuda_error(
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_a, sizeof(cublasOperation_t)));
check_cuda_error(
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op_b, sizeof(cublasOperation_t)));
check_cuda_error(
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSC, &op_b, sizeof(cublasOperation_t)));
check_cuda_error(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode)));
const int32_t int_one = 1;
const int32_t int_zero = 0;
const float float_zero = 0;
if (mode == 0)
{
beta = per_column_scaling ? &float_zero : NULL;
}
else
{
alpha = &int_one;
beta = &int_zero;
}
void* workSpace = cublas_workspace_;
int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
sync_check_cuda_error();
auto ret = cublasLtMatmulWrapper(*cublaslt_handle_, operationDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, C,
Cdesc, NULL, workSpace, workspaceSize, stream_);
check_cuda_error(ret);
sync_check_cuda_error();
cublasLtMatmulDescDestroy(operationDesc);
cublasLtMatrixLayoutDestroy(Adesc);
cublasLtMatrixLayoutDestroy(Bdesc);
cublasLtMatrixLayoutDestroy(Cdesc);
sync_check_cuda_error();
mu_->unlock();
#endif
}
void cublasMMWrapper::Int8Gemm(const int m, const int n, const int k, const int8_t* A, const int lda, const int8_t* B,
const int ldb, int8_t* C, const int ldc, const float* alpha, const bool per_column_scaling)
{
return _Int8Gemm(m, n, k, A, lda, B, ldb, C, ldc, alpha, 0, per_column_scaling);
}
void cublasMMWrapper::Int8Gemm(const int m, const int n, const int k, const int8_t* A, const int lda, const int8_t* B,
const int ldb, int32_t* C, const int ldc)
{
return _Int8Gemm(m, n, k, A, lda, B, ldb, C, ldc, (float*) nullptr, 1, false);
}
} // namespace common
} // namespace tensorrt_llm