Feat: add deep_gemm swapab Kernel (#4430)

* feat: add deepgemm_swapab

feat: add fp8_gemm_kernel_swapab

Signed-off-by: Ruoqian Guo <ruoqiang@nvidia.com>

feat: set threshold for deepgemm and deepgemmswapab

Signed-off-by: Ruoqian Guo <ruoqiang@nvidia.com>

* docs: update README.md

Signed-off-by: Ruoqian Guo <ruoqiang@nvidia.com>

* fix: std::runtime_error needs #include <stdexcept>

Signed-off-by: Ruoqian Guo <ruoqiang@nvidia.com>

* chores: remove the redundant code

Signed-off-by: Ruoqian Guo <ruoqiang@nvidia.com>

* feat: support for dense deep_gemm swapab

Signed-off-by: Ruoqian Guo <ruoqiang@nvidia.com>

* chores: remove redundant code

Signed-off-by: Ruoqian Guo <ruoqiang@nvidia.com>

---------

Signed-off-by: Ruoqian Guo <ruoqiang@nvidia.com>
Co-authored-by: Tao Li @ NVIDIA <tali@nvidia.com>
This commit is contained in:
Ruoqian Guo 2025-05-21 10:48:43 +08:00 committed by GitHub
parent 2372589689
commit db7446fda7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 1040 additions and 203 deletions

View File

@ -225,21 +225,37 @@ std::vector<std::filesystem::path> getJitIncludeDirs()
std::string generateKernel(uint32_t const shape_n, uint32_t const shape_k, uint32_t const block_m,
uint32_t const block_n, uint32_t const block_k, uint32_t const num_groups, uint32_t const num_stages,
uint32_t const num_tma_multicast, deep_gemm::GemmType const gemm_type)
uint32_t const num_tma_multicast, deep_gemm::GemmType const gemm_type, bool swapAB = false)
{
constexpr uint32_t kNumTMAThreads = 128;
constexpr uint32_t kNumMathThreadsPerGroup = 128;
std::string input_type;
switch (gemm_type)
if (!swapAB)
{
case deep_gemm::GemmType::Normal: input_type = "NormalSchedulerInput"; break;
case deep_gemm::GemmType::GroupedContiguous: input_type = "GroupedContiguousSchedulerInput"; break;
case deep_gemm::GemmType::GroupedMasked: input_type = "GroupedMaskedSchedulerInput"; break;
case deep_gemm::GemmType::GroupedWithOffset: input_type = "GroupedWithOffsetSchedulerInput"; break;
case deep_gemm::GemmType::StridedBatched: input_type = "StridedBatchedSchedulerInput"; break;
default: throw std::runtime_error("Unsupported gemm type");
switch (gemm_type)
{
case deep_gemm::GemmType::Normal: input_type = "NormalSchedulerInput"; break;
case deep_gemm::GemmType::GroupedContiguous: input_type = "GroupedContiguousSchedulerInput"; break;
case deep_gemm::GemmType::GroupedMasked: input_type = "GroupedMaskedSchedulerInput"; break;
case deep_gemm::GemmType::GroupedWithOffset: input_type = "GroupedWithOffsetSchedulerInput"; break;
case deep_gemm::GemmType::StridedBatched: input_type = "StridedBatchedSchedulerInput"; break;
default: throw std::runtime_error("Unsupported gemm type");
}
}
else
{
switch (gemm_type)
{
case deep_gemm::GemmType::Normal: input_type = "NormalSchedulerInputSwapAB"; break;
case deep_gemm::GemmType::GroupedWithOffset: input_type = "GroupedWithOffsetSchedulerInputSwapAB"; break;
default: throw std::runtime_error("Unsupported gemm type");
}
}
// Modify kernel name based on swapAB to determine which kernel function to use
std::string kernel_name = swapAB ? "fp8_gemm_kernel_swapAB" : "fp8_gemm_kernel";
std::string scheduler_name = swapAB ? "SchedulerSelectorSwapAB" : "SchedulerSelector";
// Create the kernel source code using raw string literal
std::string code = R"(
@ -265,18 +281,19 @@ std::string generateKernel(uint32_t const shape_n, uint32_t const shape_k, uint3
using namespace deep_gemm;
using SchedulerType =
typename SchedulerSelector<GemmType::)"
+ gemm_type_to_string(gemm_type) + R"(, )" + std::to_string(shape_n) + R"(, )" + std::to_string(shape_k)
+ R"(, )" + std::to_string(block_m) + R"(, )" + std::to_string(block_n) + R"(, )" + std::to_string(block_k)
+ R"(, )" + std::to_string(num_groups) + R"(, )" + std::to_string(num_tma_multicast) + R"(>::type;
typename )"
+ scheduler_name + R"(<GemmType::)" + gemm_type_to_string(gemm_type) + R"(, )" + std::to_string(shape_n)
+ R"(, )" + std::to_string(shape_k) + R"(, )" + std::to_string(block_m) + R"(, )" + std::to_string(block_n)
+ R"(, )" + std::to_string(block_k) + R"(, )" + std::to_string(num_groups) + R"(, )"
+ std::to_string(num_tma_multicast) + R"(>::type;
__global__ void dummy_kernel() {
void *ptr = (void *)&fp8_gemm_kernel<)"
+ std::to_string(shape_n) + R"(, )" + std::to_string(shape_k) + R"(, )" + std::to_string(block_m) + R"(, )"
+ std::to_string(block_n) + R"(, )" + std::to_string(block_k) + R"(, )" + std::to_string(num_groups) + R"(, )"
+ std::to_string(num_stages) + R"(, )" + std::to_string(kNumTMAThreads) + R"(, )"
+ std::to_string(kNumMathThreadsPerGroup) + R"(, )" + std::to_string(num_tma_multicast) + R"(, SchedulerType, )"
+ input_type + R"(>;
void *ptr = (void *)&)"
+ kernel_name + R"(<)" + std::to_string(shape_n) + R"(, )" + std::to_string(shape_k) + R"(, )"
+ std::to_string(block_m) + R"(, )" + std::to_string(block_n) + R"(, )" + std::to_string(block_k) + R"(, )"
+ std::to_string(num_groups) + R"(, )" + std::to_string(num_stages) + R"(, )" + std::to_string(kNumTMAThreads)
+ R"(, )" + std::to_string(kNumMathThreadsPerGroup) + R"(, )" + std::to_string(num_tma_multicast)
+ R"(, SchedulerType, )" + input_type + R"(>;
}
)";
@ -305,7 +322,7 @@ public:
// Build function
Runtime* build(uint32_t const shape_n, uint32_t const shape_k, uint32_t const block_m, uint32_t const block_n,
uint32_t const block_k, uint32_t const num_groups, uint32_t const num_stages, uint32_t const num_tma_multicast,
deep_gemm::GemmType const gemm_type)
deep_gemm::GemmType const gemm_type, bool swapAB = false)
{
int sm_version = tensorrt_llm::common::getSMVersion();
if (sm_version != 90)
@ -317,8 +334,9 @@ public:
}
// Build signature - simplified, no MD5 calculation
std::string name = "gemm_" + std::to_string(shape_n) + "_" + std::to_string(shape_k) + "_"
+ std::to_string(block_m) + "_" + std::to_string(block_n) + "_" + std::to_string(block_k) + "_"
std::string name = std::string(swapAB ? "gemm_swapAB_" : "gemm_") + std::to_string(shape_n) + "_"
+ std::to_string(shape_k) + "_" + std::to_string(block_m) + "_" + std::to_string(block_n) + "_"
+ std::to_string(block_k) + "_" + std::to_string(num_groups) + "_" + std::to_string(num_stages)
+ std::to_string(num_groups) + "_" + std::to_string(num_stages) + "_" + std::to_string(num_tma_multicast)
+ "_" + gemm_type_to_string(gemm_type);
std::filesystem::path path = getCacheDir() / name;
@ -393,7 +411,7 @@ public:
}
std::string code = generateKernel(
shape_n, shape_k, block_m, block_n, block_k, num_groups, num_stages, num_tma_multicast, gemm_type);
shape_n, shape_k, block_m, block_n, block_k, num_groups, num_stages, num_tma_multicast, gemm_type, swapAB);
if (kJitDebugging)
{

View File

@ -44,143 +44,6 @@
namespace deep_gemm
{
template <uint32_t SHAPE_N, uint32_t SHAPE_K, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, uint32_t kNumGroups,
uint32_t kNumStages, uint32_t kNumTMAMulticast, GemmType kGemmType>
class Gemm
{
private:
using Barrier = cuda::barrier<cuda::thread_scope_block>;
public:
Gemm() = default;
// DeepGEMM
template <typename LayoutIndexType>
static void run(__nv_bfloat16* gmem_d, float* scales_b, LayoutIndexType* grouped_layout, uint32_t shape_m,
CUtensorMap const& tma_a_desc, CUtensorMap const& tma_b_desc, CUtensorMap const& tma_scales_a_desc,
CUtensorMap const& tma_d_desc, cudaStream_t stream, int num_sms, uint32_t smem_size)
{
using SchedulerType = typename SchedulerSelector<kGemmType, SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K,
kNumGroups, kNumTMAMulticast>::type;
using InputType = typename SchedulerType::Input;
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
constexpr uint32_t kNumTMAThreads = 128;
constexpr uint32_t kNumMathThreadsPerGroup = 128;
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K, kNumGroups, kNumStages,
kNumTMAThreads, kNumMathThreadsPerGroup, kNumTMAMulticast, SchedulerType>;
DG_HOST_ASSERT(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
// Cluster launch
cudaLaunchConfig_t config;
config.gridDim = num_sms;
config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
config.dynamicSmemBytes = smem_size;
config.stream = stream;
// Clusters for TMA multicast
// NOTES: `>= 4` cluster size will cause performance degradation
cudaLaunchAttribute attr;
attr.id = cudaLaunchAttributeClusterDimension;
attr.val.clusterDim = {kNumTMAMulticast, 1, 1};
config.attrs = &attr;
config.numAttrs = 1;
InputType input;
input.shape_m = shape_m;
input.grouped_layout = grouped_layout;
// Launch
auto status = cudaLaunchKernelEx(
&config, kernel, gmem_d, scales_b, input, tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc);
DG_HOST_ASSERT(status == cudaSuccess);
}
// Grouped GEMM with Offset
// `problem_m_padded_offsets` is used for reading scales, to satisfy the alignment requirements of TMA,
// each problem offset in problem_m_padded_offsets must be padded to multiple for 4.
template <typename LayoutIndexType>
static void run(__nv_bfloat16* gmem_d, float* scales_b, LayoutIndexType* problem_m_offsets,
LayoutIndexType* problem_m_padded_offsets, CUtensorMap const& tma_a_desc, CUtensorMap const& tma_b_desc,
CUtensorMap const& tma_scales_a_desc, CUtensorMap const& tma_d_desc, cudaStream_t stream, int num_sms,
uint32_t smem_size)
{
using SchedulerType = typename SchedulerSelector<kGemmType, SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K,
kNumGroups, kNumTMAMulticast>::type;
using InputType = typename SchedulerType::Input;
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
constexpr uint32_t kNumTMAThreads = 128;
constexpr uint32_t kNumMathThreadsPerGroup = 128;
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K, kNumGroups, kNumStages,
kNumTMAThreads, kNumMathThreadsPerGroup, kNumTMAMulticast, SchedulerType>;
DG_HOST_ASSERT(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
// Cluster launch
cudaLaunchConfig_t config;
config.gridDim = num_sms;
config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
config.dynamicSmemBytes = smem_size;
config.stream = stream;
// Clusters for TMA multicast
// NOTES: `>= 4` cluster size will cause performance degradation
cudaLaunchAttribute attr;
attr.id = cudaLaunchAttributeClusterDimension;
attr.val.clusterDim = {kNumTMAMulticast, 1, 1};
config.attrs = &attr;
config.numAttrs = 1;
InputType input;
input.problem_m_offsets = problem_m_offsets;
input.problem_m_padded_offsets = problem_m_padded_offsets;
// Launch
auto status = cudaLaunchKernelEx(
&config, kernel, gmem_d, scales_b, input, tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc);
DG_HOST_ASSERT(status == cudaSuccess);
}
// Batched Strided GEMM
static void run(__nv_bfloat16* gmem_d, float* scales_b, uint32_t shape_m, CUtensorMap const& tma_a_desc,
CUtensorMap const& tma_b_desc, CUtensorMap const& tma_scales_a_desc, CUtensorMap const& tma_d_desc,
uint64_t ld_a, uint64_t stride_a, uint64_t ld_b, uint64_t stride_b, uint64_t ld_d, uint64_t stride_d,
cudaStream_t stream, int num_sms, uint32_t smem_size)
{
using SchedulerType = typename SchedulerSelector<kGemmType, SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K,
kNumGroups, kNumTMAMulticast>::type;
using InputType = typename SchedulerType::Input;
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
constexpr uint32_t kNumTMAThreads = 128;
constexpr uint32_t kNumMathThreadsPerGroup = 128;
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K, kNumGroups, kNumStages,
kNumTMAThreads, kNumMathThreadsPerGroup, kNumTMAMulticast, SchedulerType>;
DG_HOST_ASSERT(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
// Cluster launch
cudaLaunchConfig_t config;
config.gridDim = num_sms;
config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
config.dynamicSmemBytes = smem_size;
config.stream = stream;
// Clusters for TMA multicast
// NOTES: `>= 4` cluster size will cause performance degradation
cudaLaunchAttribute attr;
attr.id = cudaLaunchAttributeClusterDimension;
attr.val.clusterDim = {kNumTMAMulticast, 1, 1};
config.attrs = &attr;
config.numAttrs = 1;
InputType input{shape_m, ld_a, stride_a, ld_b, stride_b, ld_d, stride_d};
// Launch
auto status = cudaLaunchKernelEx(
&config, kernel, gmem_d, scales_b, input, tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc);
DG_HOST_ASSERT(status == cudaSuccess);
}
};
template <typename T>
static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m, uint32_t shape_k, uint32_t block_m,
uint32_t block_k, uint32_t num_groups, GemmType gemm_type, uint64_t global_stride_in_bytes = 0)
@ -229,6 +92,53 @@ CUtensorMap make_tma_scales_a_offset_desc(T* global_address, int64_t max_m_padde
1, global_stride_in_bytes, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
}
template <typename T>
CUtensorMap make_2d_tma_a_desc_swapAB(T* global_address, uint32_t shape_m, uint32_t shape_k, uint32_t block_m,
uint32_t block_k, uint32_t num_groups, GemmType gemm_type, uint64_t global_stride_in_bytes = 0)
{
return make_2d_tma_desc(global_address, Layout::RowMajor,
shape_m * (gemm_type != GemmType::Normal ? num_groups : 1), shape_k, block_m, block_k, global_stride_in_bytes);
}
template <typename T>
CUtensorMap make_2d_tma_b_desc_swapAB(T* global_address, uint32_t shape_n, uint32_t shape_k, uint32_t block_n,
uint32_t block_k, uint32_t num_groups, GemmType gemm_type, uint64_t global_stride_in_bytes = 0)
{
return make_2d_tma_desc(global_address, Layout::ColMajor, shape_k,
shape_n * (gemm_type == GemmType::GroupedMasked ? num_groups : 1), block_k, block_n, global_stride_in_bytes);
}
template <typename T>
CUtensorMap make_2d_tma_d_desc_swapAB(T* global_address, uint32_t shape_m, uint32_t shape_n, uint32_t block_m,
uint32_t block_n, uint32_t num_groups, GemmType gemm_type, uint64_t global_stride_in_bytes = 0)
{
return make_2d_tma_desc(global_address, Layout::RowMajor,
shape_n * (gemm_type == GemmType::GroupedMasked ? num_groups : 1), shape_m, min(block_n, shape_n),
min(block_m, shape_m), global_stride_in_bytes, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
}
template <typename T>
CUtensorMap make_2d_tma_scales_b_desc_swapAB(T* global_address, uint32_t shape_n, uint32_t shape_k, uint32_t block_n,
uint32_t block_k, uint32_t num_groups, GemmType gemm_type, uint64_t global_stride_in_bytes = 0)
{
// Make TMA aligned to 16 bytes
constexpr uint32_t kAlignment = 16 / sizeof(T);
shape_n = ceil_div(shape_n, kAlignment) * kAlignment;
return make_2d_tma_desc(global_address, Layout::RowMajor,
ceil_div(shape_k, block_k)
* ((gemm_type == GemmType::GroupedMasked || gemm_type == GemmType::StridedBatched) ? num_groups : 1),
shape_n, 1, block_n, global_stride_in_bytes, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
}
template <typename T>
CUtensorMap make_tma_scales_b_offset_desc_swapAB(T* global_address, int64_t max_n_padded_total, uint32_t shape_k,
uint32_t block_n, uint32_t block_k, uint64_t global_stride_in_bytes = 0)
{
return make_2d_tma_desc(global_address, Layout::RowMajor, ceil_div(shape_k, block_k), max_n_padded_total, 1,
block_n, global_stride_in_bytes, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
}
template <typename T>
CUtensorMap make_2d_tma_desc(T* global_address, Layout layout, uint32_t gmem_rows, uint32_t gmem_cols,
uint32_t smem_rows, uint32_t smem_cols, uint64_t global_stride_in_bytes,
@ -294,6 +204,48 @@ void runGemm(cudaKernel_t kernel, void* mat_a, int ld_a, void* mat_b, int ld_b,
DG_HOST_ASSERT(status == cudaSuccess);
}
template <typename LayoutIndexType>
void runGemmSwapAB(cudaKernel_t kernel, void* mat_a, int ld_a, void* mat_b, int ld_b, void* mat_d, int ld_d,
float* scales_a, float* scales_b, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, uint32_t block_m,
uint32_t block_n, uint32_t block_k, uint32_t num_groups, uint32_t num_tma_multicast, GemmType gemm_type,
LayoutIndexType* grouped_layout, cudaStream_t stream, int num_sms, uint32_t smem_size)
{
auto tma_a_desc = make_2d_tma_a_desc_swapAB(
reinterpret_cast<__nv_fp8_e4m3*>(mat_a), shape_m, shape_k, block_m, block_k, num_groups, gemm_type, ld_a);
auto tma_b_desc = make_2d_tma_b_desc_swapAB(
reinterpret_cast<__nv_fp8_e4m3*>(mat_b), shape_n, shape_k, block_n, block_k, num_groups, gemm_type, ld_b);
auto tma_scales_b_desc
= make_2d_tma_scales_b_desc_swapAB(scales_b, shape_n, shape_k, block_n, block_k, num_groups, gemm_type);
auto tma_d_desc = make_2d_tma_d_desc_swapAB(
reinterpret_cast<__nv_bfloat16*>(mat_d), shape_m, shape_n, block_m, block_n, num_groups, gemm_type, ld_d * 2);
constexpr uint32_t kNumTMAThreads = 128;
constexpr uint32_t kNumMathThreadsPerGroup = 128;
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
// Cluster launch
cudaLaunchConfig_t config;
config.gridDim = num_sms;
config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(static_cast<int32_t>(block_m));
config.dynamicSmemBytes = smem_size;
config.stream = stream;
// Clusters for TMA multicast
cudaLaunchAttribute attr;
attr.id = cudaLaunchAttributeClusterDimension;
attr.val.clusterDim = {num_tma_multicast, 1, 1};
config.attrs = &attr;
config.numAttrs = 1;
NormalSchedulerInputSwapAB input;
input.shape_n = shape_n;
input.grouped_layout = grouped_layout;
auto status = cudaLaunchKernelEx(&config, kernel, reinterpret_cast<__nv_bfloat16*>(mat_d), scales_a, input,
tma_a_desc, tma_b_desc, tma_scales_b_desc, tma_d_desc);
DG_HOST_ASSERT(status == cudaSuccess);
}
template <typename LayoutIndexType>
void runGemm(cudaKernel_t kernel, void* mat_a, int ld_a, void* mat_b, int ld_b, void* mat_d, int ld_d, float* scales_a,
float* scales_b, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, uint32_t block_m, uint32_t block_n,
@ -302,12 +254,12 @@ void runGemm(cudaKernel_t kernel, void* mat_a, int ld_a, void* mat_b, int ld_b,
uint32_t smem_size, uint32_t max_shape_m_padded)
{
auto tma_a_desc = make_2d_tma_a_desc(
reinterpret_cast<__nv_fp8_e4m3*>(mat_a), max_shape_m_padded, shape_k, block_m, block_k, num_groups, gemm_type);
reinterpret_cast<__nv_fp8_e4m3*>(mat_a), shape_m, shape_k, block_m, block_k, num_groups, gemm_type);
auto tma_b_desc = make_2d_tma_b_desc(
reinterpret_cast<__nv_fp8_e4m3*>(mat_b), shape_n, shape_k, block_n, block_k, num_groups, gemm_type);
auto tma_scales_a_desc = make_tma_scales_a_offset_desc(scales_a, max_shape_m_padded, shape_k, block_m, block_k);
auto tma_d_desc = make_2d_tma_d_desc(
reinterpret_cast<__nv_bfloat16*>(mat_d), max_shape_m_padded, shape_n, block_m, block_n, num_groups, gemm_type);
reinterpret_cast<__nv_bfloat16*>(mat_d), shape_m, shape_n, block_m, block_n, num_groups, gemm_type);
constexpr uint32_t kNumTMAThreads = 128;
constexpr uint32_t kNumMathThreadsPerGroup = 128;
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
@ -337,6 +289,51 @@ void runGemm(cudaKernel_t kernel, void* mat_a, int ld_a, void* mat_b, int ld_b,
DG_HOST_ASSERT(status == cudaSuccess);
}
template <typename LayoutIndexType>
void runGemmSwapAB(cudaKernel_t kernel, void* mat_a /* weight*/, int ld_a, void* mat_b /* act*/, int ld_b, void* mat_d,
int ld_d, float* scales_a /* weight scales*/, float* scales_b /* act scales*/, uint32_t shape_m, uint32_t shape_n,
uint32_t shape_k, uint32_t block_m, uint32_t block_n, uint32_t block_k, uint32_t num_groups,
uint32_t num_tma_multicast, GemmType gemm_type, LayoutIndexType* problem_n_offsets,
LayoutIndexType* problem_n_padded_offsets, cudaStream_t stream, int num_sms, uint32_t smem_size,
uint32_t max_shape_n_padded)
{
// Create tensor mappings using swapAB version functions, note the parameter order
auto tma_a_desc = make_2d_tma_a_desc_swapAB(
reinterpret_cast<__nv_fp8_e4m3*>(mat_a), shape_m, shape_k, block_m, block_k, num_groups, gemm_type);
auto tma_b_desc = make_2d_tma_b_desc_swapAB(
reinterpret_cast<__nv_fp8_e4m3*>(mat_b), shape_n, shape_k, block_n, block_k, num_groups, gemm_type);
auto tma_scales_b_desc
= make_tma_scales_b_offset_desc_swapAB(scales_b, max_shape_n_padded, shape_k, block_n, block_k);
auto tma_d_desc = make_2d_tma_d_desc_swapAB(
reinterpret_cast<__nv_bfloat16*>(mat_d), shape_m, shape_n, block_m, block_n, num_groups, gemm_type);
constexpr uint32_t kNumTMAThreads = 128;
constexpr uint32_t kNumMathThreadsPerGroup = 128;
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
// Cluster launch
cudaLaunchConfig_t config;
config.gridDim = num_sms;
config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(static_cast<int32_t>(block_m));
config.dynamicSmemBytes = smem_size;
config.stream = stream;
// Clusters for TMA multicast
cudaLaunchAttribute attr;
attr.id = cudaLaunchAttributeClusterDimension;
attr.val.clusterDim = {num_tma_multicast, 1, 1};
config.attrs = &attr;
config.numAttrs = 1;
// Update input structure to use N dimension offsets
GroupedWithOffsetSchedulerInputSwapAB input;
input.problem_n_offsets = problem_n_offsets; // Now offsets are for N dimension
input.problem_n_padded_4_offsets = problem_n_padded_offsets;
auto status = cudaLaunchKernelEx(&config, kernel, reinterpret_cast<__nv_bfloat16*>(mat_d), scales_a, input,
tma_a_desc, tma_b_desc, tma_scales_b_desc, tma_d_desc);
DG_HOST_ASSERT(status == cudaSuccess);
}
void runGemm(cudaKernel_t kernel, void* mat_a, uint64_t ld_a, uint64_t stride_a, void* mat_b, uint64_t ld_b,
uint64_t stride_b, void* mat_d, uint64_t ld_d, uint64_t stride_d, float* scales_a, float* scales_b,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, uint32_t block_m, uint32_t block_n, uint32_t block_k,

View File

@ -481,4 +481,433 @@ __global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMat
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
#endif
}
template <uint32_t SHAPE_M, uint32_t SHAPE_K, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, uint32_t kNumGroups,
uint32_t kNumStages, uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup, uint32_t kNumTMAMulticast,
typename SchedulerType, typename InputType>
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
fp8_gemm_kernel_swapAB(__nv_bfloat16* gmem_d, float* scales_a, InputType problem_input,
const __grid_constant__ CUtensorMap tensor_map_a, // weight (previously act)
const __grid_constant__ CUtensorMap tensor_map_b, // act (previously weight)
const __grid_constant__ CUtensorMap tensor_map_scales_b, // act scales (previously tensor_map_scales_a)
const __grid_constant__ CUtensorMap tensor_map_d)
{
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
// Scaling checks
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
DG_STATIC_ASSERT(ceil_div(BLOCK_M, BLOCK_K) == 1, "Too much A scales in a single block");
// Types
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// Shared memory
DG_STATIC_ASSERT(BLOCK_K % BLOCK_M == 0, "BLOCK_M should be 64 or 128 and BLOCK_K should be 128");
static constexpr uint32_t SMEM_D_SIZE = BLOCK_N * BLOCK_M * sizeof(__nv_bfloat16);
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_SCALES_B_SIZE_PER_STAGE = BLOCK_N * sizeof(float); // B matrix (act) scales
static constexpr uint32_t SMEM_SCALES_B_SIZE_PER_STAGE_PADDED
= ceil_div<uint32_t>(BLOCK_N * sizeof(float), 128) * 128; // B matrix (act) scales, 128B aligned
static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K);
static constexpr uint32_t SMEM_SCALES_A_SIZE = ceil_div<uint32_t>(SHAPE_K_SCALES * sizeof(float), sizeof(Barrier))
* sizeof(Barrier); // renamed to A (weight)
// Configs
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages);
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const uint32_t lane_idx = get_lane_id();
// Prefetch TMA descriptors at very beginning
if (threadIdx.x == kNumMathThreads)
{
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a));
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b));
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_b));
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d));
}
__syncwarp();
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
// Data on shared memory
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
__nv_fp8_e4m3* smem_a[kNumStages]; // weight
__nv_fp8_e4m3* smem_b[kNumStages]; // act
float* smem_scales_b[kNumStages]; // act scales
float* smem_scales_a; // weight scales
// TMA Barrier for both divisible and non-divisible cases
Barrier* full_barriers[kNumStages];
Barrier* empty_barriers[kNumStages];
// Fill shared memory pointers
#pragma unroll
for (int i = 0; i < kNumStages; ++i)
{
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(
smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
smem_scales_b[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE
+ kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_B_SIZE_PER_STAGE_PADDED);
}
smem_scales_a = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE
+ kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE_PADDED));
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_scales_a) + SMEM_SCALES_A_SIZE);
#pragma unroll
for (int i = 0; i < kNumStages; ++i)
{
full_barriers[i] = barrier_start_ptr + i;
empty_barriers[i] = barrier_start_ptr + kNumStages + i;
}
// Initialize barriers
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast");
if (threadIdx.x == kNumMathThreads)
{
#pragma unroll
for (int i = 0; i < kNumStages; ++i)
{
full_barriers[i]->init(1);
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_view_async_shared();
(kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void();
}
// Synchronize all threads to make barrier visible in normal memory model
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
// For pipeline unrolling
struct DivisibleK
{
};
struct NotDivisibleK
{
};
auto launch_k_iterations = [](auto const& func)
{
if constexpr (SHAPE_K % kFullKOfAllStages == 0)
{
for (int k_iter = 0; k_iter < kNumIterations; ++k_iter)
func(k_iter, DivisibleK{});
}
else
{
for (int k_iter = 0; k_iter < kNumIterations - 1; ++k_iter)
func(k_iter, DivisibleK{});
func(kNumIterations - 1, NotDivisibleK{});
}
};
// Register reconfigurations
constexpr int kNumTMARegisters = 40;
constexpr int kNumMathRegisters = 232;
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = SchedulerType(problem_input);
if (threadIdx.x >= kNumMathThreads)
{
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
// NOTES: only one thread (or warp) will be used
if (threadIdx.x == kNumMathThreads)
{
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx))
{
launch_k_iterations(
[&](int k_iter, auto type)
{
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages
= kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++s)
{
// Wait consumer release
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
// Issue TMA A (weight) now without broadcasting
auto& full_barrier = *full_barriers[s];
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier), smem_a[s], k_idx,
scheduler.get_global_m_idx(SHAPE_M, BLOCK_M, m_block_idx, n_block_idx));
// Issue TMA B (act) with broadcasting
tma_copy<kNumTMAMulticast>(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[s], k_idx, scheduler.get_global_n_idx(n_block_idx));
// Issue TMA scales_b (act scales) for B matrix
if constexpr (SchedulerType::gemm_type == GemmType::GroupedWithOffset)
{
tma_copy<kNumTMAMulticast>(&tensor_map_scales_b,
reinterpret_cast<uint64_t*>(&full_barrier), smem_scales_b[s],
scheduler.get_global_scales_b_idx(n_block_idx), k_idx / BLOCK_K);
}
else
{
tma_copy<kNumTMAMulticast>(&tensor_map_scales_b,
reinterpret_cast<uint64_t*>(&full_barrier), smem_scales_b[s], n_block_idx * BLOCK_N,
scheduler.get_global_scales_b_idx(k_idx / BLOCK_K));
}
full_barrier.arrive_and_expect_tx(
SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE);
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++s)
{
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
full_barriers[s]->arrive();
}
});
}
// To safely deconstruct distributed shared barriers, we need another round of empty waits
if constexpr (kNumTMAMulticast > 1)
{
#pragma unroll
for (uint32_t s = 0; s < kNumStages; ++s)
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1);
}
}
}
else
{
// Math warp-groups for WGMMA
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
auto const math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0);
// Each thread loads consecutive 2 scales
const uint32_t scale_offset = (lane_idx % 4) * 2;
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx))
{
// Load weight scales (scales_a) - these are associated with tensor_map_a (weight)
// Decide the number of scales A to load
DG_STATIC_ASSERT(SHAPE_M % 8 == 0, "Invalid shape M");
uint32_t num_scales_a = SHAPE_K_SCALES;
// Load A scales with math warp-groups (weight scales)
if (threadIdx.x >= 32)
{
auto num_previous_lines
= scheduler.get_global_scales_a_idx(ceil_div(SHAPE_M, BLOCK_K), 0, 0, n_block_idx);
auto local_scales_a
= scales_a + (num_previous_lines + ((m_block_idx * BLOCK_M) / BLOCK_K)) * SHAPE_K_SCALES;
#pragma unroll
for (uint32_t i = threadIdx.x - 32; i < num_scales_a; i += kNumMathThreads - 32)
st_shared(smem_scales_a + i, __ldg(local_scales_a + i));
}
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Accumulation for WGMMA or CUDA promotion
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
// Empty barrier arrival
auto empty_barrier_arrive = [&](int s)
{
if constexpr (kNumTMAMulticast == 1)
{
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
}
else
{
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();
}
};
// Launch MMAs
launch_k_iterations(
[&](int k_iter, auto type)
{
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages
= kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
#pragma unroll
for (int s = 0; s < kNumInnerStages; ++s)
{
// Read weight scales (A scales)
float scale_a_0 = ld_shared(smem_scales_a + k_iter * kNumStages + s);
// Wait TMA arrivals
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled
// block polluting the results
// Each thread reads consecutive two b scales, each thread needs to read WGMMA::N / 4 * 2 b
// scales
float scale_0_0[WGMMA::kNumAccum / 4], scale_0_1[WGMMA::kNumAccum / 4];
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++i)
{
float2 scale_b
= ld_shared(reinterpret_cast<const float2*>(smem_scales_b[s] + i * 8 + scale_offset));
scale_0_0[i] = scale_a_0 * scale_b.x;
scale_0_1[i] = scale_a_0 * scale_b.y;
}
// Commit WGMMA instructions
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (int k = 0; k < BLOCK_K / WGMMA::K; ++k)
{
auto desc_a
= make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
warpgroup_commit_batch();
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Notify barrier arrival
empty_barrier_arrive(s);
// Promote with scales
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++i)
{
final_accum[i * 4 + 0] += scale_0_0[i] * accum[i * 4 + 0];
final_accum[i * 4 + 1] += scale_0_1[i] * accum[i * 4 + 1];
final_accum[i * 4 + 2] += scale_0_0[i] * accum[i * 4 + 2];
final_accum[i * 4 + 3] += scale_0_1[i] * accum[i * 4 + 3];
}
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++s)
{
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
empty_barrier_arrive(s);
}
});
// Write back to shared memory using STSM
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
int tid = 0;
if (lane_idx < 8)
{
tid = lane_idx * BLOCK_M;
}
else if (lane_idx < 16)
{
tid = (lane_idx - 8) * BLOCK_M + 8;
}
else if (lane_idx < 24)
{
tid = (lane_idx - 8) * BLOCK_M;
}
else
{
tid = (lane_idx - 16) * BLOCK_M + 8;
}
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++i)
{
SM90_U32x4_STSM_T<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}),
__float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
__float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}),
__float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}),
smem_d + warp_idx * 16 + i * 16 * BLOCK_M + tid);
}
if constexpr (WGMMA::kNumAccum % 8 != 0)
{
SM90_U32x2_STSM_T<nv_bfloat162>::copy(__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0],
final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
__float22bfloat162_rn(
{final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
smem_d + warp_idx * 16 + WGMMA::kNumAccum / 8 * 16 * BLOCK_M + tid);
}
if constexpr (SchedulerType::gemm_type == GemmType::GroupedWithOffset)
{
auto n_global_idx = scheduler.get_global_n_idx(n_block_idx);
bool cross_boundary = (n_global_idx + BLOCK_N) > scheduler.n_boundary;
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
if (!cross_boundary)
{
// Use TMA store to write back to global memory
if (threadIdx.x == 0)
{
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, m_block_idx * BLOCK_M, n_global_idx);
cute::tma_store_arrive();
cute::tma_store_wait<0>();
}
}
else
{
__nv_bfloat16* gmem_d_this_block = gmem_d + n_global_idx * SHAPE_M;
constexpr int NUM_WARPS
= (get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M) - 128) / 32;
write_result_to_gmem<BLOCK_N, BLOCK_M, NUM_WARPS>(gmem_d_this_block, smem_d, n_global_idx,
scheduler.n_boundary, m_block_idx * BLOCK_M, SHAPE_M, SHAPE_M);
}
}
else if constexpr (SchedulerType::gemm_type == GemmType::StridedBatched)
{
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
__nv_bfloat16* gmem_d_this_block;
auto n_global_idx = scheduler.get_global_n_idx(n_block_idx);
gmem_d_this_block = gmem_d + scheduler.curr_group_idx * problem_input.stride_d
+ (n_block_idx * BLOCK_N) * problem_input.ld_d;
constexpr int NUM_WARPS
= (get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M) - 128) / 32;
write_result_to_gmem<BLOCK_N, BLOCK_M, NUM_WARPS>(gmem_d_this_block, smem_d, n_global_idx,
scheduler.n_boundary, m_block_idx * BLOCK_M, SHAPE_M, problem_input.ld_d);
}
else
{
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Use TMA store to write back to global memory
if (threadIdx.x == 0)
{
cute::SM90_TMA_STORE_2D::copy(
&tensor_map_d, smem_d, m_block_idx * BLOCK_M, scheduler.get_global_n_idx(n_block_idx));
cute::tma_store_arrive();
cute::tma_store_wait<0>();
}
}
__syncwarp();
}
}
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
#endif
}
} // namespace deep_gemm

