mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-25 21:22:57 +08:00
add 3xfp4 cutlass gemm
Signed-off-by: Xiwen Yu <xiweny@nvidia.com>
This commit is contained in:
parent
9ae01a8edb
commit
973fd37457
@ -23,38 +23,55 @@ namespace kernels
|
||||
namespace cutlass_kernels
|
||||
{
|
||||
#ifdef ENABLE_BF16
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 64, 128, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 64, 128, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 64, 128, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 64, 128, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 64, 128, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 64, 128, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 64, 128, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 64, 128, 4, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 128, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 128, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 128, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 128, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 128, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 128, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 128, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 128, 4, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 128, 256, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 128, 256, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 128, 256, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 128, 256, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 128, 256, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 128, 256, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 128, 256, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 128, 256, 4, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 256, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 256, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 256, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 256, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 256, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 256, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 256, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 256, 4, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 64, 128, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 64, 128, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 64, 128, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 64, 128, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 64, 128, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 64, 128, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 64, 128, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 64, 128, 4, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 128, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 128, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 128, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 128, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 128, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 128, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 128, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 128, 4, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 128, 256, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 128, 256, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 128, 256, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 128, 256, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 128, 256, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 128, 256, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 128, 256, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 128, 256, 4, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 256, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 256, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 256, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 256, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 256, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 256, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 256, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 256, 4, 4, 1, _2SM)
|
||||
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 128, 256, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 128, 256, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 128, 256, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 128, 256, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 128, 256, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 128, 256, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 128, 256, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 128, 256, 4, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 256, 256, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 256, 256, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 256, 256, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 256, 256, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 256, 256, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 256, 256, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 256, 256, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 256, 256, 4, 4, 1, _2SM)
|
||||
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER_SM120(__nv_bfloat16, 128, 128, 128, 1, 1, 1)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER_SM120(__nv_bfloat16, 128, 128, 256, 1, 1, 1)
|
||||
|
||||
@ -22,38 +22,55 @@ namespace kernels
|
||||
{
|
||||
namespace cutlass_kernels
|
||||
{
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 64, 128, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 64, 128, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 64, 128, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 64, 128, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 64, 128, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 64, 128, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 64, 128, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 64, 128, 4, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 128, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 128, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 128, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 128, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 128, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 128, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 128, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 128, 4, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 128, 256, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 128, 256, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 128, 256, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 128, 256, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 128, 256, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 128, 256, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 128, 256, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 128, 256, 4, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 256, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 256, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 256, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 256, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 256, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 256, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 256, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 256, 4, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 64, 128, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 64, 128, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 64, 128, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 64, 128, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 64, 128, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 64, 128, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 64, 128, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 64, 128, 4, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 128, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 128, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 128, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 128, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 128, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 128, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 128, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 128, 4, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 128, 256, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 128, 256, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 128, 256, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 128, 256, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 128, 256, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 128, 256, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 128, 256, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 128, 256, 4, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 256, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 256, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 256, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 256, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 256, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 256, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 256, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 256, 4, 4, 1, _2SM)
|
||||
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 128, 256, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 128, 256, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 128, 256, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 128, 256, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 128, 256, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 128, 256, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 128, 256, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 128, 256, 4, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 256, 256, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 256, 256, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 256, 256, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 256, 256, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 256, 256, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 256, 256, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 256, 256, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 256, 256, 4, 4, 1, _2SM)
|
||||
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER_SM120(half, 128, 128, 128, 1, 1, 1)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER_SM120(half, 128, 128, 256, 1, 1, 1)
|
||||
|
||||
@ -22,38 +22,55 @@ namespace kernels
|
||||
{
|
||||
namespace cutlass_kernels
|
||||
{
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 64, 128, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 64, 128, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 64, 128, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 64, 128, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 64, 128, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 64, 128, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 64, 128, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 64, 128, 4, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 128, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 128, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 128, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 128, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 128, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 128, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 128, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 128, 4, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 128, 256, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 128, 256, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 128, 256, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 128, 256, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 128, 256, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 128, 256, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 128, 256, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 128, 256, 4, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 256, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 256, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 256, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 256, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 256, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 256, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 256, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 256, 4, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 64, 128, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 64, 128, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 64, 128, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 64, 128, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 64, 128, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 64, 128, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 64, 128, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 64, 128, 4, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 128, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 128, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 128, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 128, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 128, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 128, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 128, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 128, 4, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 128, 256, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 128, 256, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 128, 256, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 128, 256, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 128, 256, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 128, 256, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 128, 256, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 128, 256, 4, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 256, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 256, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 256, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 256, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 256, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 256, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 256, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 256, 4, 4, 1, _2SM)
|
||||
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 128, 256, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 128, 256, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 128, 256, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 128, 256, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 128, 256, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 128, 256, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 128, 256, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 128, 256, 4, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 256, 256, 1, 1, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 256, 256, 1, 2, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 256, 256, 1, 4, 1, _1SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 256, 256, 2, 1, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 256, 256, 2, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 256, 256, 2, 4, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 256, 256, 4, 2, 1, _2SM)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 256, 256, 4, 4, 1, _2SM)
|
||||
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER_SM120(float, 128, 128, 128, 1, 1, 1)
|
||||
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER_SM120(float, 128, 128, 256, 1, 1, 1)
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
#ifndef _WIN32
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
@ -52,8 +53,8 @@ using namespace cute;
|
||||
namespace tk = tensorrt_llm::common;
|
||||
namespace tkc = tensorrt_llm::cutlass_extensions;
|
||||
|
||||
template <typename T, typename CTA_M_, typename CTA_N_, typename CTA_K_>
|
||||
size_t dispatchNVFP4xNVFP4GemmClusterShapeSm100(T* D, void const* A, void const* B, void const* input_sf,
|
||||
template <typename Arch, typename T, typename CTA_M_, typename CTA_N_, typename CTA_K_>
|
||||
size_t dispatchNVFP4xNVFP4GemmClusterShapeSm10x(T* D, void const* A, void const* B, void const* input_sf,
|
||||
void const* weight_sf, float const* global_sf, int m, int n, int k, int batch_count,
|
||||
tkc::CutlassGemmConfig gemmConfig, char* workspace, const size_t workspaceBytes, cudaStream_t stream,
|
||||
int* occupancy = nullptr)
|
||||
@ -64,43 +65,43 @@ size_t dispatchNVFP4xNVFP4GemmClusterShapeSm100(T* D, void const* A, void const*
|
||||
switch (gemmConfig.cluster_shape)
|
||||
{
|
||||
case tkc::ClusterShape::ClusterShape_1x1x1:
|
||||
return genericFp4GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<1>, cute::Int<1>, cute::Int<1>, _1SM>(
|
||||
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
|
||||
return genericFp4GemmKernelLauncher<Arch, T, CTA_M_, CTA_N_, CTA_K_, cute::Int<1>, cute::Int<1>, cute::Int<1>,
|
||||
_1SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
|
||||
stream, occupancy);
|
||||
break;
|
||||
case tkc::ClusterShape::ClusterShape_2x1x1:
|
||||
return genericFp4GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<2>, cute::Int<1>, cute::Int<1>, _2SM>(
|
||||
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
|
||||
return genericFp4GemmKernelLauncher<Arch, T, CTA_M_, CTA_N_, CTA_K_, cute::Int<2>, cute::Int<1>, cute::Int<1>,
|
||||
_2SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
|
||||
stream, occupancy);
|
||||
break;
|
||||
case tkc::ClusterShape::ClusterShape_1x2x1:
|
||||
return genericFp4GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<1>, cute::Int<2>, cute::Int<1>, _1SM>(
|
||||
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
|
||||
return genericFp4GemmKernelLauncher<Arch, T, CTA_M_, CTA_N_, CTA_K_, cute::Int<1>, cute::Int<2>, cute::Int<1>,
|
||||
_1SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
|
||||
stream, occupancy);
|
||||
break;
|
||||
case tkc::ClusterShape::ClusterShape_2x2x1:
|
||||
return genericFp4GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<2>, cute::Int<2>, cute::Int<1>, _2SM>(
|
||||
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
|
||||
return genericFp4GemmKernelLauncher<Arch, T, CTA_M_, CTA_N_, CTA_K_, cute::Int<2>, cute::Int<2>, cute::Int<1>,
|
||||
_2SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
|
||||
stream, occupancy);
|
||||
break;
|
||||
case tkc::ClusterShape::ClusterShape_1x4x1:
|
||||
return genericFp4GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<1>, cute::Int<4>, cute::Int<1>, _1SM>(
|
||||
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
|
||||
return genericFp4GemmKernelLauncher<Arch, T, CTA_M_, CTA_N_, CTA_K_, cute::Int<1>, cute::Int<4>, cute::Int<1>,
|
||||
_1SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
|
||||
stream, occupancy);
|
||||
break;
|
||||
case tkc::ClusterShape::ClusterShape_4x2x1:
|
||||
return genericFp4GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<4>, cute::Int<2>, cute::Int<1>, _2SM>(
|
||||
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
|
||||
return genericFp4GemmKernelLauncher<Arch, T, CTA_M_, CTA_N_, CTA_K_, cute::Int<4>, cute::Int<2>, cute::Int<1>,
|
||||
_2SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
|
||||
stream, occupancy);
|
||||
break;
|
||||
case tkc::ClusterShape::ClusterShape_2x4x1:
|
||||
return genericFp4GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<2>, cute::Int<4>, cute::Int<1>, _2SM>(
|
||||
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
|
||||
return genericFp4GemmKernelLauncher<Arch, T, CTA_M_, CTA_N_, CTA_K_, cute::Int<2>, cute::Int<4>, cute::Int<1>,
|
||||
_2SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
|
||||
stream, occupancy);
|
||||
break;
|
||||
case tkc::ClusterShape::ClusterShape_4x4x1:
|
||||
return genericFp4GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<4>, cute::Int<4>, cute::Int<1>, _2SM>(
|
||||
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
|
||||
return genericFp4GemmKernelLauncher<Arch, T, CTA_M_, CTA_N_, CTA_K_, cute::Int<4>, cute::Int<4>, cute::Int<1>,
|
||||
_2SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
|
||||
stream, occupancy);
|
||||
break;
|
||||
default:
|
||||
@ -110,8 +111,8 @@ size_t dispatchNVFP4xNVFP4GemmClusterShapeSm100(T* D, void const* A, void const*
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
size_t dispatchNVFP4xNVFP4GemmCTAShapeSm100(T* D, void const* A, void const* B, void const* input_sf,
|
||||
template <typename Arch, typename T>
|
||||
size_t dispatchNVFP4xNVFP4GemmCTAShapeSm10x(T* D, void const* A, void const* B, void const* input_sf,
|
||||
void const* weight_sf, float const* global_sf, int m, int n, int k, int batch_count,
|
||||
tkc::CutlassGemmConfig gemmConfig, char* workspace, const size_t workspaceBytes, cudaStream_t stream,
|
||||
int* occupancy = nullptr)
|
||||
@ -123,39 +124,48 @@ size_t dispatchNVFP4xNVFP4GemmCTAShapeSm100(T* D, void const* A, void const* B,
|
||||
// M-mode size should be 128 or 256 for 2 CTA cluster MMA;
|
||||
// M-mode size should be 128 for 1 CTA cluster OMMA.
|
||||
// K256 looks to be better than K128
|
||||
switch (gemmConfig.tile_config_sm100)
|
||||
#define CTA_CASE(M, N, K) \
|
||||
case tkc::CutlassTileConfigSM100::CtaShape##M##x##N##x##K##B: \
|
||||
return dispatchNVFP4xNVFP4GemmClusterShapeSm10x<Arch, T, cute::Int<M>, cute::Int<N>, cute::Int<K>>(D, A, B, \
|
||||
input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, \
|
||||
occupancy);
|
||||
#define CTA_CASE_DEFAULT \
|
||||
case tkc::CutlassTileConfigSM100::Undefined: \
|
||||
throw std::runtime_error("[TensorRT-LLM Error][FP4][dispatch_gemm_cta_shape] Gemm config undefined."); \
|
||||
break; \
|
||||
case tkc::CutlassTileConfigSM100::ChooseWithHeuristic: \
|
||||
throw std::runtime_error( \
|
||||
"[TensorRT-LLM Error][FP4][dispatch_gemm_cta_shape] Gemm config should have already been set by " \
|
||||
"heuristic."); \
|
||||
break; \
|
||||
default: \
|
||||
throw std::runtime_error( \
|
||||
"[TensorRT-LLM Error][FP4][dispatch_gemm_cta_shape] Config is invalid for FP4 GEMM."); \
|
||||
break;
|
||||
if constexpr (std::is_same_v<Arch, cutlass::arch::Sm100>)
|
||||
{
|
||||
switch (gemmConfig.tile_config_sm100)
|
||||
{
|
||||
CTA_CASE(128, 64, 128)
|
||||
CTA_CASE(128, 256, 128)
|
||||
CTA_CASE(128, 128, 256)
|
||||
CTA_CASE(128, 256, 256)
|
||||
CTA_CASE_DEFAULT
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<Arch, cutlass::arch::Sm103>)
|
||||
{
|
||||
switch (gemmConfig.tile_config_sm100)
|
||||
{
|
||||
CTA_CASE(128, 128, 256)
|
||||
CTA_CASE(128, 256, 256)
|
||||
CTA_CASE_DEFAULT
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
case tkc::CutlassTileConfigSM100::CtaShape128x64x128B:
|
||||
return dispatchNVFP4xNVFP4GemmClusterShapeSm100<T, cute::Int<128>, cute::Int<64>, cute::Int<128>>(D, A, B,
|
||||
input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream,
|
||||
occupancy);
|
||||
break;
|
||||
case tkc::CutlassTileConfigSM100::CtaShape128x256x128B:
|
||||
return dispatchNVFP4xNVFP4GemmClusterShapeSm100<T, cute::Int<128>, cute::Int<256>, cute::Int<128>>(D, A, B,
|
||||
input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream,
|
||||
occupancy);
|
||||
break;
|
||||
case tkc::CutlassTileConfigSM100::CtaShape128x128x256B:
|
||||
return dispatchNVFP4xNVFP4GemmClusterShapeSm100<T, cute::Int<128>, cute::Int<128>, cute::Int<256>>(D, A, B,
|
||||
input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream,
|
||||
occupancy);
|
||||
break;
|
||||
case tkc::CutlassTileConfigSM100::CtaShape128x256x256B:
|
||||
return dispatchNVFP4xNVFP4GemmClusterShapeSm100<T, cute::Int<128>, cute::Int<256>, cute::Int<256>>(D, A, B,
|
||||
input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream,
|
||||
occupancy);
|
||||
break;
|
||||
case tkc::CutlassTileConfigSM100::Undefined:
|
||||
throw std::runtime_error("[TensorRT-LLM Error][FP4][dispatch_gemm_cta_shape] Gemm config undefined.");
|
||||
break;
|
||||
case tkc::CutlassTileConfigSM100::ChooseWithHeuristic:
|
||||
throw std::runtime_error(
|
||||
"[TensorRT-LLM Error][FP4][dispatch_gemm_cta_shape] Gemm config should have already been set by "
|
||||
"heuristic.");
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("[TensorRT-LLM Error][FP4][dispatch_gemm_cta_shape] Config is invalid for FP4 GEMM.");
|
||||
break;
|
||||
"[TensorRT-LLM Error][FP4][dispatch_gemm_cta_shape] Architecture not supported for FP4 GEMM.");
|
||||
}
|
||||
}
|
||||
|
||||
@ -343,10 +353,20 @@ size_t CutlassFp4GemmRunner<T, fp4GemmType>::dispatchToArch(T* D, void const* A,
|
||||
}
|
||||
else if constexpr (fp4GemmType == FP4GemmType::W4A4_NVFP4_NVFP4)
|
||||
{
|
||||
if (mSm == 100 || mSm == 103)
|
||||
if (mSm == 103)
|
||||
{
|
||||
return dispatchNVFP4xNVFP4GemmCTAShapeSm100<T>(D, A, B, input_sf, weight_sf, global_sf, m, n, k,
|
||||
batch_count, gemmConfig, workspace, workspaceBytes, stream, occupancy);
|
||||
#ifdef COMPILE_BLACKWELL_SM103_TMA_GEMMS
|
||||
return dispatchNVFP4xNVFP4GemmCTAShapeSm10x<cutlass::arch::Sm103, T>(D, A, B, input_sf, weight_sf,
|
||||
global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, occupancy);
|
||||
#else
|
||||
return dispatchNVFP4xNVFP4GemmCTAShapeSm10x<cutlass::arch::Sm100, T>(D, A, B, input_sf, weight_sf,
|
||||
global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, occupancy);
|
||||
#endif
|
||||
}
|
||||
else if (mSm == 100)
|
||||
{
|
||||
return dispatchNVFP4xNVFP4GemmCTAShapeSm10x<cutlass::arch::Sm100, T>(D, A, B, input_sf, weight_sf,
|
||||
global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, occupancy);
|
||||
}
|
||||
else if (mSm == 120 || mSm == 121)
|
||||
{
|
||||
@ -386,12 +406,15 @@ std::vector<tkc::CutlassGemmConfig> CutlassFp4GemmRunner<T, fp4GemmType>::getCon
|
||||
|
||||
if (mSm == 100 || mSm == 103)
|
||||
{
|
||||
std::vector<tkc::CutlassTileConfigSM100> tilesSm100 = {
|
||||
tkc::CutlassTileConfigSM100::CtaShape128x64x128B,
|
||||
tkc::CutlassTileConfigSM100::CtaShape128x256x128B,
|
||||
std::vector<tkc::CutlassTileConfigSM100> tilesSm10x = {
|
||||
tkc::CutlassTileConfigSM100::CtaShape128x128x256B,
|
||||
tkc::CutlassTileConfigSM100::CtaShape128x256x256B,
|
||||
};
|
||||
if (mSm == 100)
|
||||
{
|
||||
tilesSm10x.push_back(tkc::CutlassTileConfigSM100::CtaShape128x64x128B);
|
||||
tilesSm10x.push_back(tkc::CutlassTileConfigSM100::CtaShape128x256x128B);
|
||||
}
|
||||
std::vector<tkc::ClusterShape> clusterShapes = {
|
||||
tkc::ClusterShape::ClusterShape_1x1x1,
|
||||
tkc::ClusterShape::ClusterShape_1x2x1,
|
||||
@ -402,7 +425,7 @@ std::vector<tkc::CutlassGemmConfig> CutlassFp4GemmRunner<T, fp4GemmType>::getCon
|
||||
tkc::ClusterShape::ClusterShape_2x4x1,
|
||||
tkc::ClusterShape::ClusterShape_4x4x1,
|
||||
};
|
||||
for (auto const& tile_config : tilesSm100)
|
||||
for (auto const& tile_config : tilesSm10x)
|
||||
{
|
||||
for (auto const& cluster_config : clusterShapes)
|
||||
{
|
||||
@ -417,7 +440,7 @@ std::vector<tkc::CutlassGemmConfig> CutlassFp4GemmRunner<T, fp4GemmType>::getCon
|
||||
}
|
||||
}
|
||||
CutlassGemmConfig config(
|
||||
tile_config, tkc::MainloopScheduleType::AUTO, tkc::EpilogueScheduleType::AUTO, cluster_config);
|
||||
tile_config, tkc::MainloopScheduleType::AUTO, tkc::EpilogueScheduleType::AUTO, cluster_config, mSm);
|
||||
candidateConfigs.push_back(config);
|
||||
}
|
||||
}
|
||||
|
||||
@ -60,13 +60,13 @@ struct _2SM
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
template <typename Arch, typename T>
|
||||
struct SMTypeAdapter
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SMTypeAdapter<_1SM>
|
||||
struct SMTypeAdapter<cutlass::arch::Sm100, _1SM>
|
||||
{
|
||||
static int const Scale = 1;
|
||||
using AtomThrShape = cute::Shape<_1, _1, _1>;
|
||||
@ -75,7 +75,7 @@ struct SMTypeAdapter<_1SM>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SMTypeAdapter<_2SM>
|
||||
struct SMTypeAdapter<cutlass::arch::Sm100, _2SM>
|
||||
{
|
||||
static int const Scale = 2;
|
||||
using AtomThrShape = cute::Shape<_2, _1, _1>;
|
||||
@ -83,11 +83,29 @@ struct SMTypeAdapter<_2SM>
|
||||
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SMTypeAdapter<cutlass::arch::Sm103, _1SM>
|
||||
{
|
||||
static int const Scale = 1;
|
||||
using AtomThrShape = cute::Shape<_1, _1, _1>;
|
||||
using EpilogueSchedule = cutlass::epilogue::NoSmemWarpSpecialized1Sm;
|
||||
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SMTypeAdapter<cutlass::arch::Sm103, _2SM>
|
||||
{
|
||||
static int const Scale = 2;
|
||||
using AtomThrShape = cute::Shape<_2, _1, _1>;
|
||||
using EpilogueSchedule = cutlass::epilogue::NoSmemWarpSpecialized2Sm;
|
||||
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103;
|
||||
};
|
||||
|
||||
template <typename>
|
||||
constexpr auto always_false = false;
|
||||
|
||||
template <typename T, typename CTA_M_, typename CTA_N_, typename CTA_K_, typename CGA_M_, typename CGA_N_,
|
||||
typename CGA_K_, typename XSM_>
|
||||
template <typename Arch, typename T, typename CTA_M_, typename CTA_N_, typename CTA_K_, typename CGA_M_,
|
||||
typename CGA_N_, typename CGA_K_, typename XSM_>
|
||||
size_t genericFp4GemmKernelLauncher(void* D, void const* A, void const* B, void const* input_sf, void const* weight_sf,
|
||||
float const* global_sf, int m, int n, int k, int batch_count, tkc::CutlassGemmConfig gemmConfig, char* workspace,
|
||||
size_t const workspaceBytes, cudaStream_t stream, int* occupancy)
|
||||
@ -98,13 +116,13 @@ size_t genericFp4GemmKernelLauncher(void* D, void const* A, void const* B, void
|
||||
|
||||
#ifdef PLACEHOLDER_KERNELS
|
||||
|
||||
#define INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(T, CTA_M_, CTA_N_, CTA_K_, CGA_M_, CGA_N_, CGA_K_, XSM_) \
|
||||
#define INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(ARCH_, T, CTA_M_, CTA_N_, CTA_K_, CGA_M_, CGA_N_, CGA_K_, XSM_) \
|
||||
template <> \
|
||||
size_t genericFp4GemmKernelLauncher<T, cute::Int<CTA_M_>, cute::Int<CTA_N_>, cute::Int<CTA_K_>, cute::Int<CGA_M_>, \
|
||||
cute::Int<CGA_N_>, cute::Int<CGA_K_>, XSM_>(void* D, void const* A, void const* B, void const* input_sf, \
|
||||
void const* weight_sf, float const* global_sf, int m, int n, int k, int batch_count, \
|
||||
tkc::CutlassGemmConfig gemmConfig, char* workspace, const size_t workspaceBytes, cudaStream_t stream, \
|
||||
int* occupancy) \
|
||||
size_t genericFp4GemmKernelLauncher<cutlass::arch::ARCH_, T, cute::Int<CTA_M_>, cute::Int<CTA_N_>, \
|
||||
cute::Int<CTA_K_>, cute::Int<CGA_M_>, cute::Int<CGA_N_>, cute::Int<CGA_K_>, XSM_>(void* D, void const* A, \
|
||||
void const* B, void const* input_sf, void const* weight_sf, float const* global_sf, int m, int n, int k, \
|
||||
int batch_count, tkc::CutlassGemmConfig gemmConfig, char* workspace, const size_t workspaceBytes, \
|
||||
cudaStream_t stream, int* occupancy) \
|
||||
{ \
|
||||
throw std::runtime_error( \
|
||||
"[TensorRT-LLM Error][FP4 gemm Runner] TensorRT-LLM is not compiled with support for this Architecture."); \
|
||||
@ -112,15 +130,15 @@ size_t genericFp4GemmKernelLauncher(void* D, void const* A, void const* B, void
|
||||
|
||||
#else
|
||||
|
||||
#define INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(T, CTA_M_, CTA_N_, CTA_K_, CGA_M_, CGA_N_, CGA_K_, XSM_) \
|
||||
struct DeviceGemmFp4GemmSm100_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_ \
|
||||
#define INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(ARCH_, T, CTA_M_, CTA_N_, CTA_K_, CGA_M_, CGA_N_, CGA_K_, XSM_) \
|
||||
struct DeviceGemmFp4Gemm##ARCH_##_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_ \
|
||||
{ \
|
||||
using OutElementType = TllmToCutlassTypeAdapter<T>::type; \
|
||||
using CTAShape = cute::Shape<cute::Int<CTA_M_>, cute::Int<CTA_N_>, cute::Int<CTA_K_>>; \
|
||||
/*using ClusterShape = cute::Shape<cute::Int<CGA_M_>, cute::Int<CGA_N_>, cute::Int<CGA_K_>>;*/ \
|
||||
using ClusterShape = cute::Shape<int, int, _1>; \
|
||||
using ElementType = cutlass::float_e2m1_t; \
|
||||
using Arch = cutlass::arch::Sm100; \
|
||||
using Arch = cutlass::arch::ARCH_; \
|
||||
/* // Input A */ \
|
||||
using ElementA = ElementType; \
|
||||
using LayoutA = cutlass::layout::RowMajor; \
|
||||
@ -140,10 +158,10 @@ size_t genericFp4GemmKernelLauncher(void* D, void const* A, void const* B, void
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; \
|
||||
using EpilogueTileType = std::conditional_t<CTA_M_ == 128 && CTA_N_ == 256 && CTA_K_ == 256, \
|
||||
cute::Shape<cute::_128, cute::_64>, cutlass::epilogue::collective::EpilogueTileAuto>; \
|
||||
using EpilogueSchedule = SMTypeAdapter<XSM_>::EpilogueSchedule; \
|
||||
using MainloopSchedule = SMTypeAdapter<XSM_>::MainloopSchedule; \
|
||||
using MmaTileShape \
|
||||
= cute::Shape<cute::Int<CTA_M_ * SMTypeAdapter<XSM_>::Scale>, cute::Int<CTA_N_>, cute::Int<CTA_K_>>; \
|
||||
using EpilogueSchedule = SMTypeAdapter<Arch, XSM_>::EpilogueSchedule; \
|
||||
using MainloopSchedule = SMTypeAdapter<Arch, XSM_>::MainloopSchedule; \
|
||||
using MmaTileShape = cute::Shape<cute::Int<CTA_M_ * SMTypeAdapter<Arch, XSM_>::Scale>, cute::Int<CTA_N_>, \
|
||||
cute::Int<CTA_K_*(std::is_same_v<Arch, cutlass::arch::Sm103> ? 3 : 1)>>; \
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<Arch, OperatorClass, \
|
||||
MmaTileShape, ClusterShape, EpilogueTileType, ElementAccumulator, ElementCompute, ElementC, LayoutC, \
|
||||
AlignmentC, OutElementType, LayoutC, AlignmentC, EpilogueSchedule, \
|
||||
@ -185,7 +203,7 @@ size_t genericFp4GemmKernelLauncher(void* D, void const* A, void const* B, void
|
||||
\
|
||||
template <typename Gemm> \
|
||||
typename Gemm::Arguments \
|
||||
prepareGemmArgs_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_(void* D, \
|
||||
prepareGemmArgs_##ARCH_##_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_(void* D, \
|
||||
void const* A, void const* B, void const* input_sf, void const* weight_sf, float const* global_sf, int m, \
|
||||
int n, int k, int batch_count) \
|
||||
{ \
|
||||
@ -234,17 +252,18 @@ size_t genericFp4GemmKernelLauncher(void* D, void const* A, void const* B, void
|
||||
operator_args.scheduler.raster_order = Enum_t::Heuristic; \
|
||||
} \
|
||||
operator_args.hw_info.cluster_shape = dim3(CGA_M_, CGA_N_, CGA_K_); \
|
||||
operator_args.hw_info.cluster_shape_fallback = dim3(SMTypeAdapter<XSM_>::Scale, 1, 1); \
|
||||
using Arch = cutlass::arch::ARCH_; \
|
||||
operator_args.hw_info.cluster_shape_fallback = dim3(SMTypeAdapter<Arch, XSM_>::Scale, 1, 1); \
|
||||
\
|
||||
return operator_args; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
size_t genericFp4GemmKernelLauncher<T, cute::Int<CTA_M_>, cute::Int<CTA_N_>, cute::Int<CTA_K_>, cute::Int<CGA_M_>, \
|
||||
cute::Int<CGA_N_>, cute::Int<CGA_K_>, XSM_>(void* D, void const* A, void const* B, void const* input_sf, \
|
||||
void const* weight_sf, float const* global_sf, int m, int n, int k, int batch_count, \
|
||||
tkc::CutlassGemmConfig gemmConfig, char* workspace, const size_t workspaceBytes, cudaStream_t stream, \
|
||||
int* occupancy) \
|
||||
size_t genericFp4GemmKernelLauncher<cutlass::arch::ARCH_, T, cute::Int<CTA_M_>, cute::Int<CTA_N_>, \
|
||||
cute::Int<CTA_K_>, cute::Int<CGA_M_>, cute::Int<CGA_N_>, cute::Int<CGA_K_>, XSM_>(void* D, void const* A, \
|
||||
void const* B, void const* input_sf, void const* weight_sf, float const* global_sf, int m, int n, int k, \
|
||||
int batch_count, tkc::CutlassGemmConfig gemmConfig, char* workspace, const size_t workspaceBytes, \
|
||||
cudaStream_t stream, int* occupancy) \
|
||||
{ \
|
||||
using ElementOutput__ = typename cutlass::platform::conditional<cutlass::platform::is_same<T, half>::value, \
|
||||
cutlass::half_t, T>::type; \
|
||||
@ -256,11 +275,12 @@ size_t genericFp4GemmKernelLauncher(void* D, void const* A, void const* B, void
|
||||
cutlass::bfloat16_t, ElementOutput_>::type; \
|
||||
\
|
||||
using Fp4GemmOperator \
|
||||
= DeviceGemmFp4GemmSm100_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_:: \
|
||||
= DeviceGemmFp4Gemm##ARCH_##_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_:: \
|
||||
Gemm; \
|
||||
Fp4GemmOperator gemm; \
|
||||
auto args = prepareGemmArgs_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_< \
|
||||
Fp4GemmOperator>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count); \
|
||||
auto args \
|
||||
= prepareGemmArgs_##ARCH_##_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_< \
|
||||
Fp4GemmOperator>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count); \
|
||||
/* // Check shared memory size; throw when SMEM exceeds */ \
|
||||
int smem_size = int(sizeof(typename Fp4GemmOperator::GemmKernel::SharedStorage)); \
|
||||
static int mMaxSmemSize = tk::getMaxSharedMemoryPerBlockOptin(); \
|
||||
|
||||
@ -64,6 +64,12 @@ tkc::CutlassGemmConfig getDefaultGemmConfig(int64_t m, int64_t n, int64_t k, FP4
|
||||
tkc::MainloopScheduleType::AUTO, tkc::EpilogueScheduleType::AUTO,
|
||||
tkc::ClusterShape::ClusterShape_1x1x1);
|
||||
}
|
||||
else if (sm == 103)
|
||||
{
|
||||
return tkc::CutlassGemmConfig(tkc::CutlassTileConfigSM100::CtaShape128x256x256B,
|
||||
tkc::MainloopScheduleType::AUTO, tkc::EpilogueScheduleType::AUTO,
|
||||
tkc::ClusterShape::ClusterShape_1x1x1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return tkc::CutlassGemmConfig(tkc::CutlassTileConfigSM100::CtaShape128x256x128B,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user