mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
221 lines
7.1 KiB
C++
221 lines
7.1 KiB
C++
/*
|
|
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "tensorrt_llm/common/cublasAlgoMap.h"
|
|
|
|
namespace tensorrt_llm
|
|
{
|
|
namespace common
|
|
{
|
|
|
|
cublasAlgoMap::cublasAlgoMap(const std::string filename, const std::string sp_config_filename)
|
|
: config_filename_(filename)
|
|
, sp_config_filename_(sp_config_filename)
|
|
{
|
|
loadGemmConfig();
|
|
loadSpGemmConfig();
|
|
}
|
|
|
|
cublasAlgoMap::cublasAlgoMap(const cublasAlgoMap& algo_map)
|
|
: config_filename_(algo_map.config_filename_)
|
|
, sp_config_filename_(algo_map.sp_config_filename_)
|
|
, algo_map_(algo_map.algo_map_)
|
|
, sp_algo_map_(algo_map.sp_algo_map_)
|
|
{
|
|
}
|
|
|
|
cublasAlgoMap::~cublasAlgoMap()
|
|
{
|
|
algo_map_.clear();
|
|
}
|
|
|
|
void cublasAlgoMap::loadGemmConfig()
|
|
{
|
|
FILE* fd;
|
|
fd = fopen(config_filename_.c_str(), "r");
|
|
if (fd == NULL)
|
|
{
|
|
return;
|
|
}
|
|
|
|
int batchCount2, m2, n2, k2, algoId, customOption, tile, splitK_val;
|
|
int batch_size, seq_len, head_num, size_per_head, dataType;
|
|
int swizzle, reductionScheme, workspaceSize, stages;
|
|
int inner_shapeId, cluster_shapeId, mma_shapeId, cga_shapeId, sche_mode;
|
|
float exec_time;
|
|
char tmp[1024];
|
|
if (!fgets(tmp, 1024, fd))
|
|
{
|
|
printf("[ERROR] fgets fail at %s:%d \n", __FILE__, __LINE__);
|
|
exit(-1);
|
|
}
|
|
while (fscanf(fd,
|
|
"%d %d %d %d %d ### %d %d %d %d %d %d %d %d %d %d %d %d "
|
|
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
|
|
"%d %d "
|
|
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
|
|
"%d %d %d "
|
|
#endif
|
|
"%f\n",
|
|
&batch_size, &seq_len, &head_num, &size_per_head, &dataType, &batchCount2, &n2, &m2, &k2, &algoId,
|
|
&customOption, &tile, &splitK_val, &swizzle, &reductionScheme, &workspaceSize, &stages,
|
|
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
|
|
&inner_shapeId, &cluster_shapeId,
|
|
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
|
|
&mma_shapeId, &cga_shapeId, &sche_mode,
|
|
#endif
|
|
&exec_time)
|
|
!= EOF)
|
|
{
|
|
if (dataType != FLOAT_DATATYPE && dataType != HALF_DATATYPE && dataType != BFLOAT16_DATATYPE
|
|
&& dataType != INT8_DATATYPE && dataType != FP8_DATATYPE)
|
|
{
|
|
printf("[WARNING][readAlgoFromConfig] wrong dataType %d!\n", dataType);
|
|
continue;
|
|
}
|
|
char mark[256];
|
|
sprintf(mark, "%d_%d_%d_%d_%d", batchCount2, m2, n2, k2, dataType);
|
|
std::string markStr(mark);
|
|
// workspaceSize should be zero
|
|
if (algo_map_.find(markStr) == algo_map_.end())
|
|
{
|
|
algo_map_[markStr].algoId = algoId;
|
|
algo_map_[markStr].customOption = customOption;
|
|
algo_map_[markStr].tile = tile;
|
|
algo_map_[markStr].splitK_val = splitK_val;
|
|
algo_map_[markStr].swizzle = swizzle;
|
|
algo_map_[markStr].reductionScheme = reductionScheme;
|
|
algo_map_[markStr].workspaceSize = workspaceSize;
|
|
algo_map_[markStr].stages = stages;
|
|
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
|
|
algo_map_[markStr].inner_shapeId = (uint16_t) inner_shapeId;
|
|
algo_map_[markStr].cluster_shapeId = (uint16_t) cluster_shapeId;
|
|
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
|
|
algo_map_[markStr].mma_shapeId = (uint16_t) mma_shapeId;
|
|
algo_map_[markStr].cga_shapeId = (uint16_t) cga_shapeId;
|
|
algo_map_[markStr].sche_mode = (uint16_t) sche_mode;
|
|
#endif
|
|
algo_map_[markStr].exec_time = exec_time;
|
|
}
|
|
}
|
|
fclose(fd);
|
|
}
|
|
|
|
bool cublasAlgoMap::isExist(
|
|
const int batch_count, const int m, const int n, const int k, const CublasDataType data_type)
|
|
{
|
|
char mark[256];
|
|
sprintf(mark, "%d_%d_%d_%d_%d", batch_count, n, m, k, data_type);
|
|
return algo_map_.find(mark) != algo_map_.end();
|
|
}
|
|
|
|
cublasLtMatmulAlgo_info cublasAlgoMap::getAlgo(
|
|
const int batch_count, const int m, const int n, const int k, const CublasDataType data_type)
|
|
{
|
|
char mark[256];
|
|
sprintf(mark, "%d_%d_%d_%d_%d", batch_count, n, m, k, data_type);
|
|
if (algo_map_.find(mark) != algo_map_.end())
|
|
{
|
|
return algo_map_[mark];
|
|
}
|
|
else
|
|
{
|
|
cublasLtMatmulAlgo_info tmp_algo;
|
|
tmp_algo.algoId
|
|
= static_cast<int>(data_type == FLOAT_DATATYPE ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
|
tmp_algo.customOption = -1;
|
|
tmp_algo.tile = -1;
|
|
tmp_algo.splitK_val = -1;
|
|
tmp_algo.swizzle = -1;
|
|
tmp_algo.reductionScheme = -1;
|
|
tmp_algo.workspaceSize = -1;
|
|
tmp_algo.stages = -1;
|
|
tmp_algo.exec_time = -1.0f;
|
|
return tmp_algo;
|
|
}
|
|
}
|
|
|
|
void cublasAlgoMap::loadSpGemmConfig()
|
|
{
|
|
if (sp_config_filename_.empty())
|
|
{
|
|
return;
|
|
}
|
|
FILE* fd = fopen(sp_config_filename_.c_str(), "r");
|
|
if (fd == NULL)
|
|
{
|
|
return;
|
|
}
|
|
sp_algo_map_.clear();
|
|
int batch_size, seq_len, head_num, size_per_head, data_type;
|
|
int batchCount, m, n, k, algoId;
|
|
float exec_time;
|
|
char tmp[1024];
|
|
if (!fgets(tmp, 1024, fd))
|
|
{
|
|
printf("[ERROR] fgets fail at %s:%d \n", __FILE__, __LINE__);
|
|
exit(-1);
|
|
}
|
|
while (fscanf(fd, "%d %d %d %d %d ### %d %d %d %d %d %f\n", &batch_size, &seq_len, &head_num, &size_per_head,
|
|
&data_type, &batchCount, &m, &n, &k, &algoId, &exec_time)
|
|
!= EOF)
|
|
{
|
|
char mark[256];
|
|
sprintf(mark, "%d_%d_%d_%d", batchCount, m, n, k);
|
|
std::string markStr(mark);
|
|
sp_algo_map_[markStr] = algoId;
|
|
}
|
|
fclose(fd);
|
|
}
|
|
|
|
int cublasAlgoMap::getSpAlgo(const int batch_count, const int m, const int n, const int k)
|
|
{
|
|
char mark[256];
|
|
sprintf(mark, "%d_%d_%d_%d", batch_count, m, n, k);
|
|
if (sp_algo_map_.find(mark) != sp_algo_map_.end())
|
|
{
|
|
return sp_algo_map_[mark];
|
|
}
|
|
else
|
|
{
|
|
// for remove padding, select algo 1 for simplicity
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
bool cublasAlgoMap::isUseSparse(const int batch_count, const int m, const int n, const int k)
|
|
{
|
|
// not available to use cusparselt.
|
|
if (m % 8 != 0 || n % 8 != 0 || k % 8 != 0)
|
|
{
|
|
return false;
|
|
}
|
|
char mark[256];
|
|
sprintf(mark, "%d_%d_%d_%d", batch_count, m, n, k);
|
|
if (sp_algo_map_.find(mark) != sp_algo_map_.end())
|
|
{
|
|
return sp_algo_map_[mark] != -1;
|
|
}
|
|
else
|
|
{
|
|
// no gemm test case, choose sparse according to sparse flag
|
|
return true;
|
|
}
|
|
}
|
|
|
|
} // namespace common
|
|
} // namespace tensorrt_llm
|