View File

@ -62,10 +62,10 @@ using GemmConfig
std::string gemm_type_to_string(deep_gemm::GemmType gemm_type);
int div_up(int a, int b);
int get_smem_size(int num_stages, int k, int block_m, int block_n, int block_k);
int get_smem_size(int num_stages, int k, int block_m, int block_n, int block_k, bool swap_ab);
bool is_tma_multicast_legal(int n, int block_n, int num_tma_multicast, int num_sms);
GemmConfig get_best_gemm_config(uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, int num_groups,
int num_device_sms, bool is_grouped_contiguous);
int num_device_sms, bool is_grouped_contiguous, bool swap_ab);
} // namespace deep_gemm::jit
namespace deep_gemm::jit
@ -90,23 +90,46 @@ int div_up(int a, int b)
return (a + b - 1) / b;
}
int get_smem_size(int num_stages, int k, int block_m, int block_n, int block_k = 128)
int get_smem_size(int num_stages, int k, int block_m, int block_n, int block_k = 128, bool swap_ab = false)
{
int smem_d = block_m * block_n * 2;
int smem_a_per_stage = block_m * block_k;
int smem_scales_a_per_stage = block_m * 4;
int smem_b_per_stage = block_n * block_k;
int smem_scales_b = div_up(k, block_k) * 4;
int smem_barrier = num_stages * 8 * 2;
if (!swap_ab)
{
int smem_d = block_m * block_n * 2;
int smem_a_per_stage = block_m * block_k;
int smem_scales_a_per_stage = block_m * 4;
int smem_b_per_stage = block_n * block_k;
int smem_scales_b = div_up(k, block_k) * 4;
int smem_barrier = num_stages * 8 * 2;
int smem_size = 0;
smem_size += smem_d;
smem_size += num_stages * smem_a_per_stage;
smem_size += num_stages * smem_scales_a_per_stage;
smem_size += num_stages * smem_b_per_stage;
smem_size += div_up(smem_scales_b * (block_k % block_n == 0 ? 1 : 2), 8) * 8;
smem_size += smem_barrier;
return smem_size;
int smem_size = 0;
smem_size += smem_d;
smem_size += num_stages * smem_a_per_stage;
smem_size += num_stages * smem_scales_a_per_stage;
smem_size += num_stages * smem_b_per_stage;
smem_size += div_up(smem_scales_b * (block_k % block_n == 0 ? 1 : 2), 8) * 8;
smem_size += smem_barrier;
return smem_size;
}
else
{
int smem_d = block_n * block_m * 2;
int smem_a_per_stage = block_m * block_k; // weight
int smem_scales_a_per_stage = div_up(k, block_k) * 4; // weight scales
int smem_b_per_stage = block_n * block_k; // act
int smem_scales_b = div_up(block_n * 4, 128) * 128; // act scales,tma 128B alignment
int smem_barrier = num_stages * 8 * 2;
int smem_size = 0;
smem_size += smem_d;
smem_size += num_stages * smem_a_per_stage;
smem_size += num_stages * smem_scales_b;
smem_size += num_stages * smem_b_per_stage;
smem_size += div_up(smem_scales_a_per_stage, 8) * 8;
smem_size += smem_barrier;
return smem_size;
}
}
bool is_tma_multicast_legal(int n, int block_n, int num_tma_multicast, int num_sms)
@ -119,7 +142,7 @@ bool is_tma_multicast_legal(int n, int block_n, int num_tma_multicast, int num_s
}
GemmConfig get_best_gemm_config(uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, int num_groups,
int num_device_sms, bool is_grouped_contiguous = false)
int num_device_sms, bool is_grouped_contiguous = false, bool swap_ab = false)
{
// Choose candidate block sizes
std::vector<int> block_ms;
@ -196,7 +219,7 @@ GemmConfig get_best_gemm_config(uint32_t shape_m, uint32_t shape_n, uint32_t sha
for (int num_stages : stage_candidates)
{
int smem_size = get_smem_size(num_stages, shape_k, best_block_m, best_block_n);
int smem_size = get_smem_size(num_stages, shape_k, best_block_m, best_block_n, 128, swap_ab);
if (smem_size <= sm90_capacity)
{
best_num_stages = num_stages;
@ -208,9 +231,19 @@ GemmConfig get_best_gemm_config(uint32_t shape_m, uint32_t shape_n, uint32_t sha
// Determine TMA multicast settings
int best_num_tma_multicast = 1;
if (shape_m >= 1024 && is_tma_multicast_legal(shape_n, best_block_n, 2, num_device_sms) && num_groups == 1)
if (!swap_ab)
{
best_num_tma_multicast = 2;
if (shape_m >= 1024 && is_tma_multicast_legal(shape_n, best_block_n, 2, num_device_sms) && num_groups == 1)
{
best_num_tma_multicast = 2;
}
}
else
{
if (shape_n >= 1024 && is_tma_multicast_legal(shape_m, best_block_m, 2, num_device_sms) && num_groups == 1)
{
best_num_tma_multicast = 2;
}
}
return std::make_tuple(best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size);

View File

@ -760,6 +760,30 @@ struct SM90_U32x4_STSM_N
}
};
template <typename dtype_t>
struct SM90_U32x2_STSM_T
{
__device__ __forceinline__ static void copy(dtype_t src_0, dtype_t src_1, void* smem_dst)
{
const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16.trans [%0], {%1, %2};\n" ::"l"(smem_dst), "r"(src[0]),
"r"(src[1]));
}
};
template <typename dtype_t>
struct SM90_U32x4_STSM_T
{
__device__ __forceinline__ static void copy(
dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst)
{
const uint32_t src[4] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1),
*reinterpret_cast<uint32_t*>(&src_2), *reinterpret_cast<uint32_t*>(&src_3)};
asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16.trans [%0], {%1, %2, %3, %4};\n" ::"l"(smem_dst),
"r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3]));
}
};
__device__ void warpgroup_arrive()
{
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
@ -805,6 +829,13 @@ __device__ __forceinline__ float ld_shared(float const* __restrict__ ptr)
return ret;
}
__device__ __forceinline__ float2 ld_shared(float2 const* __restrict__ ptr)
{
float2 ret;
asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(ptr));
return ret;
}
__device__ __forceinline__ void st_shared(float const* ptr, float val)
{
asm volatile("st.shared.f32 [%0], %1;" ::"l"(ptr), "f"(val));

View File

@ -68,6 +68,12 @@ struct NormalSchedulerInput
int* grouped_layout; // no use
};
struct NormalSchedulerInputSwapAB
{
uint32_t shape_n;
int* grouped_layout; // no use
};
struct GroupedContiguousSchedulerInput
{
uint32_t shape_m;
@ -87,6 +93,13 @@ struct GroupedWithOffsetSchedulerInput
int64_t* problem_m_padded_offsets;
};
struct GroupedWithOffsetSchedulerInputSwapAB
{
uint32_t shape_m;
int64_t* problem_n_offsets;
int64_t* problem_n_padded_4_offsets;
};
struct StridedBatchedSchedulerInput
{
uint32_t shape_m;
@ -98,6 +111,17 @@ struct StridedBatchedSchedulerInput
uint64_t stride_d;
};
struct StridedBatchedSchedulerInputSwapAB
{
uint32_t shape_n;
uint64_t ld_a;
uint64_t stride_a;
uint64_t ld_b;
uint64_t stride_b;
uint64_t ld_d;
uint64_t stride_d;
};
template <uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t kNumGroups, uint32_t kNumTMAMulticast,
uint32_t kNumNBlocks = ceil_div(SHAPE_N, BLOCK_N), uint32_t kNumNBlocksPerGroup = 16>
struct NormalScheduler
@ -155,6 +179,68 @@ struct NormalScheduler
}
};
template <uint32_t SHAPE_M, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t kNumGroups, uint32_t kNumTMAMulticast,
uint32_t kNumMBlocks = ceil_div(SHAPE_M, BLOCK_M), uint32_t kNumMBlocksPerGroup = 16>
struct NormalSchedulerSwapAB
{
static constexpr GemmType gemm_type = GemmType::Normal;
int current_iter = -1;
uint32_t num_aligned_n_blocks;
uint32_t num_blocks;
using Input = NormalSchedulerInputSwapAB;
Input input;
NormalSchedulerSwapAB() {}
__device__ __forceinline__ NormalSchedulerSwapAB(Input& input)
{
num_aligned_n_blocks = ceil_div(input.shape_n, BLOCK_N);
num_blocks = num_aligned_n_blocks * kNumMBlocks;
}
// weight
__device__ __forceinline__ uint32_t get_global_m_idx(
const uint32_t shape_dim, const uint32_t block_size, uint32_t const& block_idx, uint32_t const& n_block_idx = 0)
{
return block_idx * block_size;
}
// act
__device__ __forceinline__ uint32_t get_global_n_idx(uint32_t const& block_idx)
{
return block_idx * BLOCK_N;
}
// act scales
__device__ __forceinline__ uint32_t get_global_scales_b_idx(uint32_t const& block_idx)
{
return block_idx;
}
// weight scales
__device__ __forceinline__ uint32_t get_global_scales_a_idx(
const uint32_t shape_dim, const uint32_t block_size, uint32_t const& block_idx, uint32_t const& n_block_idx = 0)
{
return block_idx * block_size;
}
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx)
{
++current_iter;
auto const next_block_idx = current_iter * gridDim.x + blockIdx.x;
if (next_block_idx >= num_blocks)
{
return false;
}
get_swizzled_block_idx<kNumTMAMulticast, kNumMBlocks, kNumMBlocksPerGroup>(
num_aligned_n_blocks, next_block_idx, n_block_idx, m_block_idx);
return true;
}
};
template <uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t kNumGroups, uint32_t kNumTMAMulticast,
uint32_t kNumNBlocks, uint32_t kNumNBlocksPerGroup>
struct GroupedContiguousScheduler
@ -376,6 +462,91 @@ struct GroupedWithOffsetScheduler
}
};
template <uint32_t SHAPE_M, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t kNumGroups, uint32_t kNumTMAMulticast,
uint32_t kNumMBlocks = ceil_div(SHAPE_M, BLOCK_M), uint32_t kNumMBlocksPerGroup = 16>
struct GroupedWithOffsetSchedulerSwapAB
{
static constexpr GemmType gemm_type = GemmType::GroupedWithOffset;
int current_iter = -1;
uint32_t curr_group_idx;
uint32_t curr_cumsum;
int64_t n_offset;
int64_t n_padded_4_offset;
int64_t n_boundary;
int64_t* problem_n_offsets;
int64_t* problem_n_padded_4_offsets;
using Input = GroupedWithOffsetSchedulerInputSwapAB;
Input input;
GroupedWithOffsetSchedulerSwapAB() {}
__device__ __forceinline__ GroupedWithOffsetSchedulerSwapAB(Input& input)
{
this->problem_n_offsets = input.problem_n_offsets;
this->problem_n_padded_4_offsets = input.problem_n_padded_4_offsets;
curr_group_idx = 0;
curr_cumsum = 0;
}
// weight
__device__ __forceinline__ uint32_t get_global_m_idx(
const uint32_t shape_dim, const uint32_t block_size, uint32_t const& block_idx, uint32_t const& n_block_idx = 0)
{
return curr_group_idx * shape_dim + block_idx * block_size;
}
// act
__device__ __forceinline__ uint32_t get_global_n_idx(uint32_t const& block_idx)
{
return n_offset + block_idx * BLOCK_N;
}
// act scales
__device__ __forceinline__ uint32_t get_global_scales_b_idx(uint32_t const& block_idx)
{
return n_padded_4_offset + block_idx * BLOCK_N;
}
// weight scales
__device__ __forceinline__ uint32_t get_global_scales_a_idx(
const uint32_t shape_dim, const uint32_t block_size, uint32_t const& block_idx, uint32_t const& n_block_idx = 0)
{
return curr_group_idx * shape_dim + block_idx * block_size;
}
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx)
{
++current_iter;
auto const next_block_idx = current_iter * gridDim.x + blockIdx.x;
uint32_t num_n_blocks;
while (true)
{
// End of the task
if (curr_group_idx == kNumGroups)
return false;
n_padded_4_offset = __ldg(problem_n_padded_4_offsets + curr_group_idx);
n_offset = __ldg(problem_n_offsets + curr_group_idx);
n_boundary = __ldg(problem_n_offsets + curr_group_idx + 1);
auto n = n_boundary - n_offset;
// Within current group
num_n_blocks = ceil_div(n, static_cast<int64_t>(BLOCK_N));
auto current_n_block_cumsum = curr_cumsum + num_n_blocks;
if (next_block_idx < current_n_block_cumsum * kNumMBlocks)
break;
// Move to check the next group
curr_group_idx++;
curr_cumsum = current_n_block_cumsum;
}
get_swizzled_block_idx<kNumTMAMulticast, kNumMBlocks, kNumMBlocksPerGroup>(
num_n_blocks, next_block_idx - curr_cumsum * kNumMBlocks, n_block_idx, m_block_idx);
return true;
}
};
template <uint32_t SHAPE_N, uint32_t SHAPE_K, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, uint32_t kNumGroups,
uint32_t kNumTMAMulticast, uint32_t kNumNBlocks = ceil_div(SHAPE_N, BLOCK_N), uint32_t kNumNBlocksPerGroup = 16>
struct StridedBatchedScheduler
@ -453,6 +624,88 @@ struct StridedBatchedScheduler
}
};
template <uint32_t SHAPE_M, uint32_t SHAPE_K, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, uint32_t kNumGroups,
uint32_t kNumTMAMulticast, uint32_t kNumMBlocks = ceil_div(SHAPE_M, BLOCK_M), uint32_t kNumMBlocksPerGroup = 16>
struct StridedBatchedSchedulerSwapAB
{
static constexpr GemmType gemm_type = GemmType::StridedBatched;
int current_iter = -1;
uint32_t curr_group_idx;
uint32_t curr_cumsum;
int64_t n_offset;
int64_t n_boundary;
using Input = StridedBatchedSchedulerInputSwapAB;
Input input;
StridedBatchedSchedulerSwapAB() {}
__device__ __forceinline__ StridedBatchedSchedulerSwapAB(Input& input)
{
this->input = input;
curr_group_idx = 0;
curr_cumsum = 0;
}
// weight
__device__ __forceinline__ uint32_t get_global_m_idx(
const uint32_t shape_dim, const uint32_t block_size, uint32_t const& block_idx, uint32_t const& n_block_idx = 0)
{
// Assuming stride_a % ld_a == 0 && stride_a >= ld_a
return input.stride_a / input.ld_a * curr_group_idx + block_idx * block_size;
}
// act
__device__ __forceinline__ uint32_t get_global_n_idx(uint32_t const& block_idx)
{
// Assuming stride_b % ld_b == 0 && stride_b >= ld_b
return input.stride_b / input.ld_b * curr_group_idx + block_idx * BLOCK_N;
}
// act scales
__device__ __forceinline__ uint32_t get_global_scales_b_idx(uint32_t const& block_idx)
{
return curr_group_idx * ceil_div(SHAPE_K, BLOCK_K) + block_idx;
}
// weight scales
__device__ __forceinline__ uint32_t get_global_scales_a_idx(
const uint32_t shape_dim, const uint32_t block_size, uint32_t const& block_idx, uint32_t const& n_block_idx = 0)
{
return curr_group_idx * shape_dim + block_idx * block_size;
}
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx)
{
++current_iter;
auto const next_block_idx = current_iter * gridDim.x + blockIdx.x;
uint32_t num_n_blocks;
while (true)
{
// End of the task
if (curr_group_idx == kNumGroups)
return false;
n_offset = curr_group_idx * input.shape_n;
n_boundary = (curr_group_idx + 1) * input.shape_n;
// Within current group
num_n_blocks = ceil_div(input.shape_n, BLOCK_N);
auto current_n_block_cumsum = curr_cumsum + num_n_blocks;
if (next_block_idx < current_n_block_cumsum * kNumMBlocks)
break;
// Move to check the next group
curr_group_idx++;
curr_cumsum = current_n_block_cumsum;
}
// Note: Here, m and n roles are swapped
get_swizzled_block_idx<kNumTMAMulticast, kNumMBlocks, kNumMBlocksPerGroup>(
num_n_blocks, next_block_idx - curr_cumsum * kNumMBlocks, n_block_idx, m_block_idx);
return true;
}
};
template <GemmType GT, uint32_t SHAPE_N, uint32_t SHAPE_K, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups, uint32_t kNumTMAMulticast, uint32_t kNumNBlocks = ceil_div(SHAPE_N, BLOCK_N),
uint32_t kNumNBlocksPerGroup = 16>
@ -480,6 +733,26 @@ struct SchedulerSelector
using type = decltype(select_type());
};
template <GemmType GT, uint32_t SHAPE_M, uint32_t SHAPE_K, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups, uint32_t kNumTMAMulticast, uint32_t kNumMBlocks = ceil_div(SHAPE_M, BLOCK_M),
uint32_t kNumMBlocksPerGroup = 16>
struct SchedulerSelectorSwapAB
{
static constexpr auto select_type()
{
static_assert(GT == GemmType::GroupedWithOffset || GT == GemmType::Normal,
"Only GroupedWithOffset and Normal are supported for SwapAB");
if constexpr (GT == GemmType::Normal)
return NormalSchedulerSwapAB<SHAPE_M, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kNumMBlocks,
kNumMBlocksPerGroup>();
if constexpr (GT == GemmType::GroupedWithOffset)
return GroupedWithOffsetSchedulerSwapAB<SHAPE_M, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast,
kNumMBlocks, kNumMBlocksPerGroup>();
}
using type = decltype(select_type());
};
#pragma clang diagnostic pop
} // namespace deep_gemm

