mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Instead of allocating UserBuffers at beginning of runtime, UB buffers are now managed with global allocator. The allocator will dynamically assign free UB buffer or allocate new buffer for torch tensor. It makes userbuffers easier to use. * In common usecase, the Userbuffers will be allocated correctly during warm up stage. There is no dynamic allocation during inference. * UB fusion pattern is rewroten using the new UB Allocator. It contains following passes: 1. Fuse Quant with allreduce, replace with UB impl, and insert a copy_to_userbuffers. Currently the normal allreduce still does not support FP8 quant. So this need to be done in UB pass 2. Convert all supported allreduce with UB and insert copy_to_userbuffers. 3. Fuse op before ar with the copy_to_userbuffers. So the op directly writes to the userbuffer 4. Remove userbuffers finalize if the output is connect to another UB allreduce. Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
358 lines
14 KiB
C++
358 lines
14 KiB
C++
/*
|
|
* SPDX-FileCopyrightText: Copyright (out) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#include "tensorrt_llm/common/cublasMMWrapper.h"
|
|
#include "tensorrt_llm/common/cudaUtils.h"
|
|
#include "tensorrt_llm/kernels/userbuffers/ub_interface.h"
|
|
#include "tensorrt_llm/plugins/common/plugin.h"
|
|
#include "tensorrt_llm/plugins/gemmPlugin/gemmPlugin.h"
|
|
#include "tensorrt_llm/runtime/torchUtils.h"
|
|
#include "tensorrt_llm/thop/thUtils.h"
|
|
#include "userbuffersTensor.h"
|
|
#include <array>
|
|
#include <cublasLt.h>
|
|
#include <torch/extension.h>
|
|
#include <unordered_map>
|
|
|
|
using torch::Tensor;
|
|
|
|
namespace torch_ext
|
|
{
|
|
|
|
namespace
|
|
{
|
|
|
|
using tensorrt_llm::common::check;
|
|
using tensorrt_llm::common::CublasMMWrapper;
|
|
|
|
struct hash_tuple
|
|
{
|
|
size_t operator()(std::tuple<int, int, int> const& x) const
|
|
{
|
|
return std::get<0>(x) ^ std::get<1>(x) ^ std::get<2>(x);
|
|
}
|
|
};
|
|
|
|
// got from cublasTest matmultFind
|
|
// {mp2, k, n}: {algo, m_tile, m_stages, m_numsK, m_reduction, m_swizzle, m_custom, m_cga}
|
|
using AlgoListType = std::unordered_map<std::tuple<int32_t, int32_t, int32_t>, std::array<int, 7>, hash_tuple>;
|
|
|
|
// bf16*bf16->fp32->bf16
|
|
AlgoListType bf16_algo_list = {
|
|
// Deepseek v3/R1 fused_a
|
|
// [-algo66 -m_tile10 -m_stages35 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom5 -m_mma0 -m_cga2 -m_scheduling1]
|
|
{{8, 7168, 2112}, {10, 35, 1, 0, 0, 5, 2}},
|
|
// Deepseek v3/R1 router gemm
|
|
// [-algo66 -m_tile10 -m_stages35 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom3 -m_mma0 -m_cga2 -m_scheduling1]
|
|
{{8, 7168, 256}, {10, 35, 1, 0, 0, 3, 2}},
|
|
{{512, 7168, 256}, {48, 35, 1, 0, 0, 0, 2}},
|
|
{{1024, 7168, 256}, {13, 35, 1, 0, 0, 1, 3}},
|
|
};
|
|
|
|
// fp8*fp8->fp32->fp16
|
|
AlgoListType fp8_algo_list = {
|
|
// Llama-3.1-70B
|
|
// [-algo66 -m_tile393 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom5 -m_mma0 -m_cga2 -m_scheduling1]
|
|
{{8, 8192, 8192}, {393, 36, 1, 0, 0, 5, 2}},
|
|
// [-algo66 -m_tile10 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom1 -m_mma0 -m_cga2 -m_scheduling1]
|
|
{{8, 8192, 57344}, {10, 36, 1, 0, 0, 1, 2}},
|
|
};
|
|
|
|
void set_algo_attr(cublasLtMatmulAlgo_t& algo, std::array<int, 7> const& attr_list)
|
|
{
|
|
auto const& [tileID, stagesID, numsK, reduction, swizzle, customOption_, cga_] = attr_list;
|
|
uint32_t customOption = customOption_;
|
|
uint16_t cga = cga_;
|
|
check_cuda_error(
|
|
cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tileID, sizeof(tileID)));
|
|
check_cuda_error(
|
|
cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stagesID, sizeof(stagesID)));
|
|
check_cuda_error(
|
|
cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &numsK, sizeof(numsK)));
|
|
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
|
|
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &reduction, sizeof(reduction)));
|
|
check_cuda_error(
|
|
cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &swizzle, sizeof(swizzle)));
|
|
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
|
|
&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption)));
|
|
check_cuda_error(
|
|
cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID, &cga, sizeof(cga)));
|
|
}
|
|
|
|
bool find_special_algo(cublasLtMatmulAlgo_t& algo, std::shared_ptr<CublasMMWrapper> const& cublasWrapper, int32_t m,
|
|
int32_t n, int32_t k, cublasComputeType_t compType, cudaDataType_t scaleType, cudaDataType_t aType,
|
|
cudaDataType_t bType, cudaDataType_t outType)
|
|
{
|
|
int32_t mp2 = std::max(nextPowerOfTwo(m), 8);
|
|
AlgoListType algo_list;
|
|
if ((aType == CUDA_R_16BF || aType == CUDA_R_16F) && (outType == aType || outType == CUDA_R_32F)
|
|
&& compType == CUBLAS_COMPUTE_32F)
|
|
{
|
|
algo_list = bf16_algo_list;
|
|
}
|
|
else if (aType == CUDA_R_8F_E4M3 && compType == CUBLAS_COMPUTE_32F)
|
|
{
|
|
algo_list = fp8_algo_list;
|
|
}
|
|
else
|
|
{
|
|
TLLM_LOG_DEBUG(
|
|
"No special cublasLt algo found for aType=%d, outType=%d, compType=%d\n", aType, outType, compType);
|
|
return false;
|
|
}
|
|
int const algoID = 66; // CUBLASLT_MATMUL_ALGO_NVJET
|
|
check_cuda_error(cublasLtMatmulAlgoInit(
|
|
cublasWrapper->getCublasLtHandle(), compType, scaleType, aType, bType, outType, outType, algoID, &algo));
|
|
if (auto algo_iter = algo_list.find({mp2, k, n}); algo_iter != algo_list.end())
|
|
{
|
|
set_algo_attr(algo, algo_iter->second);
|
|
}
|
|
else
|
|
{
|
|
TLLM_LOG_DEBUG("No special cublasLt algo found for m=%d, k=%d, n=%d\n", m, k, n);
|
|
return false;
|
|
}
|
|
TLLM_LOG_DEBUG("Found special cublasLt algo for m=%d, k=%d, n=%d\n", m, k, n);
|
|
return true;
|
|
}
|
|
|
|
bool find_special_algo_deprecated(cublasLtMatmulAlgo_t& algo, std::shared_ptr<CublasMMWrapper> const& cublasWrapper,
|
|
int32_t m, int32_t n, int32_t k, cublasComputeType_t compType, cudaDataType_t scaleType, cudaDataType_t aType,
|
|
cudaDataType_t bType, cudaDataType_t outType)
|
|
{
|
|
int32_t mp2 = std::max(nextPowerOfTwo(m), 8);
|
|
if (aType != CUDA_R_8F_E4M3 || compType != CUBLAS_COMPUTE_32F)
|
|
{
|
|
return false;
|
|
}
|
|
int const algoID = 52;
|
|
check_cuda_error(cublasLtMatmulAlgoInit(
|
|
cublasWrapper->getCublasLtHandle(), compType, scaleType, aType, bType, outType, outType, algoID, &algo));
|
|
int tileID = CUBLASLT_MATMUL_TILE_256x128;
|
|
int swizzle = 0;
|
|
uint16_t cga = CUBLASLT_CLUSTER_SHAPE_2x1x1;
|
|
int const stagesID = CUBLASLT_MATMUL_STAGES_128xAUTO;
|
|
int const numsK = -1;
|
|
int const reduction = CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE;
|
|
if (mp2 <= 64)
|
|
{
|
|
tileID = CUBLASLT_MATMUL_TILE_64x64;
|
|
swizzle = 1;
|
|
if (n > k) // qkv & gate_up
|
|
cga = CUBLASLT_CLUSTER_SHAPE_13x1x1;
|
|
else // o & down
|
|
cga = CUBLASLT_CLUSTER_SHAPE_10x1x1;
|
|
}
|
|
else if (mp2 <= 256)
|
|
{
|
|
if (n > k) // qkv & gate_up
|
|
tileID = CUBLASLT_MATMUL_TILE_192x128;
|
|
else // o & down
|
|
tileID = CUBLASLT_MATMUL_TILE_128x128;
|
|
swizzle = 1;
|
|
cga = CUBLASLT_CLUSTER_SHAPE_1x2x1;
|
|
}
|
|
else if (mp2 <= 2048)
|
|
{
|
|
if (n > k) // qkv & gate_up
|
|
tileID = CUBLASLT_MATMUL_TILE_160x128;
|
|
else // o & down
|
|
tileID = CUBLASLT_MATMUL_TILE_256x128;
|
|
}
|
|
else
|
|
{
|
|
return false;
|
|
}
|
|
set_algo_attr(algo, {tileID, stagesID, numsK, reduction, swizzle, 0, cga});
|
|
return true;
|
|
}
|
|
|
|
void cublas_gemm_caller(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b,
|
|
std::optional<at::Tensor> const& scale_a, std::optional<at::Tensor> const& scale_b, bool fast_acc = false)
|
|
{
|
|
bool use_scale = false;
|
|
if (scale_a.has_value() && scale_b.has_value())
|
|
{
|
|
use_scale = true;
|
|
}
|
|
|
|
int32_t m = a.sizes()[0];
|
|
int32_t n = b.sizes()[1];
|
|
int32_t k = a.sizes()[1];
|
|
|
|
thread_local std::shared_ptr<CublasMMWrapper> cublasWrapper;
|
|
if (cublasWrapper == nullptr)
|
|
{
|
|
auto cublasHandle = getCublasHandle();
|
|
auto cublasLtHandle = getCublasLtHandle();
|
|
cublasWrapper = std::make_shared<CublasMMWrapper>(cublasHandle, cublasLtHandle, nullptr, nullptr);
|
|
}
|
|
|
|
cudaDataType_t aType = convert_torch_dtype(a.scalar_type());
|
|
cudaDataType_t bType = convert_torch_dtype(b.scalar_type());
|
|
cudaDataType_t outType = convert_torch_dtype(out.scalar_type());
|
|
|
|
// hardcode compute type for FP8
|
|
cublasComputeType_t compType = CUBLAS_COMPUTE_32F;
|
|
cudaDataType_t scaleType = CUDA_R_32F;
|
|
cublasWrapper->setGemmConfig(aType, bType, outType, /*computeType=*/scaleType);
|
|
|
|
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
|
auto workspace = torch::empty(CUBLAS_WORKSPACE_SIZE, workspace_options);
|
|
|
|
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
|
|
|
auto* a_ptr = static_cast<void*>(a.data_ptr());
|
|
auto* b_ptr = static_cast<void*>(b.data_ptr());
|
|
auto* out_ptr = static_cast<void*>(out.data_ptr());
|
|
auto* ws_ptr = static_cast<void*>(workspace.data_ptr());
|
|
void* a_scale = nullptr;
|
|
void* b_scale = nullptr;
|
|
if (use_scale)
|
|
{
|
|
a_scale = static_cast<void*>(scale_a.value().data_ptr());
|
|
b_scale = static_cast<void*>(scale_b.value().data_ptr());
|
|
}
|
|
|
|
cublasWrapper->setStream(stream);
|
|
cublasWrapper->setWorkspace(ws_ptr);
|
|
|
|
// set algo according to m/n/k
|
|
cublasLtMatmulAlgo_t algo;
|
|
#if CUDART_VERSION < 12080
|
|
// nvjet is not supported
|
|
bool has_algo
|
|
= find_special_algo_deprecated(algo, cublasWrapper, m, n, k, compType, scaleType, aType, bType, outType);
|
|
#else
|
|
bool has_algo = find_special_algo(algo, cublasWrapper, m, n, k, compType, scaleType, aType, bType, outType);
|
|
#endif
|
|
|
|
// swap A and B. A is column major, B is row major.
|
|
cublasWrapper->createDescriptors(
|
|
CUBLAS_OP_T, CUBLAS_OP_N, n, m, k, /*lda=*/k, /*ldb=*/k, /*ldc=*/n, /*fastAcc=*/fast_acc);
|
|
if (use_scale)
|
|
cublasWrapper->setScaleDescriptors(a_scale, b_scale);
|
|
cublasWrapper->Gemm(CUBLAS_OP_T, CUBLAS_OP_N, n, m, k, /*A=*/b_ptr, /*lda=*/k, /*B=*/a_ptr, /*ldb=*/k, out_ptr,
|
|
/*ldc=*/n, 1.0F, 0.0F, algo, has_algo, true);
|
|
cublasWrapper->destroyDescriptors();
|
|
}
|
|
|
|
} // namespace
|
|
|
|
Tensor& cublas_scaled_mm_out(Tensor const& mat_a, Tensor const& mat_b, Tensor const& scale_a, Tensor const& scale_b,
|
|
std::optional<at::Tensor> const& bias, Tensor& out)
|
|
{
|
|
// Check device
|
|
CHECK_TH_CUDA(mat_a);
|
|
CHECK_TH_CUDA(mat_b);
|
|
CHECK_TH_CUDA(scale_a);
|
|
CHECK_TH_CUDA(scale_b);
|
|
CHECK_TH_CUDA(out);
|
|
|
|
TORCH_CHECK(mat_a.dim() == 2 && mat_b.dim() == 2 && out.dim() == 2);
|
|
TORCH_CHECK(out.sizes()[0] == mat_a.sizes()[0] && mat_a.sizes()[1] == mat_b.sizes()[0]
|
|
&& mat_b.sizes()[1] == out.sizes()[1]);
|
|
TORCH_CHECK(scale_a.numel() == 1 || scale_a.numel() == mat_a.sizes()[0]);
|
|
TORCH_CHECK(scale_b.numel() == 1 || scale_b.numel() == mat_b.sizes()[1]);
|
|
|
|
// Check for strides and alignment
|
|
TORCH_CHECK(mat_a.strides()[1] == 1 && out.strides()[1] == 1); // Row-major
|
|
TORCH_CHECK(mat_b.strides()[0] == 1); // Column-major
|
|
TORCH_CHECK(out.strides()[0] % 16 == 0 && mat_b.strides()[1] % 16 == 0); // 16 Byte Alignment
|
|
TORCH_CHECK(scale_a.is_contiguous() && scale_b.is_contiguous());
|
|
|
|
TORCH_CHECK(!bias.has_value(), "bias is not support yet");
|
|
|
|
TORCH_CHECK(mat_a.dtype() == torch::kFloat8_e4m3fn);
|
|
TORCH_CHECK(mat_b.dtype() == torch::kFloat8_e4m3fn);
|
|
|
|
cublas_gemm_caller(out, mat_a, mat_b, scale_a, scale_b, true);
|
|
return out;
|
|
}
|
|
|
|
Tensor cublas_scaled_mm(Tensor const& mat_a, Tensor const& mat_b, Tensor const& scale_a, Tensor const& scale_b,
|
|
std::optional<at::Tensor> const& bias, std::optional<c10::ScalarType> out_dtype, bool to_userbuffers = false)
|
|
{
|
|
TORCH_CHECK(mat_a.dim() == 2 && mat_b.dim() == 2);
|
|
auto const out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
|
|
|
|
std::vector<int64_t> output_size = {mat_a.sizes()[0], mat_b.sizes()[1]};
|
|
|
|
Tensor out;
|
|
if (to_userbuffers)
|
|
{
|
|
out = torch_ext::create_userbuffers_tensor(output_size, out_dtype_).first;
|
|
}
|
|
else
|
|
{
|
|
out = at::empty(output_size, mat_a.options().dtype(out_dtype_));
|
|
}
|
|
|
|
return cublas_scaled_mm_out(mat_a, mat_b, scale_a, scale_b, bias, out);
|
|
}
|
|
|
|
Tensor& cublas_mm_out(Tensor const& mat_a, Tensor const& mat_b, std::optional<at::Tensor> const& bias, Tensor& out)
|
|
{
|
|
// Check device
|
|
CHECK_TH_CUDA(mat_a);
|
|
CHECK_TH_CUDA(mat_b);
|
|
CHECK_TH_CUDA(out);
|
|
|
|
TORCH_CHECK(mat_a.dim() == 2 && mat_b.dim() == 2 && out.dim() == 2);
|
|
// TODO: consider remove mat_b.to() and add extra transa & transb flag like trt's matmul
|
|
TORCH_CHECK(out.sizes()[0] == mat_a.sizes()[0] && mat_a.sizes()[1] == mat_b.sizes()[0]
|
|
&& mat_b.sizes()[1] == out.sizes()[1]);
|
|
|
|
// Check for strides and alignment
|
|
TORCH_CHECK(mat_a.strides()[1] == 1 && out.strides()[1] == 1); // Row-major
|
|
TORCH_CHECK(mat_b.strides()[0] == 1); // Column-major
|
|
|
|
TORCH_CHECK(!bias.has_value(), "bias is not support yet");
|
|
|
|
cublas_gemm_caller(out, mat_a, mat_b, at::nullopt, at::nullopt, false);
|
|
return out;
|
|
}
|
|
|
|
Tensor cublas_mm(Tensor const& mat_a, Tensor const& mat_b, std::optional<at::Tensor> const& bias,
|
|
std::optional<c10::ScalarType> out_dtype)
|
|
{
|
|
TORCH_CHECK(mat_a.dim() == 2 && mat_b.dim() == 2);
|
|
auto const out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
|
|
std::vector<int64_t> output_size = {mat_a.sizes()[0], mat_b.sizes()[1]};
|
|
Tensor out = at::empty(output_size, mat_a.options().dtype(out_dtype_));
|
|
return cublas_mm_out(mat_a, mat_b, bias, out);
|
|
}
|
|
|
|
} // namespace torch_ext
|
|
|
|
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
|
{
|
|
m.def(
|
|
"cublas_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scale_a, Tensor scale_b, Tensor? bias,"
|
|
" ScalarType? out_dtype, bool to_userbuffers=False) -> (Tensor out)");
|
|
m.def(
|
|
"cublas_scaled_mm_out(Tensor mat_a, Tensor mat_b, Tensor scale_a, Tensor scale_b, Tensor? bias,"
|
|
" int userbuffers_id, Tensor! out) -> (Tensor out)");
|
|
m.def("cublas_mm(Tensor mat_a, Tensor mat_b, Tensor? bias, ScalarType? out_dtype) -> (Tensor out)");
|
|
m.def("cublas_mm_out(Tensor mat_a, Tensor mat_b, Tensor? bias, Tensor! out) -> (Tensor out)");
|
|
}
|
|
|
|
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
|
{
|
|
m.impl("cublas_scaled_mm", &torch_ext::cublas_scaled_mm);
|
|
m.impl("cublas_mm", &torch_ext::cublas_mm);
|
|
}
|