TensorRT-LLMs/cpp/tensorrt_llm/kernels/splitkGroupGemm.cu
Yihan Wang 9df4dad3b6
[None][fix] Introduce inline namespace to avoid symbol collision (#9541)
Signed-off-by: Yihan Wang <yihwang@nvidia.com>
2025-12-12 23:32:15 +08:00

297 lines
14 KiB
Plaintext

/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "splitkGroupGemm.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/gemm.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/config.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h"
#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h"
#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h"
TRTLLM_NAMESPACE_BEGIN
namespace kernels
{
int64_t inline getGemmCoordSize(int64_t problemCount)
{
return (int64_t) (tensorrt_llm::common::divUp(problemCount * sizeof(cutlass::gemm::GemmCoord), 16) * 16);
}
int64_t inline getPtrSize(int64_t problemCount)
{
return (int64_t) (tensorrt_llm::common::divUp(problemCount * sizeof(half*), 16) * 16);
}
int64_t inline getLddSize(int64_t problemCount)
{
return (int64_t) (tensorrt_llm::common::divUp(problemCount * sizeof(int64_t), 16) * 16);
}
int64_t inline getOffsetSize(int64_t problemCount)
{
return (int64_t) (tensorrt_llm::common::divUp(problemCount * sizeof(int64_t), 16) * 16);
}
int64_t getSplitkGroupedGemmParamsWorkSpaceSize(int64_t problemCount)
{
auto gemm_coord_size = getGemmCoordSize(problemCount);
auto ptr_size = 4 * getPtrSize(problemCount);
auto ldd_size = 4 * getLddSize(problemCount);
auto offset_size = getOffsetSize(problemCount);
return gemm_coord_size + ptr_size + ldd_size + offset_size;
}
template <int M1, int N1, int K1, int M2, int N2, int K2, typename cutlassType, int kAlignmentAB, int kAlignmentC,
int kStages>
void splitkGroupedGemm_(std::vector<cutlass::gemm::GemmCoord> problemSizes, std::vector<void*> const& ptrA,
std::vector<void*> const& ptrB, std::vector<void*> const& ptrC, std::vector<void*> const& ptrD,
void* gemmParamsWorkSpace, int64_t gemmParamsWorkSpaceSize, void* gemmWorkSpace, int64_t gemmWorkSpaceSize,
int splitKSlices, cudaStream_t stream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
using ElementA = cutlassType;
using ElementB = cutlassType;
using ElementOutput = float;
using ElementAccumulator = float;
using ElementFinalOutput = cutlassType;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using GemmKernel = typename cutlass::gemm::kernel::DefaultSplitkGemmGrouped<ElementA, LayoutA,
cutlass::ComplexTransform::kNone, kAlignmentAB, ElementB, LayoutB, cutlass::ComplexTransform::kNone,
kAlignmentAB, ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<M1, N1, K1>, cutlass::gemm::GemmShape<M2, N2, K2>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<ElementOutput, kAlignmentC, ElementAccumulator,
ElementAccumulator>,
// NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels.
// This parameter is passed in at present to match the APIs of other kernels. The parameter
// is unused within the kernel.
cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle, kStages,
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly>::GemmKernel;
using Gemm = cutlass::gemm::device::SplitkGemmGrouped<GemmKernel>;
float alpha = 1.0f;
float beta = 0.0f;
typename Gemm::EpilogueOutputOp::Params epilogue_op(alpha, beta);
int problemCount = problemSizes.size();
auto gemm_coord_size = getGemmCoordSize(problemCount);
auto ptr_size = getPtrSize(problemCount);
auto ldd_size = getLddSize(problemCount);
auto offset_size = getOffsetSize(problemCount);
auto out_ptr_size = ptr_size;
char* host_workspace = (char*) std::malloc(gemmParamsWorkSpaceSize);
cutlass::gemm::GemmCoord* problem_sizes_host = reinterpret_cast<cutlass::gemm::GemmCoord*>(host_workspace);
ElementA** ptr_A_host = reinterpret_cast<ElementA**>(host_workspace + gemm_coord_size);
ElementB** ptr_B_host = reinterpret_cast<ElementB**>(host_workspace + gemm_coord_size + ptr_size);
ElementFinalOutput** ptr_C_host
= reinterpret_cast<ElementFinalOutput**>(host_workspace + gemm_coord_size + 2 * ptr_size);
ElementFinalOutput** ptr_D_host
= reinterpret_cast<ElementFinalOutput**>(host_workspace + gemm_coord_size + 2 * ptr_size + out_ptr_size);
int64_t* lda_host
= reinterpret_cast<int64_t*>(host_workspace + gemm_coord_size + 2 * ptr_size + 2 * out_ptr_size + 0 * ldd_size);
int64_t* ldb_host
= reinterpret_cast<int64_t*>(host_workspace + gemm_coord_size + 2 * ptr_size + 2 * out_ptr_size + 1 * ldd_size);
int64_t* ldc_host
= reinterpret_cast<int64_t*>(host_workspace + gemm_coord_size + 2 * ptr_size + 2 * out_ptr_size + 2 * ldd_size);
int64_t* ldd_host
= reinterpret_cast<int64_t*>(host_workspace + gemm_coord_size + 2 * ptr_size + 2 * out_ptr_size + 3 * ldd_size);
int64_t* offset_host
= reinterpret_cast<int64_t*>(host_workspace + gemm_coord_size + 2 * ptr_size + 2 * out_ptr_size + 4 * ldd_size);
int64_t cumulative_offsets = 0;
for (int32_t i = 0; i < problemCount; ++i)
{
problem_sizes_host[i] = problemSizes.at(i);
ptr_A_host[i] = (ElementA*) ptrA.at(i);
ptr_B_host[i] = (ElementB*) ptrB.at(i);
ptr_C_host[i] = (ElementFinalOutput*) ptrC.at(i);
ptr_D_host[i] = (ElementFinalOutput*) ptrD.at(i);
auto const& problem = problemSizes.at(i);
lda_host[i] = LayoutA::packed({problem.m(), problem.k()}).stride(0);
TLLM_CHECK(lda_host[i] % kAlignmentAB == 0);
ldb_host[i] = LayoutB::packed({problem.k(), problem.n()}).stride(0);
TLLM_CHECK(ldb_host[i] % kAlignmentAB == 0);
ldc_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0);
TLLM_CHECK(ldc_host[i] % kAlignmentC == 0);
ldd_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0);
TLLM_CHECK(ldd_host[i] % kAlignmentC == 0);
offset_host[i] = cumulative_offsets;
cumulative_offsets += problem.m() * problem.n();
}
cutlass::gemm::GemmCoord* problem_sizes_device = reinterpret_cast<cutlass::gemm::GemmCoord*>(gemmParamsWorkSpace);
ElementA** ptr_A = reinterpret_cast<ElementA**>((char*) gemmParamsWorkSpace + gemm_coord_size);
ElementB** ptr_B = reinterpret_cast<ElementB**>((char*) gemmParamsWorkSpace + gemm_coord_size + ptr_size);
ElementFinalOutput** ptr_C
= reinterpret_cast<ElementFinalOutput**>((char*) gemmParamsWorkSpace + gemm_coord_size + 2 * ptr_size);
ElementFinalOutput** ptr_D = reinterpret_cast<ElementFinalOutput**>(
(char*) gemmParamsWorkSpace + gemm_coord_size + 2 * ptr_size + out_ptr_size);
int64_t* lda = reinterpret_cast<int64_t*>(
(char*) gemmParamsWorkSpace + gemm_coord_size + 2 * ptr_size + 2 * out_ptr_size + 0 * ldd_size);
int64_t* ldb = reinterpret_cast<int64_t*>(
(char*) gemmParamsWorkSpace + gemm_coord_size + 2 * ptr_size + 2 * out_ptr_size + 1 * ldd_size);
int64_t* ldc = reinterpret_cast<int64_t*>(
(char*) gemmParamsWorkSpace + gemm_coord_size + 2 * ptr_size + 2 * out_ptr_size + 2 * ldd_size);
int64_t* ldd = reinterpret_cast<int64_t*>(
(char*) gemmParamsWorkSpace + gemm_coord_size + 2 * ptr_size + 2 * out_ptr_size + 3 * ldd_size);
int64_t* offset = reinterpret_cast<int64_t*>(
(char*) gemmParamsWorkSpace + gemm_coord_size + 2 * ptr_size + 2 * out_ptr_size + 4 * ldd_size);
TLLM_CHECK(((char*) ldc_host - (char*) host_workspace) == ((char*) ldc - (char*) gemmParamsWorkSpace));
tensorrt_llm::common::cudaAutoCpy(
(int8_t*) gemmParamsWorkSpace, (int8_t*) host_workspace, gemmParamsWorkSpaceSize, stream);
int threadblock_count = Gemm::sufficient(problemSizes.data(), problemCount);
typename Gemm::Arguments args(problem_sizes_device, problemCount, threadblock_count, epilogue_op, ptr_A, ptr_B,
ptr_C, ptr_D, lda, ldb, ldc, ldd, problemSizes.data(), splitKSlices, offset);
// Initialize the GEMM object
Gemm gemm;
size_t workSpaceSize = gemm.get_workspace_size(args);
TLLM_CHECK_WITH_INFO(workSpaceSize <= gemmWorkSpaceSize,
"workSpaceSize (%lu) is smaller than required gemmWorkSpaceSize (%lu).", workSpaceSize,
(size_t) gemmWorkSpaceSize);
cutlass::Status status = gemm.initialize(args, gemmWorkSpace);
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, "Failed to initialize CUTLASS Grouped GEMM kernel.");
// Run the grouped GEMM object
status = gemm.run(stream);
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, "Failed to run CUTLASS Grouped GEMM kernel.");
sync_check_cuda_error(stream);
std::free(host_workspace);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <int M1, int N1, int K1, int M2, int N2, int K2, int kAlignmentAB, int kAlignmentC, int kStages>
void splitkGroupedGemmType_(std::vector<cutlass::gemm::GemmCoord> const& problemSizes, std::vector<void*> const& ptrA,
std::vector<void*> const& ptrB, std::vector<void*> const& ptrC, std::vector<void*> const& ptrD,
void* gemmParamsWorkSpace, int64_t gemmParamsWorkSpaceSize, void* gemmWorkSpace, int64_t gemmWorkSpaceSize,
nvinfer1::DataType dataType, int splitKSlices, cudaStream_t stream)
{
if (dataType == nvinfer1::DataType::kHALF)
{
splitkGroupedGemm_<M1, N1, K1, M2, N2, K2, cutlass::half_t, kAlignmentAB, kAlignmentC, kStages>(problemSizes,
ptrA, ptrB, ptrC, ptrD, gemmParamsWorkSpace, gemmParamsWorkSpaceSize, gemmWorkSpace, gemmWorkSpaceSize,
splitKSlices, stream);
}
else if (dataType == nvinfer1::DataType::kFLOAT)
{
TLLM_CHECK_WITH_INFO(false, "not support float input/output");
}
#ifdef ENABLE_BF16
else if (dataType == nvinfer1::DataType::kBF16)
{
splitkGroupedGemm_<M1, N1, K1, M2, N2, K2, cutlass::bfloat16_t, kAlignmentAB, kAlignmentC, kStages>(
problemSizes, ptrA, ptrB, ptrC, ptrD, gemmParamsWorkSpace, gemmParamsWorkSpaceSize, gemmWorkSpace,
gemmWorkSpaceSize, splitKSlices, stream);
}
#endif
}
void splitkGroupedGemm(std::vector<cutlass::gemm::GemmCoord> const& problemSizes, std::vector<void*> const& ptrA,
std::vector<void*> const& ptrB, std::vector<void*> const& ptrC, std::vector<void*> const& ptrD,
void* gemmParamsWorkSpace, int64_t gemmParamsWorkSpaceSize, void* gemmWorkSpace, int64_t gemmWorkSpaceSize,
bool isLoraIn, nvinfer1::DataType dataType, int splitKSlices, int minKN, cudaStream_t stream)
{
TLLM_LOG_TRACE("%s start, isLoraIn: %d, minKN = %d", __PRETTY_FUNCTION__, static_cast<int>(isLoraIn), minKN);
if (isLoraIn)
{
// K >> N, like K = 1024, N = 8
// Use larger K tile and smaller N tile
if (minKN >= 8)
{
splitkGroupedGemmType_<16, 32, 64, 16, 32, 64, 8, 8, 4>(problemSizes, ptrA, ptrB, ptrC, ptrD,
gemmParamsWorkSpace, gemmParamsWorkSpaceSize, gemmWorkSpace, gemmWorkSpaceSize, dataType, splitKSlices,
stream);
}
else if (minKN >= 4)
{
splitkGroupedGemmType_<16, 32, 64, 16, 32, 64, 8, 4, 4>(problemSizes, ptrA, ptrB, ptrC, ptrD,
gemmParamsWorkSpace, gemmParamsWorkSpaceSize, gemmWorkSpace, gemmWorkSpaceSize, dataType, splitKSlices,
stream);
}
else if (minKN >= 2)
{
splitkGroupedGemmType_<16, 32, 64, 16, 32, 64, 8, 2, 2>(problemSizes, ptrA, ptrB, ptrC, ptrD,
gemmParamsWorkSpace, gemmParamsWorkSpaceSize, gemmWorkSpace, gemmWorkSpaceSize, dataType, splitKSlices,
stream);
}
else if (minKN >= 1)
{
splitkGroupedGemmType_<16, 32, 64, 16, 32, 64, 8, 1, 2>(problemSizes, ptrA, ptrB, ptrC, ptrD,
gemmParamsWorkSpace, gemmParamsWorkSpaceSize, gemmWorkSpace, gemmWorkSpaceSize, dataType, splitKSlices,
stream);
}
}
else
{
// N >> K, like K = 8, N = 1024
// User larger N tile and smaller K tile
if (minKN >= 8)
{
splitkGroupedGemmType_<32, 128, 32, 32, 32, 32, 8, 8, 4>(problemSizes, ptrA, ptrB, ptrC, ptrD,
gemmParamsWorkSpace, gemmParamsWorkSpaceSize, gemmWorkSpace, gemmWorkSpaceSize, dataType, splitKSlices,
stream);
}
else if (minKN >= 4)
{
splitkGroupedGemmType_<32, 128, 32, 32, 32, 32, 4, 8, 4>(problemSizes, ptrA, ptrB, ptrC, ptrD,
gemmParamsWorkSpace, gemmParamsWorkSpaceSize, gemmWorkSpace, gemmWorkSpaceSize, dataType, splitKSlices,
stream);
}
else if (minKN >= 2)
{
splitkGroupedGemmType_<32, 128, 32, 32, 32, 32, 2, 8, 2>(problemSizes, ptrA, ptrB, ptrC, ptrD,
gemmParamsWorkSpace, gemmParamsWorkSpaceSize, gemmWorkSpace, gemmWorkSpaceSize, dataType, splitKSlices,
stream);
}
else if (minKN >= 1)
{
splitkGroupedGemmType_<32, 128, 32, 32, 32, 32, 1, 8, 2>(problemSizes, ptrA, ptrB, ptrC, ptrD,
gemmParamsWorkSpace, gemmParamsWorkSpaceSize, gemmWorkSpace, gemmWorkSpaceSize, dataType, splitKSlices,
stream);
}
}
}
} // namespace kernels
TRTLLM_NAMESPACE_END