View File

@ -32,6 +32,7 @@
#include <cuda/barrier>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <stdexcept>
#endif
#include "utils.cuh"

View File

@ -1626,16 +1626,33 @@ void gemm_dispatch(void* mat_a, int ld_a, void* mat_b, int ld_b, void* mat_d, in
constexpr uint32_t block_k = 128;
constexpr uint32_t num_problems = 1;
// Select the best configuration based on shape dimensions
auto [best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size]
= deep_gemm::jit::get_best_gemm_config(shape_m, shape_n, shape_k, num_problems, num_device_sms);
uint32_t m_threshold = 32;
if (shape_m >= m_threshold)
{
// Select the best configuration based on shape dimensions
auto [best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size]
= deep_gemm::jit::get_best_gemm_config(shape_m, shape_n, shape_k, num_problems, num_device_sms);
auto runtime = deep_gemm::jit::getGlobalCompiler().build(shape_n, shape_k, best_block_m, best_block_n, block_k,
num_problems, best_num_stages, best_num_tma_multicast, deep_gemm::GemmType::Normal);
auto kernel = reinterpret_cast<cudaKernel_t>(runtime->getKernel());
deep_gemm::runGemm(kernel, mat_a, ld_a, mat_b, ld_b, mat_d, ld_d, scales_a, scales_b, shape_m, shape_n, shape_k,
best_block_m, best_block_n, block_k, num_problems, best_num_tma_multicast, deep_gemm::GemmType::Normal,
static_cast<int*>(nullptr), stream, num_device_sms, static_cast<uint32_t>(best_smem_size));
auto runtime = deep_gemm::jit::getGlobalCompiler().build(shape_n, shape_k, best_block_m, best_block_n, block_k,
num_problems, best_num_stages, best_num_tma_multicast, deep_gemm::GemmType::Normal);
auto kernel = reinterpret_cast<cudaKernel_t>(runtime->getKernel());
deep_gemm::runGemm(kernel, mat_a, ld_a, mat_b, ld_b, mat_d, ld_d, scales_a, scales_b, shape_m, shape_n, shape_k,
best_block_m, best_block_n, block_k, num_problems, best_num_tma_multicast, deep_gemm::GemmType::Normal,
static_cast<int*>(nullptr), stream, num_device_sms, static_cast<uint32_t>(best_smem_size));
}
else
{
auto [best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size]
= deep_gemm::jit::get_best_gemm_config(
shape_n, shape_m, shape_k, num_problems, num_device_sms, false, true);
auto runtime = deep_gemm::jit::getGlobalCompiler().build(shape_n, shape_k, best_block_m, best_block_n, block_k,
num_problems, best_num_stages, best_num_tma_multicast, deep_gemm::GemmType::Normal, true);
auto kernel = reinterpret_cast<cudaKernel_t>(runtime->getKernel());
deep_gemm::runGemmSwapAB(kernel, mat_b, ld_b, mat_a, ld_a, mat_d, ld_d, scales_b, scales_a, shape_n, shape_m,
shape_k, best_block_m, best_block_n, block_k, num_problems, best_num_tma_multicast,
deep_gemm::GemmType::Normal, static_cast<int*>(nullptr), stream, num_device_sms,
static_cast<uint32_t>(best_smem_size));
}
}
void fp8_gemm_run(__nv_fp8_e4m3* mat_a, int ld_a, __nv_fp8_e4m3* mat_b, int ld_b, __nv_bfloat16* mat_d, int ld_d,
@ -1693,16 +1710,34 @@ void grouped_gemm_dispatch(__nv_fp8_e4m3* mat_a, __nv_fp8_e4m3* mat_b, __nv_bflo
}
constexpr uint32_t block_k = 128;
auto [best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size]
= deep_gemm::jit::get_best_gemm_config(expected_m, shape_n, shape_k, num_problems, num_device_sms);
uint32_t m_per_expert_threshold = num_device_sms == 78 ? 64 : 32; // 64 for H20(sms=78), 32 for H100/H200
if (expected_m >= m_per_expert_threshold)
{
auto [best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size]
= deep_gemm::jit::get_best_gemm_config(expected_m, shape_n, shape_k, num_problems, num_device_sms);
auto runtime = deep_gemm::jit::getGlobalCompiler().build(shape_n, shape_k, best_block_m, best_block_n, block_k,
num_problems, best_num_stages, best_num_tma_multicast, deep_gemm::GemmType::GroupedWithOffset);
auto kernel = reinterpret_cast<cudaKernel_t>(runtime->getKernel());
deep_gemm::runGemm(kernel, mat_a, 0, mat_b, 0, mat_d, 0, scales_a, scales_b, max_shape_m, shape_n, shape_k,
best_block_m, best_block_n, block_k, num_problems, best_num_tma_multicast,
deep_gemm::GemmType::GroupedWithOffset, const_cast<int64_t*>(problem_m_offsets), problem_m_padded_offsets,
stream, num_device_sms, static_cast<uint32_t>(best_smem_size), max_shape_m_padded);
auto runtime = deep_gemm::jit::getGlobalCompiler().build(shape_n, shape_k, best_block_m, best_block_n, block_k,
num_problems, best_num_stages, best_num_tma_multicast, deep_gemm::GemmType::GroupedWithOffset);
auto kernel = reinterpret_cast<cudaKernel_t>(runtime->getKernel());
deep_gemm::runGemm(kernel, mat_a, 0, mat_b, 0, mat_d, 0, scales_a, scales_b, max_shape_m, shape_n, shape_k,
best_block_m, best_block_n, block_k, num_problems, best_num_tma_multicast,
deep_gemm::GemmType::GroupedWithOffset, const_cast<int64_t*>(problem_m_offsets), problem_m_padded_offsets,
stream, num_device_sms, static_cast<uint32_t>(best_smem_size), max_shape_m_padded);
}
else
{
auto [best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size]
= deep_gemm::jit::get_best_gemm_config(
shape_n, expected_m, shape_k, num_problems, num_device_sms, false, true);
auto runtime = deep_gemm::jit::getGlobalCompiler().build(shape_n, shape_k, best_block_m, best_block_n, block_k,
num_problems, best_num_stages, best_num_tma_multicast, deep_gemm::GemmType::GroupedWithOffset, true);
auto kernel = reinterpret_cast<cudaKernel_t>(runtime->getKernel());
deep_gemm::runGemmSwapAB(kernel, mat_b, 0, mat_a, 0, mat_d, 0, scales_b, scales_a, shape_n, max_shape_m,
shape_k, best_block_m, best_block_n, block_k, num_problems, best_num_tma_multicast,
deep_gemm::GemmType::GroupedWithOffset, const_cast<int64_t*>(problem_m_offsets), problem_m_padded_offsets,
stream, num_device_sms, static_cast<uint32_t>(best_smem_size), max_shape_m_padded);
}
}
void fp8_grouped_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_mat_a, float* scales_a,

