/* * SPDX-FileCopyrightText: Copyright (out) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/common/cublasMMWrapper.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/kernels/userbuffers/ub_interface.h" #include "tensorrt_llm/plugins/common/plugin.h" #include "tensorrt_llm/plugins/gemmPlugin/gemmPlugin.h" #include "tensorrt_llm/runtime/torchUtils.h" #include "tensorrt_llm/thop/thUtils.h" #include "userbuffersTensor.h" #include #include #include #include using torch::Tensor; namespace torch_ext { namespace { using tensorrt_llm::common::check; using tensorrt_llm::common::CublasMMWrapper; struct hash_tuple { size_t operator()(std::tuple const& x) const { return std::get<0>(x) ^ std::get<1>(x) ^ std::get<2>(x); } }; // got from cublasTest matmultFind // {mp2, k, n}: {algo, m_tile, m_stages, m_numsK, m_reduction, m_swizzle, m_custom, m_cga} using AlgoListType = std::unordered_map, std::array, hash_tuple>; // bf16*bf16->fp32->bf16 AlgoListType bf16_algo_list = { // Deepseek v3/R1 router gemm // [-algo66 -m_tile10 -m_stages35 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom3 -m_mma0 -m_cga2 -m_scheduling1] {{8, 7168, 256}, {10, 35, 1, 0, 0, 3, 2}}, {{512, 7168, 256}, {48, 35, 1, 0, 0, 0, 2}}, {{1024, 7168, 256}, {13, 35, 1, 0, 0, 1, 3}}, }; // fp8*fp8->fp32->fp16 AlgoListType fp8_algo_list = { // Llama-3.1-70B // [-algo66 -m_tile393 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom5 -m_mma0 -m_cga2 -m_scheduling1] {{8, 8192, 8192}, {393, 36, 1, 0, 0, 5, 2}}, // [-algo66 -m_tile10 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom1 -m_mma0 -m_cga2 -m_scheduling1] {{8, 8192, 57344}, {10, 36, 1, 0, 0, 1, 2}}, }; void set_algo_attr(cublasLtMatmulAlgo_t& algo, std::array const& attr_list) { auto const& [tileID, stagesID, numsK, reduction, swizzle, customOption_, cga_] = attr_list; uint32_t customOption = customOption_; uint16_t cga = cga_; check_cuda_error( cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tileID, sizeof(tileID))); check_cuda_error( cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stagesID, sizeof(stagesID))); check_cuda_error( cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &numsK, sizeof(numsK))); check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &reduction, sizeof(reduction))); check_cuda_error( cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &swizzle, sizeof(swizzle))); check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption))); check_cuda_error( cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID, &cga, sizeof(cga))); } bool find_special_algo(cublasLtMatmulAlgo_t& algo, std::shared_ptr const& cublasWrapper, int32_t m, int32_t n, int32_t k, cublasComputeType_t compType, cudaDataType_t scaleType, cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t outType) { int32_t mp2 = std::max(nextPowerOfTwo(m), 8); AlgoListType algo_list; if ((aType == CUDA_R_16BF || aType == CUDA_R_16F) && (outType == aType || outType == CUDA_R_32F) && compType == CUBLAS_COMPUTE_32F) { algo_list = bf16_algo_list; } else if (aType == CUDA_R_8F_E4M3 && compType == CUBLAS_COMPUTE_32F) { algo_list = fp8_algo_list; } else { TLLM_LOG_DEBUG( "No special cublasLt algo found for aType=%d, outType=%d, compType=%d\n", aType, outType, compType); return false; } int const algoID = 66; // CUBLASLT_MATMUL_ALGO_NVJET check_cuda_error(cublasLtMatmulAlgoInit( cublasWrapper->getCublasLtHandle(), compType, scaleType, aType, bType, outType, outType, algoID, &algo)); if (auto algo_iter = algo_list.find({mp2, k, n}); algo_iter != algo_list.end()) { set_algo_attr(algo, algo_iter->second); } else { TLLM_LOG_DEBUG("No special cublasLt algo found for m=%d, k=%d, n=%d\n", m, k, n); return false; } TLLM_LOG_DEBUG("Found special cublasLt algo for m=%d, k=%d, n=%d\n", m, k, n); return true; } bool find_special_algo_deprecated(cublasLtMatmulAlgo_t& algo, std::shared_ptr const& cublasWrapper, int32_t m, int32_t n, int32_t k, cublasComputeType_t compType, cudaDataType_t scaleType, cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t outType) { int32_t mp2 = std::max(nextPowerOfTwo(m), 8); if (aType != CUDA_R_8F_E4M3 || compType != CUBLAS_COMPUTE_32F) { return false; } int const algoID = 52; check_cuda_error(cublasLtMatmulAlgoInit( cublasWrapper->getCublasLtHandle(), compType, scaleType, aType, bType, outType, outType, algoID, &algo)); int tileID = CUBLASLT_MATMUL_TILE_256x128; int swizzle = 0; uint16_t cga = CUBLASLT_CLUSTER_SHAPE_2x1x1; int const stagesID = CUBLASLT_MATMUL_STAGES_128xAUTO; int const numsK = -1; int const reduction = CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE; if (mp2 <= 64) { tileID = CUBLASLT_MATMUL_TILE_64x64; swizzle = 1; if (n > k) // qkv & gate_up cga = CUBLASLT_CLUSTER_SHAPE_13x1x1; else // o & down cga = CUBLASLT_CLUSTER_SHAPE_10x1x1; } else if (mp2 <= 256) { if (n > k) // qkv & gate_up tileID = CUBLASLT_MATMUL_TILE_192x128; else // o & down tileID = CUBLASLT_MATMUL_TILE_128x128; swizzle = 1; cga = CUBLASLT_CLUSTER_SHAPE_1x2x1; } else if (mp2 <= 2048) { if (n > k) // qkv & gate_up tileID = CUBLASLT_MATMUL_TILE_160x128; else // o & down tileID = CUBLASLT_MATMUL_TILE_256x128; } else { return false; } set_algo_attr(algo, {tileID, stagesID, numsK, reduction, swizzle, 0, cga}); return true; } void cublas_gemm_caller(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, std::optional const& scale_a, std::optional const& scale_b, bool fast_acc = false) { bool use_scale = false; if (scale_a.has_value() && scale_b.has_value()) { use_scale = true; } int32_t m = a.sizes()[0]; int32_t n = b.sizes()[1]; int32_t k = a.sizes()[1]; thread_local std::shared_ptr cublasWrapper; if (cublasWrapper == nullptr) { auto cublasHandle = getCublasHandle(); auto cublasLtHandle = getCublasLtHandle(); cublasWrapper = std::make_shared(cublasHandle, cublasLtHandle, nullptr, nullptr); } cudaDataType_t aType = convert_torch_dtype(a.scalar_type()); cudaDataType_t bType = convert_torch_dtype(b.scalar_type()); cudaDataType_t outType = convert_torch_dtype(out.scalar_type()); // hardcode compute type for FP8 cublasComputeType_t compType = CUBLAS_COMPUTE_32F; cudaDataType_t scaleType = CUDA_R_32F; cublasWrapper->setGemmConfig(aType, bType, outType, /*computeType=*/scaleType); auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); auto workspace = torch::empty(CUBLAS_WORKSPACE_SIZE, workspace_options); auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); auto* a_ptr = static_cast(a.data_ptr()); auto* b_ptr = static_cast(b.data_ptr()); auto* out_ptr = static_cast(out.data_ptr()); auto* ws_ptr = static_cast(workspace.data_ptr()); void* a_scale = nullptr; void* b_scale = nullptr; if (use_scale) { a_scale = static_cast(scale_a.value().data_ptr()); b_scale = static_cast(scale_b.value().data_ptr()); } cublasWrapper->setStream(stream); cublasWrapper->setWorkspace(ws_ptr); // set algo according to m/n/k cublasLtMatmulAlgo_t algo; #if CUDART_VERSION < 12080 // nvjet is not supported bool has_algo = find_special_algo_deprecated(algo, cublasWrapper, m, n, k, compType, scaleType, aType, bType, outType); #else bool has_algo = find_special_algo(algo, cublasWrapper, m, n, k, compType, scaleType, aType, bType, outType); #endif // swap A and B. A is column major, B is row major. cublasWrapper->createDescriptors( CUBLAS_OP_T, CUBLAS_OP_N, n, m, k, /*lda=*/k, /*ldb=*/k, /*ldc=*/n, /*fastAcc=*/fast_acc); if (use_scale) cublasWrapper->setScaleDescriptors(a_scale, b_scale); cublasWrapper->Gemm(CUBLAS_OP_T, CUBLAS_OP_N, n, m, k, /*A=*/b_ptr, /*lda=*/k, /*B=*/a_ptr, /*ldb=*/k, out_ptr, /*ldc=*/n, 1.0F, 0.0F, algo, has_algo, true); cublasWrapper->destroyDescriptors(); } } // namespace Tensor& cublas_scaled_mm_out(Tensor const& mat_a, Tensor const& mat_b, Tensor const& scale_a, Tensor const& scale_b, std::optional const& bias, Tensor& out) { // Check device CHECK_TH_CUDA(mat_a); CHECK_TH_CUDA(mat_b); CHECK_TH_CUDA(scale_a); CHECK_TH_CUDA(scale_b); CHECK_TH_CUDA(out); TORCH_CHECK(mat_a.dim() == 2 && mat_b.dim() == 2 && out.dim() == 2); TORCH_CHECK(out.sizes()[0] == mat_a.sizes()[0] && mat_a.sizes()[1] == mat_b.sizes()[0] && mat_b.sizes()[1] == out.sizes()[1]); TORCH_CHECK(scale_a.numel() == 1 || scale_a.numel() == mat_a.sizes()[0]); TORCH_CHECK(scale_b.numel() == 1 || scale_b.numel() == mat_b.sizes()[1]); // Check for strides and alignment TORCH_CHECK(mat_a.strides()[1] == 1 && out.strides()[1] == 1); // Row-major TORCH_CHECK(mat_b.strides()[0] == 1); // Column-major TORCH_CHECK(out.strides()[0] % 16 == 0 && mat_b.strides()[1] % 16 == 0); // 16 Byte Alignment TORCH_CHECK(scale_a.is_contiguous() && scale_b.is_contiguous()); TORCH_CHECK(!bias.has_value(), "bias is not support yet"); TORCH_CHECK(mat_a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(mat_b.dtype() == torch::kFloat8_e4m3fn); cublas_gemm_caller(out, mat_a, mat_b, scale_a, scale_b, true); return out; } Tensor cublas_scaled_mm(Tensor const& mat_a, Tensor const& mat_b, Tensor const& scale_a, Tensor const& scale_b, std::optional const& bias, std::optional out_dtype, bool to_userbuffers = false) { TORCH_CHECK(mat_a.dim() == 2 && mat_b.dim() == 2); auto const out_dtype_ = out_dtype.value_or(mat_a.scalar_type()); std::vector output_size = {mat_a.sizes()[0], mat_b.sizes()[1]}; Tensor out; if (to_userbuffers) { out = torch_ext::create_userbuffers_tensor(output_size, out_dtype_).first; } else { out = at::empty(output_size, mat_a.options().dtype(out_dtype_)); } return cublas_scaled_mm_out(mat_a, mat_b, scale_a, scale_b, bias, out); } Tensor& cublas_mm_out(Tensor const& mat_a, Tensor const& mat_b, std::optional const& bias, Tensor& out) { // Check device CHECK_TH_CUDA(mat_a); CHECK_TH_CUDA(mat_b); CHECK_TH_CUDA(out); TORCH_CHECK(mat_a.dim() == 2 && mat_b.dim() == 2 && out.dim() == 2); // TODO: consider remove mat_b.to() and add extra transa & transb flag like trt's matmul TORCH_CHECK(out.sizes()[0] == mat_a.sizes()[0] && mat_a.sizes()[1] == mat_b.sizes()[0] && mat_b.sizes()[1] == out.sizes()[1]); // Check for strides and alignment TORCH_CHECK(mat_a.strides()[1] == 1 && out.strides()[1] == 1); // Row-major TORCH_CHECK(mat_b.strides()[0] == 1); // Column-major TORCH_CHECK(!bias.has_value(), "bias is not support yet"); cublas_gemm_caller(out, mat_a, mat_b, at::nullopt, at::nullopt, false); return out; } Tensor cublas_mm(Tensor const& mat_a, Tensor const& mat_b, std::optional const& bias, std::optional out_dtype) { TORCH_CHECK(mat_a.dim() == 2 && mat_b.dim() == 2); auto const out_dtype_ = out_dtype.value_or(mat_a.scalar_type()); std::vector output_size = {mat_a.sizes()[0], mat_b.sizes()[1]}; Tensor out = at::empty(output_size, mat_a.options().dtype(out_dtype_)); return cublas_mm_out(mat_a, mat_b, bias, out); } } // namespace torch_ext TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "cublas_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scale_a, Tensor scale_b, Tensor? bias," " ScalarType? out_dtype, bool to_userbuffers=False) -> (Tensor out)"); m.def( "cublas_scaled_mm_out(Tensor mat_a, Tensor mat_b, Tensor scale_a, Tensor scale_b, Tensor? bias," " int userbuffers_id, Tensor! out) -> (Tensor out)"); m.def("cublas_mm(Tensor mat_a, Tensor mat_b, Tensor? bias, ScalarType? out_dtype) -> (Tensor out)"); m.def("cublas_mm_out(Tensor mat_a, Tensor mat_b, Tensor? bias, Tensor! out) -> (Tensor out)"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("cublas_scaled_mm", &torch_ext::cublas_scaled_mm); m.impl("cublas_mm", &torch_ext::cublas_mm); }