diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_bf16.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_bf16.cu index 5c6013407a..cbf33a9ce5 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_bf16.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_bf16.cu @@ -23,38 +23,55 @@ namespace kernels namespace cutlass_kernels { #ifdef ENABLE_BF16 -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 64, 128, 1, 1, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 64, 128, 1, 2, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 64, 128, 1, 4, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 64, 128, 2, 1, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 64, 128, 2, 2, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 64, 128, 2, 4, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 64, 128, 4, 2, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 64, 128, 4, 4, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 128, 1, 1, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 128, 1, 2, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 128, 1, 4, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 128, 2, 1, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 128, 2, 2, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 128, 2, 4, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 128, 4, 2, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 128, 4, 4, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 128, 256, 1, 1, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 128, 256, 1, 2, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 128, 256, 1, 4, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 128, 256, 2, 1, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 128, 256, 2, 2, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 128, 256, 2, 4, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 128, 256, 4, 2, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 128, 256, 4, 4, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 256, 1, 1, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 256, 1, 2, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 256, 1, 4, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 256, 2, 1, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 256, 2, 2, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 256, 2, 4, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 256, 4, 2, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(__nv_bfloat16, 128, 256, 256, 4, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 64, 128, 1, 1, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 64, 128, 1, 2, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 64, 128, 1, 4, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 64, 128, 2, 1, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 64, 128, 2, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 64, 128, 2, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 64, 128, 4, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 64, 128, 4, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 128, 1, 1, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 128, 1, 2, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 128, 1, 4, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 128, 2, 1, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 128, 2, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 128, 2, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 128, 4, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 128, 4, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 128, 256, 1, 1, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 128, 256, 1, 2, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 128, 256, 1, 4, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 128, 256, 2, 1, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 128, 256, 2, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 128, 256, 2, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 128, 256, 4, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 128, 256, 4, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 256, 1, 1, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 256, 1, 2, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 256, 1, 4, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 256, 2, 1, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 256, 2, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 256, 2, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 256, 4, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, __nv_bfloat16, 128, 256, 256, 4, 4, 1, _2SM) + +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 128, 256, 1, 1, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 128, 256, 1, 2, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 128, 256, 1, 4, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 128, 256, 2, 1, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 128, 256, 2, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 128, 256, 2, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 128, 256, 4, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 128, 256, 4, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 256, 256, 1, 1, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 256, 256, 1, 2, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 256, 256, 1, 4, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 256, 256, 2, 1, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 256, 256, 2, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 256, 256, 2, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 256, 256, 4, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, __nv_bfloat16, 128, 256, 256, 4, 4, 1, _2SM) INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER_SM120(__nv_bfloat16, 128, 128, 128, 1, 1, 1) INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER_SM120(__nv_bfloat16, 128, 128, 256, 1, 1, 1) diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_fp16.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_fp16.cu index 7f41b93e66..0b232fb95b 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_fp16.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_fp16.cu @@ -22,38 +22,55 @@ namespace kernels { namespace cutlass_kernels { -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 64, 128, 1, 1, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 64, 128, 1, 2, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 64, 128, 1, 4, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 64, 128, 2, 1, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 64, 128, 2, 2, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 64, 128, 2, 4, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 64, 128, 4, 2, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 64, 128, 4, 4, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 128, 1, 1, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 128, 1, 2, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 128, 1, 4, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 128, 2, 1, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 128, 2, 2, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 128, 2, 4, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 128, 4, 2, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 128, 4, 4, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 128, 256, 1, 1, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 128, 256, 1, 2, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 128, 256, 1, 4, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 128, 256, 2, 1, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 128, 256, 2, 2, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 128, 256, 2, 4, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 128, 256, 4, 2, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 128, 256, 4, 4, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 256, 1, 1, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 256, 1, 2, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 256, 1, 4, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 256, 2, 1, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 256, 2, 2, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 256, 2, 4, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 256, 4, 2, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(half, 128, 256, 256, 4, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 64, 128, 1, 1, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 64, 128, 1, 2, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 64, 128, 1, 4, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 64, 128, 2, 1, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 64, 128, 2, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 64, 128, 2, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 64, 128, 4, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 64, 128, 4, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 128, 1, 1, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 128, 1, 2, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 128, 1, 4, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 128, 2, 1, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 128, 2, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 128, 2, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 128, 4, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 128, 4, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 128, 256, 1, 1, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 128, 256, 1, 2, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 128, 256, 1, 4, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 128, 256, 2, 1, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 128, 256, 2, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 128, 256, 2, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 128, 256, 4, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 128, 256, 4, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 256, 1, 1, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 256, 1, 2, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 256, 1, 4, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 256, 2, 1, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 256, 2, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 256, 2, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 256, 4, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, half, 128, 256, 256, 4, 4, 1, _2SM) + +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 128, 256, 1, 1, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 128, 256, 1, 2, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 128, 256, 1, 4, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 128, 256, 2, 1, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 128, 256, 2, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 128, 256, 2, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 128, 256, 4, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 128, 256, 4, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 256, 256, 1, 1, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 256, 256, 1, 2, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 256, 256, 1, 4, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 256, 256, 2, 1, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 256, 256, 2, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 256, 256, 2, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 256, 256, 4, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, half, 128, 256, 256, 4, 4, 1, _2SM) INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER_SM120(half, 128, 128, 128, 1, 1, 1) INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER_SM120(half, 128, 128, 256, 1, 1, 1) diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_fp32.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_fp32.cu index abfbd30c44..d733c97f6b 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_fp32.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_fp32.cu @@ -22,38 +22,55 @@ namespace kernels { namespace cutlass_kernels { -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 64, 128, 1, 1, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 64, 128, 1, 2, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 64, 128, 1, 4, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 64, 128, 2, 1, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 64, 128, 2, 2, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 64, 128, 2, 4, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 64, 128, 4, 2, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 64, 128, 4, 4, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 128, 1, 1, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 128, 1, 2, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 128, 1, 4, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 128, 2, 1, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 128, 2, 2, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 128, 2, 4, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 128, 4, 2, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 128, 4, 4, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 128, 256, 1, 1, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 128, 256, 1, 2, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 128, 256, 1, 4, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 128, 256, 2, 1, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 128, 256, 2, 2, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 128, 256, 2, 4, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 128, 256, 4, 2, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 128, 256, 4, 4, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 256, 1, 1, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 256, 1, 2, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 256, 1, 4, 1, _1SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 256, 2, 1, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 256, 2, 2, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 256, 2, 4, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 256, 4, 2, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(float, 128, 256, 256, 4, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 64, 128, 1, 1, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 64, 128, 1, 2, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 64, 128, 1, 4, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 64, 128, 2, 1, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 64, 128, 2, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 64, 128, 2, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 64, 128, 4, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 64, 128, 4, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 128, 1, 1, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 128, 1, 2, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 128, 1, 4, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 128, 2, 1, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 128, 2, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 128, 2, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 128, 4, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 128, 4, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 128, 256, 1, 1, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 128, 256, 1, 2, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 128, 256, 1, 4, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 128, 256, 2, 1, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 128, 256, 2, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 128, 256, 2, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 128, 256, 4, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 128, 256, 4, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 256, 1, 1, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 256, 1, 2, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 256, 1, 4, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 256, 2, 1, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 256, 2, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 256, 2, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 256, 4, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm100, float, 128, 256, 256, 4, 4, 1, _2SM) + +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 128, 256, 1, 1, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 128, 256, 1, 2, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 128, 256, 1, 4, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 128, 256, 2, 1, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 128, 256, 2, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 128, 256, 2, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 128, 256, 4, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 128, 256, 4, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 256, 256, 1, 1, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 256, 256, 1, 2, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 256, 256, 1, 4, 1, _1SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 256, 256, 2, 1, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 256, 256, 2, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 256, 256, 2, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 256, 256, 4, 2, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(Sm103, float, 128, 256, 256, 4, 4, 1, _2SM) INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER_SM120(float, 128, 128, 128, 1, 1, 1) INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER_SM120(float, 128, 128, 256, 1, 1, 1) diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_template.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_template.h index f8f01f1a85..a404f9c588 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_template.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_template.h @@ -16,6 +16,7 @@ #pragma once +#include #ifndef _WIN32 #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" @@ -52,8 +53,8 @@ using namespace cute; namespace tk = tensorrt_llm::common; namespace tkc = tensorrt_llm::cutlass_extensions; -template -size_t dispatchNVFP4xNVFP4GemmClusterShapeSm100(T* D, void const* A, void const* B, void const* input_sf, +template +size_t dispatchNVFP4xNVFP4GemmClusterShapeSm10x(T* D, void const* A, void const* B, void const* input_sf, void const* weight_sf, float const* global_sf, int m, int n, int k, int batch_count, tkc::CutlassGemmConfig gemmConfig, char* workspace, const size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr) @@ -64,43 +65,43 @@ size_t dispatchNVFP4xNVFP4GemmClusterShapeSm100(T* D, void const* A, void const* switch (gemmConfig.cluster_shape) { case tkc::ClusterShape::ClusterShape_1x1x1: - return genericFp4GemmKernelLauncher, cute::Int<1>, cute::Int<1>, _1SM>( - D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, + return genericFp4GemmKernelLauncher, cute::Int<1>, cute::Int<1>, + _1SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, occupancy); break; case tkc::ClusterShape::ClusterShape_2x1x1: - return genericFp4GemmKernelLauncher, cute::Int<1>, cute::Int<1>, _2SM>( - D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, + return genericFp4GemmKernelLauncher, cute::Int<1>, cute::Int<1>, + _2SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, occupancy); break; case tkc::ClusterShape::ClusterShape_1x2x1: - return genericFp4GemmKernelLauncher, cute::Int<2>, cute::Int<1>, _1SM>( - D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, + return genericFp4GemmKernelLauncher, cute::Int<2>, cute::Int<1>, + _1SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, occupancy); break; case tkc::ClusterShape::ClusterShape_2x2x1: - return genericFp4GemmKernelLauncher, cute::Int<2>, cute::Int<1>, _2SM>( - D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, + return genericFp4GemmKernelLauncher, cute::Int<2>, cute::Int<1>, + _2SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, occupancy); break; case tkc::ClusterShape::ClusterShape_1x4x1: - return genericFp4GemmKernelLauncher, cute::Int<4>, cute::Int<1>, _1SM>( - D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, + return genericFp4GemmKernelLauncher, cute::Int<4>, cute::Int<1>, + _1SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, occupancy); break; case tkc::ClusterShape::ClusterShape_4x2x1: - return genericFp4GemmKernelLauncher, cute::Int<2>, cute::Int<1>, _2SM>( - D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, + return genericFp4GemmKernelLauncher, cute::Int<2>, cute::Int<1>, + _2SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, occupancy); break; case tkc::ClusterShape::ClusterShape_2x4x1: - return genericFp4GemmKernelLauncher, cute::Int<4>, cute::Int<1>, _2SM>( - D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, + return genericFp4GemmKernelLauncher, cute::Int<4>, cute::Int<1>, + _2SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, occupancy); break; case tkc::ClusterShape::ClusterShape_4x4x1: - return genericFp4GemmKernelLauncher, cute::Int<4>, cute::Int<1>, _2SM>( - D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, + return genericFp4GemmKernelLauncher, cute::Int<4>, cute::Int<1>, + _2SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, occupancy); break; default: @@ -110,8 +111,8 @@ size_t dispatchNVFP4xNVFP4GemmClusterShapeSm100(T* D, void const* A, void const* } } -template -size_t dispatchNVFP4xNVFP4GemmCTAShapeSm100(T* D, void const* A, void const* B, void const* input_sf, +template +size_t dispatchNVFP4xNVFP4GemmCTAShapeSm10x(T* D, void const* A, void const* B, void const* input_sf, void const* weight_sf, float const* global_sf, int m, int n, int k, int batch_count, tkc::CutlassGemmConfig gemmConfig, char* workspace, const size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr) @@ -123,39 +124,48 @@ size_t dispatchNVFP4xNVFP4GemmCTAShapeSm100(T* D, void const* A, void const* B, // M-mode size should be 128 or 256 for 2 CTA cluster MMA; // M-mode size should be 128 for 1 CTA cluster OMMA. // K256 looks to be better than K128 - switch (gemmConfig.tile_config_sm100) +#define CTA_CASE(M, N, K) \ + case tkc::CutlassTileConfigSM100::CtaShape##M##x##N##x##K##B: \ + return dispatchNVFP4xNVFP4GemmClusterShapeSm10x, cute::Int, cute::Int>(D, A, B, \ + input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, \ + occupancy); +#define CTA_CASE_DEFAULT \ + case tkc::CutlassTileConfigSM100::Undefined: \ + throw std::runtime_error("[TensorRT-LLM Error][FP4][dispatch_gemm_cta_shape] Gemm config undefined."); \ + break; \ + case tkc::CutlassTileConfigSM100::ChooseWithHeuristic: \ + throw std::runtime_error( \ + "[TensorRT-LLM Error][FP4][dispatch_gemm_cta_shape] Gemm config should have already been set by " \ + "heuristic."); \ + break; \ + default: \ + throw std::runtime_error( \ + "[TensorRT-LLM Error][FP4][dispatch_gemm_cta_shape] Config is invalid for FP4 GEMM."); \ + break; + if constexpr (std::is_same_v) + { + switch (gemmConfig.tile_config_sm100) + { + CTA_CASE(128, 64, 128) + CTA_CASE(128, 256, 128) + CTA_CASE(128, 128, 256) + CTA_CASE(128, 256, 256) + CTA_CASE_DEFAULT + } + } + else if constexpr (std::is_same_v) + { + switch (gemmConfig.tile_config_sm100) + { + CTA_CASE(128, 128, 256) + CTA_CASE(128, 256, 256) + CTA_CASE_DEFAULT + } + } + else { - case tkc::CutlassTileConfigSM100::CtaShape128x64x128B: - return dispatchNVFP4xNVFP4GemmClusterShapeSm100, cute::Int<64>, cute::Int<128>>(D, A, B, - input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, - occupancy); - break; - case tkc::CutlassTileConfigSM100::CtaShape128x256x128B: - return dispatchNVFP4xNVFP4GemmClusterShapeSm100, cute::Int<256>, cute::Int<128>>(D, A, B, - input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, - occupancy); - break; - case tkc::CutlassTileConfigSM100::CtaShape128x128x256B: - return dispatchNVFP4xNVFP4GemmClusterShapeSm100, cute::Int<128>, cute::Int<256>>(D, A, B, - input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, - occupancy); - break; - case tkc::CutlassTileConfigSM100::CtaShape128x256x256B: - return dispatchNVFP4xNVFP4GemmClusterShapeSm100, cute::Int<256>, cute::Int<256>>(D, A, B, - input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, - occupancy); - break; - case tkc::CutlassTileConfigSM100::Undefined: - throw std::runtime_error("[TensorRT-LLM Error][FP4][dispatch_gemm_cta_shape] Gemm config undefined."); - break; - case tkc::CutlassTileConfigSM100::ChooseWithHeuristic: throw std::runtime_error( - "[TensorRT-LLM Error][FP4][dispatch_gemm_cta_shape] Gemm config should have already been set by " - "heuristic."); - break; - default: - throw std::runtime_error("[TensorRT-LLM Error][FP4][dispatch_gemm_cta_shape] Config is invalid for FP4 GEMM."); - break; + "[TensorRT-LLM Error][FP4][dispatch_gemm_cta_shape] Architecture not supported for FP4 GEMM."); } } @@ -343,10 +353,20 @@ size_t CutlassFp4GemmRunner::dispatchToArch(T* D, void const* A, } else if constexpr (fp4GemmType == FP4GemmType::W4A4_NVFP4_NVFP4) { - if (mSm == 100 || mSm == 103) + if (mSm == 103) { - return dispatchNVFP4xNVFP4GemmCTAShapeSm100(D, A, B, input_sf, weight_sf, global_sf, m, n, k, - batch_count, gemmConfig, workspace, workspaceBytes, stream, occupancy); +#ifdef COMPILE_BLACKWELL_SM103_TMA_GEMMS + return dispatchNVFP4xNVFP4GemmCTAShapeSm10x(D, A, B, input_sf, weight_sf, + global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, occupancy); +#else + return dispatchNVFP4xNVFP4GemmCTAShapeSm10x(D, A, B, input_sf, weight_sf, + global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, occupancy); +#endif + } + else if (mSm == 100) + { + return dispatchNVFP4xNVFP4GemmCTAShapeSm10x(D, A, B, input_sf, weight_sf, + global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, occupancy); } else if (mSm == 120 || mSm == 121) { @@ -386,12 +406,15 @@ std::vector CutlassFp4GemmRunner::getCon if (mSm == 100 || mSm == 103) { - std::vector tilesSm100 = { - tkc::CutlassTileConfigSM100::CtaShape128x64x128B, - tkc::CutlassTileConfigSM100::CtaShape128x256x128B, + std::vector tilesSm10x = { tkc::CutlassTileConfigSM100::CtaShape128x128x256B, tkc::CutlassTileConfigSM100::CtaShape128x256x256B, }; + if (mSm == 100) + { + tilesSm10x.push_back(tkc::CutlassTileConfigSM100::CtaShape128x64x128B); + tilesSm10x.push_back(tkc::CutlassTileConfigSM100::CtaShape128x256x128B); + } std::vector clusterShapes = { tkc::ClusterShape::ClusterShape_1x1x1, tkc::ClusterShape::ClusterShape_1x2x1, @@ -402,7 +425,7 @@ std::vector CutlassFp4GemmRunner::getCon tkc::ClusterShape::ClusterShape_2x4x1, tkc::ClusterShape::ClusterShape_4x4x1, }; - for (auto const& tile_config : tilesSm100) + for (auto const& tile_config : tilesSm10x) { for (auto const& cluster_config : clusterShapes) { @@ -417,7 +440,7 @@ std::vector CutlassFp4GemmRunner::getCon } } CutlassGemmConfig config( - tile_config, tkc::MainloopScheduleType::AUTO, tkc::EpilogueScheduleType::AUTO, cluster_config); + tile_config, tkc::MainloopScheduleType::AUTO, tkc::EpilogueScheduleType::AUTO, cluster_config, mSm); candidateConfigs.push_back(config); } } diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/nvfp4_nvfp4_gemm_template_sm100.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/nvfp4_nvfp4_gemm_template_sm100.h index da7b303351..ef7b101bf7 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/nvfp4_nvfp4_gemm_template_sm100.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/nvfp4_nvfp4_gemm_template_sm100.h @@ -60,13 +60,13 @@ struct _2SM { }; -template +template struct SMTypeAdapter { }; template <> -struct SMTypeAdapter<_1SM> +struct SMTypeAdapter { static int const Scale = 1; using AtomThrShape = cute::Shape<_1, _1, _1>; @@ -75,7 +75,7 @@ struct SMTypeAdapter<_1SM> }; template <> -struct SMTypeAdapter<_2SM> +struct SMTypeAdapter { static int const Scale = 2; using AtomThrShape = cute::Shape<_2, _1, _1>; @@ -83,11 +83,29 @@ struct SMTypeAdapter<_2SM> using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100; }; +template <> +struct SMTypeAdapter +{ + static int const Scale = 1; + using AtomThrShape = cute::Shape<_1, _1, _1>; + using EpilogueSchedule = cutlass::epilogue::NoSmemWarpSpecialized1Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103; +}; + +template <> +struct SMTypeAdapter +{ + static int const Scale = 2; + using AtomThrShape = cute::Shape<_2, _1, _1>; + using EpilogueSchedule = cutlass::epilogue::NoSmemWarpSpecialized2Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103; +}; + template constexpr auto always_false = false; -template +template size_t genericFp4GemmKernelLauncher(void* D, void const* A, void const* B, void const* input_sf, void const* weight_sf, float const* global_sf, int m, int n, int k, int batch_count, tkc::CutlassGemmConfig gemmConfig, char* workspace, size_t const workspaceBytes, cudaStream_t stream, int* occupancy) @@ -98,13 +116,13 @@ size_t genericFp4GemmKernelLauncher(void* D, void const* A, void const* B, void #ifdef PLACEHOLDER_KERNELS -#define INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(T, CTA_M_, CTA_N_, CTA_K_, CGA_M_, CGA_N_, CGA_K_, XSM_) \ +#define INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(ARCH_, T, CTA_M_, CTA_N_, CTA_K_, CGA_M_, CGA_N_, CGA_K_, XSM_) \ template <> \ - size_t genericFp4GemmKernelLauncher, cute::Int, cute::Int, cute::Int, \ - cute::Int, cute::Int, XSM_>(void* D, void const* A, void const* B, void const* input_sf, \ - void const* weight_sf, float const* global_sf, int m, int n, int k, int batch_count, \ - tkc::CutlassGemmConfig gemmConfig, char* workspace, const size_t workspaceBytes, cudaStream_t stream, \ - int* occupancy) \ + size_t genericFp4GemmKernelLauncher, cute::Int, \ + cute::Int, cute::Int, cute::Int, cute::Int, XSM_>(void* D, void const* A, \ + void const* B, void const* input_sf, void const* weight_sf, float const* global_sf, int m, int n, int k, \ + int batch_count, tkc::CutlassGemmConfig gemmConfig, char* workspace, const size_t workspaceBytes, \ + cudaStream_t stream, int* occupancy) \ { \ throw std::runtime_error( \ "[TensorRT-LLM Error][FP4 gemm Runner] TensorRT-LLM is not compiled with support for this Architecture."); \ @@ -112,15 +130,15 @@ size_t genericFp4GemmKernelLauncher(void* D, void const* A, void const* B, void #else -#define INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(T, CTA_M_, CTA_N_, CTA_K_, CGA_M_, CGA_N_, CGA_K_, XSM_) \ - struct DeviceGemmFp4GemmSm100_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_ \ +#define INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER(ARCH_, T, CTA_M_, CTA_N_, CTA_K_, CGA_M_, CGA_N_, CGA_K_, XSM_) \ + struct DeviceGemmFp4Gemm##ARCH_##_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_ \ { \ using OutElementType = TllmToCutlassTypeAdapter::type; \ using CTAShape = cute::Shape, cute::Int, cute::Int>; \ /*using ClusterShape = cute::Shape, cute::Int, cute::Int>;*/ \ using ClusterShape = cute::Shape; \ using ElementType = cutlass::float_e2m1_t; \ - using Arch = cutlass::arch::Sm100; \ + using Arch = cutlass::arch::ARCH_; \ /* // Input A */ \ using ElementA = ElementType; \ using LayoutA = cutlass::layout::RowMajor; \ @@ -140,10 +158,10 @@ size_t genericFp4GemmKernelLauncher(void* D, void const* A, void const* B, void using OperatorClass = cutlass::arch::OpClassTensorOp; \ using EpilogueTileType = std::conditional_t, cutlass::epilogue::collective::EpilogueTileAuto>; \ - using EpilogueSchedule = SMTypeAdapter::EpilogueSchedule; \ - using MainloopSchedule = SMTypeAdapter::MainloopSchedule; \ - using MmaTileShape \ - = cute::Shape::Scale>, cute::Int, cute::Int>; \ + using EpilogueSchedule = SMTypeAdapter::EpilogueSchedule; \ + using MainloopSchedule = SMTypeAdapter::MainloopSchedule; \ + using MmaTileShape = cute::Shape::Scale>, cute::Int, \ + cute::Int ? 3 : 1)>>; \ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder \ typename Gemm::Arguments \ - prepareGemmArgs_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_(void* D, \ + prepareGemmArgs_##ARCH_##_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_(void* D, \ void const* A, void const* B, void const* input_sf, void const* weight_sf, float const* global_sf, int m, \ int n, int k, int batch_count) \ { \ @@ -234,17 +252,18 @@ size_t genericFp4GemmKernelLauncher(void* D, void const* A, void const* B, void operator_args.scheduler.raster_order = Enum_t::Heuristic; \ } \ operator_args.hw_info.cluster_shape = dim3(CGA_M_, CGA_N_, CGA_K_); \ - operator_args.hw_info.cluster_shape_fallback = dim3(SMTypeAdapter::Scale, 1, 1); \ + using Arch = cutlass::arch::ARCH_; \ + operator_args.hw_info.cluster_shape_fallback = dim3(SMTypeAdapter::Scale, 1, 1); \ \ return operator_args; \ } \ \ template <> \ - size_t genericFp4GemmKernelLauncher, cute::Int, cute::Int, cute::Int, \ - cute::Int, cute::Int, XSM_>(void* D, void const* A, void const* B, void const* input_sf, \ - void const* weight_sf, float const* global_sf, int m, int n, int k, int batch_count, \ - tkc::CutlassGemmConfig gemmConfig, char* workspace, const size_t workspaceBytes, cudaStream_t stream, \ - int* occupancy) \ + size_t genericFp4GemmKernelLauncher, cute::Int, \ + cute::Int, cute::Int, cute::Int, cute::Int, XSM_>(void* D, void const* A, \ + void const* B, void const* input_sf, void const* weight_sf, float const* global_sf, int m, int n, int k, \ + int batch_count, tkc::CutlassGemmConfig gemmConfig, char* workspace, const size_t workspaceBytes, \ + cudaStream_t stream, int* occupancy) \ { \ using ElementOutput__ = typename cutlass::platform::conditional::value, \ cutlass::half_t, T>::type; \ @@ -256,11 +275,12 @@ size_t genericFp4GemmKernelLauncher(void* D, void const* A, void const* B, void cutlass::bfloat16_t, ElementOutput_>::type; \ \ using Fp4GemmOperator \ - = DeviceGemmFp4GemmSm100_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_:: \ + = DeviceGemmFp4Gemm##ARCH_##_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_:: \ Gemm; \ Fp4GemmOperator gemm; \ - auto args = prepareGemmArgs_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_< \ - Fp4GemmOperator>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count); \ + auto args \ + = prepareGemmArgs_##ARCH_##_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_< \ + Fp4GemmOperator>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count); \ /* // Check shared memory size; throw when SMEM exceeds */ \ int smem_size = int(sizeof(typename Fp4GemmOperator::GemmKernel::SharedStorage)); \ static int mMaxSmemSize = tk::getMaxSharedMemoryPerBlockOptin(); \ diff --git a/cpp/tensorrt_llm/thop/fp4Gemm.cpp b/cpp/tensorrt_llm/thop/fp4Gemm.cpp index 327a3537c0..2fa818bdee 100644 --- a/cpp/tensorrt_llm/thop/fp4Gemm.cpp +++ b/cpp/tensorrt_llm/thop/fp4Gemm.cpp @@ -64,6 +64,12 @@ tkc::CutlassGemmConfig getDefaultGemmConfig(int64_t m, int64_t n, int64_t k, FP4 tkc::MainloopScheduleType::AUTO, tkc::EpilogueScheduleType::AUTO, tkc::ClusterShape::ClusterShape_1x1x1); } + else if (sm == 103) + { + return tkc::CutlassGemmConfig(tkc::CutlassTileConfigSM100::CtaShape128x256x256B, + tkc::MainloopScheduleType::AUTO, tkc::EpilogueScheduleType::AUTO, + tkc::ClusterShape::ClusterShape_1x1x1); + } else { return tkc::CutlassGemmConfig(tkc::CutlassTileConfigSM100::CtaShape128x256x128B,