View File

@ -1,7 +1,7 @@
# DeepSeekV3 and DeepSeek-R1
This guide walks you through the examples to run the DeepSeekV3/DeepSeek-R1 models using NVIDIA's TensorRT-LLM framework with the PyTorch backend.
**DeepSeek-R1 and DeepSeek-V3 share exact same model architecture other than weights differences, and share same code path in TensorRT-LLM, for brevity we only provide one model example, the example command to be used interchangeablely by only replacing the model name to the other one**.
**DeepSeek-R1 and DeepSeek-V3 share exact same model architecture other than weights differences, and share same code path in TensorRT-LLM, for brevity we only provide one model example, the example command to be used interchangeably by only replacing the model name to the other one**.
To benchmark the model with best configurations, refer to [DeepSeek R1 benchmarking blog](../../../../docs/source/blogs/Best_perf_practice_on_DeepSeek-R1_in_TensorRT-LLM.md).
@ -509,6 +509,25 @@ DeepGEMM-related behavior can be controlled by the following environment variabl
| `TRTLLM_DG_JIT_USE_NVCC` | When set to `1`, use NVCC instead of NVRTC to compile the kernel, which has slightly better performance but requires CUDA Toolkit (>=12.3) and longer compilation time.|
| `TRTLLM_DG_JIT_DUMP_CUBIN` | When set to `1`, dump the cubin file. This is only effective with NVRTC since NVCC will always dump the cubin file. NVRTC-based JIT will store the generated kernels in memory by default. If you want to persist the kernels across multiple runs, you can either use this variable or use NVCC. |
#### MOE GEMM Optimization
For Mixture of Experts (MOE) GEMM operations, TensorRT-LLM's DeepGEMM includes the optimized `fp8_gemm_kernel_swapAB` kernel. This kernel is automatically selected based on the input dimensions and GPU type:
- On H20 GPUs (SM count = 78): Uses `fp8_gemm_kernel_swapAB` when the expected m_per_expert is less than 64
- On H100/H200 GPUs: Uses `fp8_gemm_kernel_swapAB` when the expected m_per_expert is less than 32
- Otherwise, uses the original `fp8_gemm_kernel`
This automatic selection provides better performance for different workload sizes across various Hopper GPUs. In our test cases, the `fp8_gemm_kernel_swapAB` kernel achieves up to 1.8x speedup for individual kernels on H20 GPUs and up to 1.3x speedup on H100 GPUs.
#### Dense GEMM Optimization
The same optimization has been extended to Dense GEMM operations. For regular dense matrix multiplications:
- On all Hopper GPUs (H20, H100, H200): Uses `fp8_gemm_kernel_swapAB` when the m is less than 32
- Otherwise, uses the original `fp8_gemm_kernel`
This optimization delivers significant performance improvements for small batch sizes. Our benchmarks show that the `fp8_gemm_kernel_swapAB` kernel achieves up to 1.7x speedup on H20 GPUs and up to 1.8x speedup on H100 GPUs for certain matrix dimensions.
```bash
#single-node
trtllm-bench \

View File

@ -31,7 +31,7 @@ from utils.util import getSMVersion
)
@pytest.mark.parametrize(
"m",
[7, 64, 128, 4096],
[7, 16, 64, 128, 4096],
)
def test_fp8_block_scale_gemm(m, k, n):
torch.random.manual_seed(0)
@ -162,7 +162,8 @@ def construct_batched(
getSMVersion() != 90,
reason="Op only supported on Hopper, current SM is %d." % getSMVersion(),
)
@pytest.mark.parametrize("ms", [[256, 256], [128, 64, 64], [16, 24, 48]])
@pytest.mark.parametrize(
"ms", [[256, 256], [128, 64, 64], [16, 24, 48], [4, 8, 16, 32]])
@pytest.mark.parametrize("k, n", [(7168, 4096), (2048, 7168)])
def test_fp8_block_scaling_moe_gemm(ms, k, n):
offset_cpu = [0] + list(itertools.accumulate(ms))