add 3xfp4 cutlass gemm

Signed-off-by: Xiwen Yu <xiweny@nvidia.com>
This commit is contained in:
Xiwen Yu 2025-09-05 00:06:41 -07:00
parent 9ae01a8edb
commit 973fd37457
6 changed files with 283 additions and 183 deletions

View File

@ -23,38 +23,55 @@ namespace kernels
namespace cutlass_kernels
{
#ifdef ENABLE_BF16
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 64, 128, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 64, 128, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 64, 128, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 64, 128, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 64, 128, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 64, 128, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 64, 128, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 64, 128, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 128, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 128, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 128, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 128, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 128, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 128, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 128, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 128, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 128, 256, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 128, 256, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 128, 256, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 128, 256, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 128, 256, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 128, 256, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 128, 256, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 128, 256, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 256, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 256, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 256, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 256, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 256, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 256, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 256, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 256, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 64, 128, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 64, 128, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 64, 128, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 64, 128, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 64, 128, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 64, 128, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 64, 128, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 64, 128, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 128, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 128, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 128, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 128, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 128, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 128, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 128, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 128, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 128, 256, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 128, 256, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 128, 256, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 128, 256, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 128, 256, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 128, 256, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 128, 256, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 128, 256, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 256, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 256, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 256, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 256, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 256, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 256, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 256, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 256, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 128, 256, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 128, 256, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 128, 256, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 128, 256, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 128, 256, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 128, 256, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 128, 256, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 128, 256, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 256, 256, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 256, 256, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 256, 256, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 256, 256, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 256, 256, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 256, 256, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 256, 256, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 256, 256, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER_SM120(__nv_bfloat16, 128, 128, 128, 1, 1, 1)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER_SM120(__nv_bfloat16, 128, 128, 256, 1, 1, 1)

View File

@ -22,38 +22,55 @@ namespace kernels
{
namespace cutlass_kernels
{
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 64, 128, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 64, 128, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 64, 128, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 64, 128, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 64, 128, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 64, 128, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 64, 128, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 64, 128, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 128, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 128, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 128, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 128, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 128, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 128, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 128, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 128, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 128, 256, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 128, 256, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 128, 256, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 128, 256, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 128, 256, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 128, 256, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 128, 256, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 128, 256, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 256, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 256, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 256, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 256, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 256, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 256, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 256, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 256, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 64, 128, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 64, 128, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 64, 128, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 64, 128, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 64, 128, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 64, 128, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 64, 128, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 64, 128, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 128, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 128, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 128, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 128, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 128, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 128, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 128, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 128, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 128, 256, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 128, 256, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 128, 256, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 128, 256, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 128, 256, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 128, 256, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 128, 256, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 128, 256, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 256, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 256, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 256, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 256, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 256, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 256, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 256, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 256, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 128, 256, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 128, 256, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 128, 256, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 128, 256, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 128, 256, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 128, 256, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 128, 256, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 128, 256, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 256, 256, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 256, 256, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 256, 256, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 256, 256, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 256, 256, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 256, 256, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 256, 256, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 256, 256, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER_SM120(half, 128, 128, 128, 1, 1, 1)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER_SM120(half, 128, 128, 256, 1, 1, 1)

View File

@ -22,38 +22,55 @@ namespace kernels
{
namespace cutlass_kernels
{
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 64, 128, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 64, 128, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 64, 128, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 64, 128, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 64, 128, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 64, 128, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 64, 128, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 64, 128, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 128, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 128, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 128, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 128, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 128, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 128, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 128, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 128, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 128, 256, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 128, 256, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 128, 256, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 128, 256, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 128, 256, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 128, 256, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 128, 256, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 128, 256, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 256, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 256, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 256, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 256, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 256, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 256, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 256, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 256, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 64, 128, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 64, 128, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 64, 128, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 64, 128, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 64, 128, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 64, 128, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 64, 128, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 64, 128, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 128, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 128, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 128, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 128, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 128, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 128, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 128, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 128, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 128, 256, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 128, 256, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 128, 256, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 128, 256, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 128, 256, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 128, 256, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 128, 256, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 128, 256, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 256, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 256, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 256, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 256, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 256, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 256, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 256, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 256, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 128, 256, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 128, 256, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 128, 256, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 128, 256, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 128, 256, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 128, 256, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 128, 256, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 128, 256, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 256, 256, 1, 1, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 256, 256, 1, 2, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 256, 256, 1, 4, 1, _1SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 256, 256, 2, 1, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 256, 256, 2, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 256, 256, 2, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 256, 256, 4, 2, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 256, 256, 4, 4, 1, _2SM)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER_SM120(float, 128, 128, 128, 1, 1, 1)
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER_SM120(float, 128, 128, 256, 1, 1, 1)

View File

@ -16,6 +16,7 @@
#pragma once
#include <type_traits>
#ifndef _WIN32
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
@ -52,8 +53,8 @@ using namespace cute;
namespace tk = tensorrt_llm::common;
namespace tkc = tensorrt_llm::cutlass_extensions;
template <typename T, typename CTA_M_, typename CTA_N_, typename CTA_K_>
size_t dispatchNVFP4xNVFP4GemmClusterShapeSm100(T* D, void const* A, void const* B, void const* input_sf,
template <typename Arch, typename T, typename CTA_M_, typename CTA_N_, typename CTA_K_>
size_t dispatchNVFP4xNVFP4GemmClusterShapeSm10x(T* D, void const* A, void const* B, void const* input_sf,
void const* weight_sf, float const* global_sf, int m, int n, int k, int batch_count,
tkc::CutlassGemmConfig gemmConfig, char* workspace, const size_t workspaceBytes, cudaStream_t stream,
int* occupancy = nullptr)
@ -64,43 +65,43 @@ size_t dispatchNVFP4xNVFP4GemmClusterShapeSm100(T* D, void const* A, void const*
switch (gemmConfig.cluster_shape)
{
case tkc::ClusterShape::ClusterShape_1x1x1:
return genericFp4GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<1>, cute::Int<1>, cute::Int<1>, _1SM>(
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
return genericFp4GemmKernelLauncher<Arch, T, CTA_M_, CTA_N_, CTA_K_, cute::Int<1>, cute::Int<1>, cute::Int<1>,
_1SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
stream, occupancy);
break;
case tkc::ClusterShape::ClusterShape_2x1x1:
return genericFp4GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<2>, cute::Int<1>, cute::Int<1>, _2SM>(
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
return genericFp4GemmKernelLauncher<Arch, T, CTA_M_, CTA_N_, CTA_K_, cute::Int<2>, cute::Int<1>, cute::Int<1>,
_2SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
stream, occupancy);
break;
case tkc::ClusterShape::ClusterShape_1x2x1:
return genericFp4GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<1>, cute::Int<2>, cute::Int<1>, _1SM>(
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
return genericFp4GemmKernelLauncher<Arch, T, CTA_M_, CTA_N_, CTA_K_, cute::Int<1>, cute::Int<2>, cute::Int<1>,
_1SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
stream, occupancy);
break;
case tkc::ClusterShape::ClusterShape_2x2x1:
return genericFp4GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<2>, cute::Int<2>, cute::Int<1>, _2SM>(
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
return genericFp4GemmKernelLauncher<Arch, T, CTA_M_, CTA_N_, CTA_K_, cute::Int<2>, cute::Int<2>, cute::Int<1>,
_2SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
stream, occupancy);
break;
case tkc::ClusterShape::ClusterShape_1x4x1:
return genericFp4GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<1>, cute::Int<4>, cute::Int<1>, _1SM>(
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
return genericFp4GemmKernelLauncher<Arch, T, CTA_M_, CTA_N_, CTA_K_, cute::Int<1>, cute::Int<4>, cute::Int<1>,
_1SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
stream, occupancy);
break;
case tkc::ClusterShape::ClusterShape_4x2x1:
return genericFp4GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<4>, cute::Int<2>, cute::Int<1>, _2SM>(
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
return genericFp4GemmKernelLauncher<Arch, T, CTA_M_, CTA_N_, CTA_K_, cute::Int<4>, cute::Int<2>, cute::Int<1>,
_2SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
stream, occupancy);
break;
case tkc::ClusterShape::ClusterShape_2x4x1:
return genericFp4GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<2>, cute::Int<4>, cute::Int<1>, _2SM>(
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
return genericFp4GemmKernelLauncher<Arch, T, CTA_M_, CTA_N_, CTA_K_, cute::Int<2>, cute::Int<4>, cute::Int<1>,
_2SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
stream, occupancy);
break;
case tkc::ClusterShape::ClusterShape_4x4x1:
return genericFp4GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<4>, cute::Int<4>, cute::Int<1>, _2SM>(
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
return genericFp4GemmKernelLauncher<Arch, T, CTA_M_, CTA_N_, CTA_K_, cute::Int<4>, cute::Int<4>, cute::Int<1>,
_2SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
stream, occupancy);
break;
default:
@ -110,8 +111,8 @@ size_t dispatchNVFP4xNVFP4GemmClusterShapeSm100(T* D, void const* A, void const*
}
}
template <typename T>
size_t dispatchNVFP4xNVFP4GemmCTAShapeSm100(T* D, void const* A, void const* B, void const* input_sf,
template <typename Arch, typename T>
size_t dispatchNVFP4xNVFP4GemmCTAShapeSm10x(T* D, void const* A, void const* B, void const* input_sf,
void const* weight_sf, float const* global_sf, int m, int n, int k, int batch_count,
tkc::CutlassGemmConfig gemmConfig, char* workspace, const size_t workspaceBytes, cudaStream_t stream,
int* occupancy = nullptr)
@ -123,39 +124,48 @@ size_t dispatchNVFP4xNVFP4GemmCTAShapeSm100(T* D, void const* A, void const* B,
// M-mode size should be 128 or 256 for 2 CTA cluster MMA;
// M-mode size should be 128 for 1 CTA cluster OMMA.
// K256 looks to be better than K128
switch (gemmConfig.tile_config_sm100)
#define CTA_CASE(M, N, K) \
case tkc::CutlassTileConfigSM100::CtaShape##M##x##N##x##K##B: \
return dispatchNVFP4xNVFP4GemmClusterShapeSm10x<Arch, T, cute::Int<M>, cute::Int<N>, cute::Int<K>>(D, A, B, \
input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, \
occupancy);
#define CTA_CASE_DEFAULT \
case tkc::CutlassTileConfigSM100::Undefined: \
throw std::runtime_error("[TensorRT-LLM Error][FP4][dispatch_gemm_cta_shape] Gemm config undefined."); \
break; \
case tkc::CutlassTileConfigSM100::ChooseWithHeuristic: \
throw std::runtime_error( \
"[TensorRT-LLM Error][FP4][dispatch_gemm_cta_shape] Gemm config should have already been set by " \
"heuristic."); \
break; \
default: \
throw std::runtime_error( \
"[TensorRT-LLM Error][FP4][dispatch_gemm_cta_shape] Config is invalid for FP4 GEMM."); \
break;
if constexpr (std::is_same_v<Arch, cutlass::arch::Sm100>)
{
switch (gemmConfig.tile_config_sm100)
{
CTA_CASE(128, 64, 128)
CTA_CASE(128, 256, 128)
CTA_CASE(128, 128, 256)
CTA_CASE(128, 256, 256)
CTA_CASE_DEFAULT
}
}
else if constexpr (std::is_same_v<Arch, cutlass::arch::Sm103>)
{
switch (gemmConfig.tile_config_sm100)
{
CTA_CASE(128, 128, 256)
CTA_CASE(128, 256, 256)
CTA_CASE_DEFAULT
}
}
else
{
case tkc::CutlassTileConfigSM100::CtaShape128x64x128B:
return dispatchNVFP4xNVFP4GemmClusterShapeSm100<T, cute::Int<128>, cute::Int<64>, cute::Int<128>>(D, A, B,
input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream,
occupancy);
break;
case tkc::CutlassTileConfigSM100::CtaShape128x256x128B:
return dispatchNVFP4xNVFP4GemmClusterShapeSm100<T, cute::Int<128>, cute::Int<256>, cute::Int<128>>(D, A, B,
input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream,
occupancy);
break;
case tkc::CutlassTileConfigSM100::CtaShape128x128x256B:
return dispatchNVFP4xNVFP4GemmClusterShapeSm100<T, cute::Int<128>, cute::Int<128>, cute::Int<256>>(D, A, B,
input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream,
occupancy);
break;
case tkc::CutlassTileConfigSM100::CtaShape128x256x256B:
return dispatchNVFP4xNVFP4GemmClusterShapeSm100<T, cute::Int<128>, cute::Int<256>, cute::Int<256>>(D, A, B,
input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream,
occupancy);
break;
case tkc::CutlassTileConfigSM100::Undefined:
throw std::runtime_error("[TensorRT-LLM Error][FP4][dispatch_gemm_cta_shape] Gemm config undefined.");
break;
case tkc::CutlassTileConfigSM100::ChooseWithHeuristic:
throw std::runtime_error(
"[TensorRT-LLM Error][FP4][dispatch_gemm_cta_shape] Gemm config should have already been set by "
"heuristic.");
break;
default:
throw std::runtime_error("[TensorRT-LLM Error][FP4][dispatch_gemm_cta_shape] Config is invalid for FP4 GEMM.");
break;
"[TensorRT-LLM Error][FP4][dispatch_gemm_cta_shape] Architecture not supported for FP4 GEMM.");
}
}
@ -343,10 +353,20 @@ size_t CutlassFp4GemmRunner<T, fp4GemmType>::dispatchToArch(T* D, void const* A,
}
else if constexpr (fp4GemmType == FP4GemmType::W4A4_NVFP4_NVFP4)
{
if (mSm == 100 || mSm == 103)
if (mSm == 103)
{
return dispatchNVFP4xNVFP4GemmCTAShapeSm100<T>(D, A, B, input_sf, weight_sf, global_sf, m, n, k,
batch_count, gemmConfig, workspace, workspaceBytes, stream, occupancy);
#ifdef COMPILE_BLACKWELL_SM103_TMA_GEMMS
return dispatchNVFP4xNVFP4GemmCTAShapeSm10x<cutlass::arch::Sm103, T>(D, A, B, input_sf, weight_sf,
global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, occupancy);
#else
return dispatchNVFP4xNVFP4GemmCTAShapeSm10x<cutlass::arch::Sm100, T>(D, A, B, input_sf, weight_sf,
global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, occupancy);
#endif
}
else if (mSm == 100)
{
return dispatchNVFP4xNVFP4GemmCTAShapeSm10x<cutlass::arch::Sm100, T>(D, A, B, input_sf, weight_sf,
global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, occupancy);
}
else if (mSm == 120 || mSm == 121)
{
@ -386,12 +406,15 @@ std::vector<tkc::CutlassGemmConfig> CutlassFp4GemmRunner<T, fp4GemmType>::getCon
if (mSm == 100 || mSm == 103)
{
std::vector<tkc::CutlassTileConfigSM100> tilesSm100 = {
tkc::CutlassTileConfigSM100::CtaShape128x64x128B,
tkc::CutlassTileConfigSM100::CtaShape128x256x128B,
std::vector<tkc::CutlassTileConfigSM100> tilesSm10x = {
tkc::CutlassTileConfigSM100::CtaShape128x128x256B,
tkc::CutlassTileConfigSM100::CtaShape128x256x256B,
};
if (mSm == 100)
{
tilesSm10x.push_back(tkc::CutlassTileConfigSM100::CtaShape128x64x128B);
tilesSm10x.push_back(tkc::CutlassTileConfigSM100::CtaShape128x256x128B);
}
std::vector<tkc::ClusterShape> clusterShapes = {
tkc::ClusterShape::ClusterShape_1x1x1,
tkc::ClusterShape::ClusterShape_1x2x1,
@ -402,7 +425,7 @@ std::vector<tkc::CutlassGemmConfig> CutlassFp4GemmRunner<T, fp4GemmType>::getCon
tkc::ClusterShape::ClusterShape_2x4x1,
tkc::ClusterShape::ClusterShape_4x4x1,
};
for (auto const& tile_config : tilesSm100)
for (auto const& tile_config : tilesSm10x)
{
for (auto const& cluster_config : clusterShapes)
{
@ -417,7 +440,7 @@ std::vector<tkc::CutlassGemmConfig> CutlassFp4GemmRunner<T, fp4GemmType>::getCon
}
}
CutlassGemmConfig config(
tile_config, tkc::MainloopScheduleType::AUTO, tkc::EpilogueScheduleType::AUTO, cluster_config);
tile_config, tkc::MainloopScheduleType::AUTO, tkc::EpilogueScheduleType::AUTO, cluster_config, mSm);
candidateConfigs.push_back(config);
}
}

View File

@ -60,13 +60,13 @@ struct _2SM
{
};
template <typename T>
template <typename Arch, typename T>
struct SMTypeAdapter
{
};
template <>
struct SMTypeAdapter<_1SM>
struct SMTypeAdapter<cutlass::arch::Sm100, _1SM>
{
static int const Scale = 1;
using AtomThrShape = cute::Shape<_1, _1, _1>;
@ -75,7 +75,7 @@ struct SMTypeAdapter<_1SM>
};
template <>
struct SMTypeAdapter<_2SM>
struct SMTypeAdapter<cutlass::arch::Sm100, _2SM>
{
static int const Scale = 2;
using AtomThrShape = cute::Shape<_2, _1, _1>;
@ -83,11 +83,29 @@ struct SMTypeAdapter<_2SM>
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100;
};
template <>
struct SMTypeAdapter<cutlass::arch::Sm103, _1SM>
{
static int const Scale = 1;
using AtomThrShape = cute::Shape<_1, _1, _1>;
using EpilogueSchedule = cutlass::epilogue::NoSmemWarpSpecialized1Sm;
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103;
};
template <>
struct SMTypeAdapter<cutlass::arch::Sm103, _2SM>
{
static int const Scale = 2;
using AtomThrShape = cute::Shape<_2, _1, _1>;
using EpilogueSchedule = cutlass::epilogue::NoSmemWarpSpecialized2Sm;
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103;
};
template <typename>
constexpr auto always_false = false;
template <typename T, typename CTA_M_, typename CTA_N_, typename CTA_K_, typename CGA_M_, typename CGA_N_,
typename CGA_K_, typename XSM_>
template <typename Arch, typename T, typename CTA_M_, typename CTA_N_, typename CTA_K_, typename CGA_M_,
typename CGA_N_, typename CGA_K_, typename XSM_>
size_t genericFp4GemmKernelLauncher(void* D, void const* A, void const* B, void const* input_sf, void const* weight_sf,
float const* global_sf, int m, int n, int k, int batch_count, tkc::CutlassGemmConfig gemmConfig, char* workspace,
size_t const workspaceBytes, cudaStream_t stream, int* occupancy)
@ -98,13 +116,13 @@ size_t genericFp4GemmKernelLauncher(void* D, void const* A, void const* B, void
#ifdef PLACEHOLDER_KERNELS
#define INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(T, CTA_M_, CTA_N_, CTA_K_, CGA_M_, CGA_N_, CGA_K_, XSM_) \
#define INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(ARCH_, T, CTA_M_, CTA_N_, CTA_K_, CGA_M_, CGA_N_, CGA_K_, XSM_) \
template <> \
size_t genericFp4GemmKernelLauncher<T, cute::Int<CTA_M_>, cute::Int<CTA_N_>, cute::Int<CTA_K_>, cute::Int<CGA_M_>, \
cute::Int<CGA_N_>, cute::Int<CGA_K_>, XSM_>(void* D, void const* A, void const* B, void const* input_sf, \
void const* weight_sf, float const* global_sf, int m, int n, int k, int batch_count, \
tkc::CutlassGemmConfig gemmConfig, char* workspace, const size_t workspaceBytes, cudaStream_t stream, \
int* occupancy) \
size_t genericFp4GemmKernelLauncher<cutlass::arch::ARCH_, T, cute::Int<CTA_M_>, cute::Int<CTA_N_>, \
cute::Int<CTA_K_>, cute::Int<CGA_M_>, cute::Int<CGA_N_>, cute::Int<CGA_K_>, XSM_>(void* D, void const* A, \
void const* B, void const* input_sf, void const* weight_sf, float const* global_sf, int m, int n, int k, \
int batch_count, tkc::CutlassGemmConfig gemmConfig, char* workspace, const size_t workspaceBytes, \
cudaStream_t stream, int* occupancy) \
{ \
throw std::runtime_error( \
"[TensorRT-LLM Error][FP4 gemm Runner] TensorRT-LLM is not compiled with support for this Architecture."); \
@ -112,15 +130,15 @@ size_t genericFp4GemmKernelLauncher(void* D, void const* A, void const* B, void
#else
#define INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(T, CTA_M_, CTA_N_, CTA_K_, CGA_M_, CGA_N_, CGA_K_, XSM_) \
struct DeviceGemmFp4GemmSm100_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_ \
#define INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(ARCH_, T, CTA_M_, CTA_N_, CTA_K_, CGA_M_, CGA_N_, CGA_K_, XSM_) \
struct DeviceGemmFp4Gemm##ARCH_##_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_ \
{ \
using OutElementType = TllmToCutlassTypeAdapter<T>::type; \
using CTAShape = cute::Shape<cute::Int<CTA_M_>, cute::Int<CTA_N_>, cute::Int<CTA_K_>>; \
/*using ClusterShape = cute::Shape<cute::Int<CGA_M_>, cute::Int<CGA_N_>, cute::Int<CGA_K_>>;*/ \
using ClusterShape = cute::Shape<int, int, _1>; \
using ElementType = cutlass::float_e2m1_t; \
using Arch = cutlass::arch::Sm100; \
using Arch = cutlass::arch::ARCH_; \
/* // Input A */ \
using ElementA = ElementType; \
using LayoutA = cutlass::layout::RowMajor; \
@ -140,10 +158,10 @@ size_t genericFp4GemmKernelLauncher(void* D, void const* A, void const* B, void
using OperatorClass = cutlass::arch::OpClassTensorOp; \
using EpilogueTileType = std::conditional_t<CTA_M_ == 128 && CTA_N_ == 256 && CTA_K_ == 256, \
cute::Shape<cute::_128, cute::_64>, cutlass::epilogue::collective::EpilogueTileAuto>; \
using EpilogueSchedule = SMTypeAdapter<XSM_>::EpilogueSchedule; \
using MainloopSchedule = SMTypeAdapter<XSM_>::MainloopSchedule; \
using MmaTileShape \
= cute::Shape<cute::Int<CTA_M_ * SMTypeAdapter<XSM_>::Scale>, cute::Int<CTA_N_>, cute::Int<CTA_K_>>; \
using EpilogueSchedule = SMTypeAdapter<Arch, XSM_>::EpilogueSchedule; \
using MainloopSchedule = SMTypeAdapter<Arch, XSM_>::MainloopSchedule; \
using MmaTileShape = cute::Shape<cute::Int<CTA_M_ * SMTypeAdapter<Arch, XSM_>::Scale>, cute::Int<CTA_N_>, \
cute::Int<CTA_K_*(std::is_same_v<Arch, cutlass::arch::Sm103> ? 3 : 1)>>; \
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<Arch, OperatorClass, \
MmaTileShape, ClusterShape, EpilogueTileType, ElementAccumulator, ElementCompute, ElementC, LayoutC, \
AlignmentC, OutElementType, LayoutC, AlignmentC, EpilogueSchedule, \
@ -185,7 +203,7 @@ size_t genericFp4GemmKernelLauncher(void* D, void const* A, void const* B, void
\
template <typename Gemm> \
typename Gemm::Arguments \
prepareGemmArgs_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_(void* D, \
prepareGemmArgs_##ARCH_##_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_(void* D, \
void const* A, void const* B, void const* input_sf, void const* weight_sf, float const* global_sf, int m, \
int n, int k, int batch_count) \
{ \
@ -234,17 +252,18 @@ size_t genericFp4GemmKernelLauncher(void* D, void const* A, void const* B, void
operator_args.scheduler.raster_order = Enum_t::Heuristic; \
} \
operator_args.hw_info.cluster_shape = dim3(CGA_M_, CGA_N_, CGA_K_); \
operator_args.hw_info.cluster_shape_fallback = dim3(SMTypeAdapter<XSM_>::Scale, 1, 1); \
using Arch = cutlass::arch::ARCH_; \
operator_args.hw_info.cluster_shape_fallback = dim3(SMTypeAdapter<Arch, XSM_>::Scale, 1, 1); \
\
return operator_args; \
} \
\
template <> \
size_t genericFp4GemmKernelLauncher<T, cute::Int<CTA_M_>, cute::Int<CTA_N_>, cute::Int<CTA_K_>, cute::Int<CGA_M_>, \
cute::Int<CGA_N_>, cute::Int<CGA_K_>, XSM_>(void* D, void const* A, void const* B, void const* input_sf, \
void const* weight_sf, float const* global_sf, int m, int n, int k, int batch_count, \
tkc::CutlassGemmConfig gemmConfig, char* workspace, const size_t workspaceBytes, cudaStream_t stream, \
int* occupancy) \
size_t genericFp4GemmKernelLauncher<cutlass::arch::ARCH_, T, cute::Int<CTA_M_>, cute::Int<CTA_N_>, \
cute::Int<CTA_K_>, cute::Int<CGA_M_>, cute::Int<CGA_N_>, cute::Int<CGA_K_>, XSM_>(void* D, void const* A, \
void const* B, void const* input_sf, void const* weight_sf, float const* global_sf, int m, int n, int k, \
int batch_count, tkc::CutlassGemmConfig gemmConfig, char* workspace, const size_t workspaceBytes, \
cudaStream_t stream, int* occupancy) \
{ \
using ElementOutput__ = typename cutlass::platform::conditional<cutlass::platform::is_same<T, half>::value, \
cutlass::half_t, T>::type; \
@ -256,11 +275,12 @@ size_t genericFp4GemmKernelLauncher(void* D, void const* A, void const* B, void
cutlass::bfloat16_t, ElementOutput_>::type; \
\
using Fp4GemmOperator \
= DeviceGemmFp4GemmSm100_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_:: \
= DeviceGemmFp4Gemm##ARCH_##_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_:: \
Gemm; \
Fp4GemmOperator gemm; \
auto args = prepareGemmArgs_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_< \
Fp4GemmOperator>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count); \
auto args \
= prepareGemmArgs_##ARCH_##_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_< \
Fp4GemmOperator>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count); \
/* // Check shared memory size; throw when SMEM exceeds */ \
int smem_size = int(sizeof(typename Fp4GemmOperator::GemmKernel::SharedStorage)); \
static int mMaxSmemSize = tk::getMaxSharedMemoryPerBlockOptin(); \

View File

@ -64,6 +64,12 @@ tkc::CutlassGemmConfig getDefaultGemmConfig(int64_t m, int64_t n, int64_t k, FP4
tkc::MainloopScheduleType::AUTO, tkc::EpilogueScheduleType::AUTO,
tkc::ClusterShape::ClusterShape_1x1x1);
}
else if (sm == 103)
{
return tkc::CutlassGemmConfig(tkc::CutlassTileConfigSM100::CtaShape128x256x256B,
tkc::MainloopScheduleType::AUTO, tkc::EpilogueScheduleType::AUTO,
tkc::ClusterShape::ClusterShape_1x1x1);
}
else
{
return tkc::CutlassGemmConfig(tkc::CutlassTileConfigSM100::CtaShape128x256x128B,