TensorRT-LLMs/cpp/tensorrt_llm/thop/cublasScaledMM.cpp
Zhenhua Wang 62298bc473
perf: customize cublastLt algo for Llamba 3.3 70B TP4 (#6315)
Signed-off-by: Zhenhua Wang <zhenhuaw@nvidia.com>
2025-07-24 23:01:15 +08:00

358 lines
14 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (out) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cublasMMWrapper.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/userbuffers/ub_interface.h"
#include "tensorrt_llm/plugins/common/plugin.h"
#include "tensorrt_llm/plugins/gemmPlugin/gemmPlugin.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/thop/thUtils.h"
#include "userbuffersTensor.h"
#include <array>
#include <cublasLt.h>
#include <torch/extension.h>
#include <unordered_map>
using torch::Tensor;
namespace torch_ext
{
namespace
{
using tensorrt_llm::common::check;
using tensorrt_llm::common::CublasMMWrapper;
struct hash_tuple
{
size_t operator()(std::tuple<int, int, int> const& x) const
{
return std::get<0>(x) ^ std::get<1>(x) ^ std::get<2>(x);
}
};
// got from cublasTest matmultFind
// {mp2, k, n}: {algo, m_tile, m_stages, m_numsK, m_reduction, m_swizzle, m_custom, m_cga}
using AlgoListType = std::unordered_map<std::tuple<int32_t, int32_t, int32_t>, std::array<int, 7>, hash_tuple>;
// bf16*bf16->fp32->bf16
AlgoListType bf16_algo_list = {
// Deepseek v3/R1 router gemm
// [-algo66 -m_tile10 -m_stages35 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom3 -m_mma0 -m_cga2 -m_scheduling1]
{{8, 7168, 256}, {10, 35, 1, 0, 0, 3, 2}},
{{512, 7168, 256}, {48, 35, 1, 0, 0, 0, 2}},
{{1024, 7168, 256}, {13, 35, 1, 0, 0, 1, 3}},
};
// fp8*fp8->fp32->fp16
AlgoListType fp8_algo_list = {
// Llama-3.1-70B
// [-algo66 -m_tile393 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom5 -m_mma0 -m_cga2 -m_scheduling1]
{{8, 8192, 8192}, {393, 36, 1, 0, 0, 5, 2}},
// [-algo66 -m_tile10 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom1 -m_mma0 -m_cga2 -m_scheduling1]
{{8, 8192, 57344}, {10, 36, 1, 0, 0, 1, 2}},
// Llama-3.3-70B TP4 (this is the default algo on B200. Here we aim to use the same algo on GB200.)
// [-algo66 -m_tile393 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom1 -m_mma0 -m_cga4 -m_scheduling1]
{{8, 8192, 14336}, {393, 36, 1, 0, 1, 1, 4}},
};
void set_algo_attr(cublasLtMatmulAlgo_t& algo, std::array<int, 7> const& attr_list)
{
auto const& [tileID, stagesID, numsK, reduction, swizzle, customOption_, cga_] = attr_list;
uint32_t customOption = customOption_;
uint16_t cga = cga_;
check_cuda_error(
cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tileID, sizeof(tileID)));
check_cuda_error(
cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stagesID, sizeof(stagesID)));
check_cuda_error(
cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &numsK, sizeof(numsK)));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &reduction, sizeof(reduction)));
check_cuda_error(
cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &swizzle, sizeof(swizzle)));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption)));
check_cuda_error(
cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID, &cga, sizeof(cga)));
}
bool find_special_algo(cublasLtMatmulAlgo_t& algo, std::shared_ptr<CublasMMWrapper> const& cublasWrapper, int32_t m,
int32_t n, int32_t k, cublasComputeType_t compType, cudaDataType_t scaleType, cudaDataType_t aType,
cudaDataType_t bType, cudaDataType_t outType)
{
int32_t mp2 = std::max(nextPowerOfTwo(m), 8);
AlgoListType algo_list;
if ((aType == CUDA_R_16BF || aType == CUDA_R_16F) && (outType == aType || outType == CUDA_R_32F)
&& compType == CUBLAS_COMPUTE_32F)
{
algo_list = bf16_algo_list;
}
else if (aType == CUDA_R_8F_E4M3 && compType == CUBLAS_COMPUTE_32F)
{
algo_list = fp8_algo_list;
}
else
{
TLLM_LOG_DEBUG(
"No special cublasLt algo found for aType=%d, outType=%d, compType=%d\n", aType, outType, compType);
return false;
}
int const algoID = 66; // CUBLASLT_MATMUL_ALGO_NVJET
check_cuda_error(cublasLtMatmulAlgoInit(
cublasWrapper->getCublasLtHandle(), compType, scaleType, aType, bType, outType, outType, algoID, &algo));
if (auto algo_iter = algo_list.find({mp2, k, n}); algo_iter != algo_list.end())
{
set_algo_attr(algo, algo_iter->second);
}
else
{
TLLM_LOG_DEBUG("No special cublasLt algo found for m=%d, k=%d, n=%d\n", m, k, n);
return false;
}
TLLM_LOG_DEBUG("Found special cublasLt algo for m=%d, k=%d, n=%d\n", m, k, n);
return true;
}
bool find_special_algo_deprecated(cublasLtMatmulAlgo_t& algo, std::shared_ptr<CublasMMWrapper> const& cublasWrapper,
int32_t m, int32_t n, int32_t k, cublasComputeType_t compType, cudaDataType_t scaleType, cudaDataType_t aType,
cudaDataType_t bType, cudaDataType_t outType)
{
int32_t mp2 = std::max(nextPowerOfTwo(m), 8);
if (aType != CUDA_R_8F_E4M3 || compType != CUBLAS_COMPUTE_32F)
{
return false;
}
int const algoID = 52;
check_cuda_error(cublasLtMatmulAlgoInit(
cublasWrapper->getCublasLtHandle(), compType, scaleType, aType, bType, outType, outType, algoID, &algo));
int tileID = CUBLASLT_MATMUL_TILE_256x128;
int swizzle = 0;
uint16_t cga = CUBLASLT_CLUSTER_SHAPE_2x1x1;
int const stagesID = CUBLASLT_MATMUL_STAGES_128xAUTO;
int const numsK = -1;
int const reduction = CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE;
if (mp2 <= 64)
{
tileID = CUBLASLT_MATMUL_TILE_64x64;
swizzle = 1;
if (n > k) // qkv & gate_up
cga = CUBLASLT_CLUSTER_SHAPE_13x1x1;
else // o & down
cga = CUBLASLT_CLUSTER_SHAPE_10x1x1;
}
else if (mp2 <= 256)
{
if (n > k) // qkv & gate_up
tileID = CUBLASLT_MATMUL_TILE_192x128;
else // o & down
tileID = CUBLASLT_MATMUL_TILE_128x128;
swizzle = 1;
cga = CUBLASLT_CLUSTER_SHAPE_1x2x1;
}
else if (mp2 <= 2048)
{
if (n > k) // qkv & gate_up
tileID = CUBLASLT_MATMUL_TILE_160x128;
else // o & down
tileID = CUBLASLT_MATMUL_TILE_256x128;
}
else
{
return false;
}
set_algo_attr(algo, {tileID, stagesID, numsK, reduction, swizzle, 0, cga});
return true;
}
void cublas_gemm_caller(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b,
std::optional<at::Tensor> const& scale_a, std::optional<at::Tensor> const& scale_b, bool fast_acc = false)
{
bool use_scale = false;
if (scale_a.has_value() && scale_b.has_value())
{
use_scale = true;
}
int32_t m = a.sizes()[0];
int32_t n = b.sizes()[1];
int32_t k = a.sizes()[1];
thread_local std::shared_ptr<CublasMMWrapper> cublasWrapper;
if (cublasWrapper == nullptr)
{
auto cublasHandle = getCublasHandle();
auto cublasLtHandle = getCublasLtHandle();
cublasWrapper = std::make_shared<CublasMMWrapper>(cublasHandle, cublasLtHandle, nullptr, nullptr);
}
cudaDataType_t aType = convert_torch_dtype(a.scalar_type());
cudaDataType_t bType = convert_torch_dtype(b.scalar_type());
cudaDataType_t outType = convert_torch_dtype(out.scalar_type());
// hardcode compute type for FP8
cublasComputeType_t compType = CUBLAS_COMPUTE_32F;
cudaDataType_t scaleType = CUDA_R_32F;
cublasWrapper->setGemmConfig(aType, bType, outType, /*computeType=*/scaleType);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(CUBLAS_WORKSPACE_SIZE, workspace_options);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto* a_ptr = static_cast<void*>(a.data_ptr());
auto* b_ptr = static_cast<void*>(b.data_ptr());
auto* out_ptr = static_cast<void*>(out.data_ptr());
auto* ws_ptr = static_cast<void*>(workspace.data_ptr());
void* a_scale = nullptr;
void* b_scale = nullptr;
if (use_scale)
{
a_scale = static_cast<void*>(scale_a.value().data_ptr());
b_scale = static_cast<void*>(scale_b.value().data_ptr());
}
cublasWrapper->setStream(stream);
cublasWrapper->setWorkspace(ws_ptr);
// set algo according to m/n/k
cublasLtMatmulAlgo_t algo;
#if CUDART_VERSION < 12080
// nvjet is not supported
bool has_algo
= find_special_algo_deprecated(algo, cublasWrapper, m, n, k, compType, scaleType, aType, bType, outType);
#else
bool has_algo = find_special_algo(algo, cublasWrapper, m, n, k, compType, scaleType, aType, bType, outType);
#endif
// swap A and B. A is column major, B is row major.
cublasWrapper->createDescriptors(
CUBLAS_OP_T, CUBLAS_OP_N, n, m, k, /*lda=*/k, /*ldb=*/k, /*ldc=*/n, /*fastAcc=*/fast_acc);
if (use_scale)
cublasWrapper->setScaleDescriptors(a_scale, b_scale);
cublasWrapper->Gemm(CUBLAS_OP_T, CUBLAS_OP_N, n, m, k, /*A=*/b_ptr, /*lda=*/k, /*B=*/a_ptr, /*ldb=*/k, out_ptr,
/*ldc=*/n, 1.0F, 0.0F, algo, has_algo, true);
cublasWrapper->destroyDescriptors();
}
} // namespace
Tensor& cublas_scaled_mm_out(Tensor const& mat_a, Tensor const& mat_b, Tensor const& scale_a, Tensor const& scale_b,
std::optional<at::Tensor> const& bias, Tensor& out)
{
// Check device
CHECK_TH_CUDA(mat_a);
CHECK_TH_CUDA(mat_b);
CHECK_TH_CUDA(scale_a);
CHECK_TH_CUDA(scale_b);
CHECK_TH_CUDA(out);
TORCH_CHECK(mat_a.dim() == 2 && mat_b.dim() == 2 && out.dim() == 2);
TORCH_CHECK(out.sizes()[0] == mat_a.sizes()[0] && mat_a.sizes()[1] == mat_b.sizes()[0]
&& mat_b.sizes()[1] == out.sizes()[1]);
TORCH_CHECK(scale_a.numel() == 1 || scale_a.numel() == mat_a.sizes()[0]);
TORCH_CHECK(scale_b.numel() == 1 || scale_b.numel() == mat_b.sizes()[1]);
// Check for strides and alignment
TORCH_CHECK(mat_a.strides()[1] == 1 && out.strides()[1] == 1); // Row-major
TORCH_CHECK(mat_b.strides()[0] == 1); // Column-major
TORCH_CHECK(out.strides()[0] % 16 == 0 && mat_b.strides()[1] % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(scale_a.is_contiguous() && scale_b.is_contiguous());
TORCH_CHECK(!bias.has_value(), "bias is not support yet");
TORCH_CHECK(mat_a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(mat_b.dtype() == torch::kFloat8_e4m3fn);
cublas_gemm_caller(out, mat_a, mat_b, scale_a, scale_b, true);
return out;
}
Tensor cublas_scaled_mm(Tensor const& mat_a, Tensor const& mat_b, Tensor const& scale_a, Tensor const& scale_b,
std::optional<at::Tensor> const& bias, std::optional<c10::ScalarType> out_dtype, bool to_userbuffers = false)
{
TORCH_CHECK(mat_a.dim() == 2 && mat_b.dim() == 2);
auto const out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
std::vector<int64_t> output_size = {mat_a.sizes()[0], mat_b.sizes()[1]};
Tensor out;
if (to_userbuffers)
{
out = torch_ext::create_userbuffers_tensor(output_size, out_dtype_).first;
}
else
{
out = at::empty(output_size, mat_a.options().dtype(out_dtype_));
}
return cublas_scaled_mm_out(mat_a, mat_b, scale_a, scale_b, bias, out);
}
Tensor& cublas_mm_out(Tensor const& mat_a, Tensor const& mat_b, std::optional<at::Tensor> const& bias, Tensor& out)
{
// Check device
CHECK_TH_CUDA(mat_a);
CHECK_TH_CUDA(mat_b);
CHECK_TH_CUDA(out);
TORCH_CHECK(mat_a.dim() == 2 && mat_b.dim() == 2 && out.dim() == 2);
// TODO: consider remove mat_b.to() and add extra transa & transb flag like trt's matmul
TORCH_CHECK(out.sizes()[0] == mat_a.sizes()[0] && mat_a.sizes()[1] == mat_b.sizes()[0]
&& mat_b.sizes()[1] == out.sizes()[1]);
// Check for strides and alignment
TORCH_CHECK(mat_a.strides()[1] == 1 && out.strides()[1] == 1); // Row-major
TORCH_CHECK(mat_b.strides()[0] == 1); // Column-major
TORCH_CHECK(!bias.has_value(), "bias is not support yet");
cublas_gemm_caller(out, mat_a, mat_b, at::nullopt, at::nullopt, false);
return out;
}
Tensor cublas_mm(Tensor const& mat_a, Tensor const& mat_b, std::optional<at::Tensor> const& bias,
std::optional<c10::ScalarType> out_dtype)
{
TORCH_CHECK(mat_a.dim() == 2 && mat_b.dim() == 2);
auto const out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
std::vector<int64_t> output_size = {mat_a.sizes()[0], mat_b.sizes()[1]};
Tensor out = at::empty(output_size, mat_a.options().dtype(out_dtype_));
return cublas_mm_out(mat_a, mat_b, bias, out);
}
} // namespace torch_ext
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"cublas_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scale_a, Tensor scale_b, Tensor? bias,"
" ScalarType? out_dtype, bool to_userbuffers=False) -> (Tensor out)");
m.def(
"cublas_scaled_mm_out(Tensor mat_a, Tensor mat_b, Tensor scale_a, Tensor scale_b, Tensor? bias,"
" int userbuffers_id, Tensor! out) -> (Tensor out)");
m.def("cublas_mm(Tensor mat_a, Tensor mat_b, Tensor? bias, ScalarType? out_dtype) -> (Tensor out)");
m.def("cublas_mm_out(Tensor mat_a, Tensor mat_b, Tensor? bias, Tensor! out) -> (Tensor out)");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("cublas_scaled_mm", &torch_ext::cublas_scaled_mm);
m.impl("cublas_mm", &torch_ext::cublas_mm);
}