From 469a38d0d8258fe2f3cbeb1c1efbaf8f5fbaad67 Mon Sep 17 00:00:00 2001 From: Daniel Stokes Date: Tue, 8 Jul 2025 12:37:34 +1200 Subject: [PATCH] feat: Add support for SM103 3xFP4 tile shapes Signed-off-by: Daniel Stokes --- .gitmodules | 3 + 3rdparty/cutlass | 2 +- 3rdparty/dynamic-kernel-generator | 1 + cpp/CMakeLists.txt | 4 +- cpp/cmake/modules/cuda_configuration.cmake | 12 +- .../include/cutlass_extensions/gemm_configs.h | 12 +- .../kernels/cutlass_kernels/CMakeLists.txt | 22 ++- .../cutlass_kernels/cutlass_heuristic.cpp | 133 +++++++++--------- .../fpA_intB_gemm/fpA_intB_gemm_template.h | 7 +- .../fpA_intB_gemm_template_sm90.h | 4 +- .../launchers/fpA_intB_launcher_sm90.h | 4 +- .../launchers/fpA_intB_launcher_sm90.inl | 5 +- .../include/moe_gemm_kernels.h | 5 +- .../launchers/fused_moe_gemm_launcher_sm80.h | 2 +- .../fused_moe_gemm_launcher_sm80.inl | 4 +- .../launchers/moe_gemm_tma_ws_launcher.h | 6 +- .../launchers/moe_gemm_tma_ws_launcher.inl | 81 ++++++----- .../moe_gemm_tma_ws_mixed_input_launcher.h | 10 +- .../moe_gemm_tma_ws_mixed_input_launcher.inl | 4 +- .../moe_gemm/moe_gemm_template_dispatch.h | 102 ++++++++------ .../moe_gemm_template_dispatch_tma_ws.h | 56 ++++++-- ...emm_template_dispatch_tma_ws_mixed_dtype.h | 6 +- .../python/generate_kernels.py | 47 +++---- .../include/moe_gemm_kernels.h | 1 - cpp/tests/CMakeLists.txt | 4 +- cpp/tests/unit_tests/kernels/CMakeLists.txt | 3 + .../inflight_batcher_llm/CMakeLists.txt | 2 +- 27 files changed, 320 insertions(+), 222 deletions(-) create mode 160000 3rdparty/dynamic-kernel-generator diff --git a/.gitmodules b/.gitmodules index 00ff73d136..cb7cb5e4ac 100644 --- a/.gitmodules +++ b/.gitmodules @@ -26,3 +26,6 @@ [submodule "3rdparty/cppzmq"] path = 3rdparty/cppzmq url = https://github.com/zeromq/cppzmq.git +[submodule "3rdparty/dynamic-kernel-generator"] + path = 3rdparty/dynamic-kernel-generator + url = ssh://git@gitlab-master.nvidia.com:12051/dlarch-fastkernels/dynamic-kernel-generator.git diff --git a/3rdparty/cutlass b/3rdparty/cutlass index dc4817921e..a1aaf2300a 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit dc4817921edda44a549197ff3a9dcf5df0636e7b +Subproject commit a1aaf2300a8fc3a8106a05436e1a2abad0930443 diff --git a/3rdparty/dynamic-kernel-generator b/3rdparty/dynamic-kernel-generator new file mode 160000 index 0000000000..34bfe35573 --- /dev/null +++ b/3rdparty/dynamic-kernel-generator @@ -0,0 +1 @@ +Subproject commit 34bfe3557372d1d2cebe3c90448b03756c6a16eb diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 24779ef6b0..44127ab087 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -215,8 +215,8 @@ include_directories( ${CUDAToolkit_INCLUDE_DIRS}/cccl ${CUDNN_ROOT_DIR}/include $ - ${3RDPARTY_DIR}/cutlass/include - ${3RDPARTY_DIR}/cutlass/tools/util/include + ${3RDPARTY_DIR}/dynamic-kernel-generator/cutlass/include + ${3RDPARTY_DIR}/dynamic-kernel-generator/cutlass/tools/util/include ${3RDPARTY_DIR}/NVTX/include ${3RDPARTY_DIR}/json/include) if(BINDING_TYPE STREQUAL "pybind" OR BUILD_DEEP_EP) diff --git a/cpp/cmake/modules/cuda_configuration.cmake b/cpp/cmake/modules/cuda_configuration.cmake index 3e40c9b15b..b9d910fdee 100644 --- a/cpp/cmake/modules/cuda_configuration.cmake +++ b/cpp/cmake/modules/cuda_configuration.cmake @@ -150,6 +150,9 @@ function(setup_cuda_architectures) if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.7") list(APPEND CMAKE_CUDA_ARCHITECTURES_RAW 100 120) endif() + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.9") + list(APPEND CMAKE_CUDA_ARCHITECTURES_RAW 103) + endif() endif() # CMAKE_CUDA_ARCHITECTURES_ORIG contains all architectures enabled, without @@ -160,7 +163,14 @@ function(setup_cuda_architectures) ${CMAKE_CUDA_ARCHITECTURES_ORIG} PARENT_SCOPE) - set(ARCHITECTURES_WITH_KERNELS 80 86 89 90 120) + set(ARCHITECTURES_WITH_KERNELS + 80 + 86 + 89 + 90 + 100 + 103 + 120) foreach(CUDA_ARCH IN LISTS ARCHITECTURES_WITH_KERNELS) if(NOT ${CUDA_ARCH} IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG) add_definitions("-DEXCLUDE_SM_${CUDA_ARCH}") diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h index f9355860be..3fab43a3b4 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h @@ -133,13 +133,10 @@ enum class CutlassTileConfigSM100 CtaShape128x256x128B, CtaShape128x128x256B, CtaShape128x256x256B, - - // M=256 - CtaShape256x64x128B, - CtaShape256x128x128B, - CtaShape256x256x128B, }; +using CutlassTileConfigSM103 = CutlassTileConfigSM100; + enum class CutlassTileConfigSM120 { // Signals that we should run heuristics do choose a config @@ -461,14 +458,15 @@ struct CutlassGemmConfig } CutlassGemmConfig(CutlassTileConfigSM100 tile_config_sm100, MainloopScheduleType mainloop_schedule, - EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape) + EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape, int sm_version = 100) : tile_config_sm100(tile_config_sm100) , mainloop_schedule(mainloop_schedule) , epilogue_schedule(epilogue_schedule) , cluster_shape(cluster_shape) - , sm_version(100) + , sm_version(sm_version) , is_tma_warp_specialized(true) { + assert(sm_version >= 100 && sm_version < 120 && "Expected SM 10x version"); } CutlassGemmConfig(CutlassTileConfigSM120 tile_config_sm120, MainloopScheduleType mainloop_schedule, diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt b/cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt index 7a02cdee73..5a0eb518ee 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt @@ -28,7 +28,7 @@ if(NOT Python3_EXECUTABLE) endif() execute_process( - WORKING_DIRECTORY ${3RDPARTY_DIR}/cutlass/python/ + WORKING_DIRECTORY ${3RDPARTY_DIR}/dynamic-kernel-generator/cutlass/python/ COMMAND ${Python3_EXECUTABLE} setup_library.py develop --user RESULT_VARIABLE _CUTLASS_LIBRARY_SUCCESS) @@ -72,10 +72,14 @@ function(process_target target_name enable_hopper enable_blackwell) if(${enable_blackwell} AND ("100" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG + OR "103" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG OR "120" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG - OR "121" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)) + OR "121" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG + )) - if("100" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG) + # Both 100 and 103 support these kernels + if("100" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG + OR "103" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG) # No kernels should be parsed, unless blackwell is specified. This is a # build time improvement target_compile_definitions(${target_name} @@ -83,6 +87,13 @@ function(process_target target_name enable_hopper enable_blackwell) target_compile_definitions(${target_name} PUBLIC COMPILE_BLACKWELL_TMA_GROUPED_GEMMS) endif() + # SM103 only kernels + if("103" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG) + target_compile_definitions(${target_name} + PUBLIC COMPILE_BLACKWELL_SM103_TMA_GEMMS) + target_compile_definitions( + ${target_name} PUBLIC COMPILE_BLACKWELL_SM103_TMA_GROUPED_GEMMS) + endif() if("120" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG OR "121" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG) target_compile_definitions(${target_name} @@ -113,6 +124,8 @@ function(add_instantiations library base_dir) list(LENGTH INSTANTIATIONS_GENERATED_${ARCH} n) if(${n} GREATER 0) set(TARGET_NAME "_${library}_instantiations_${ARCH}") + message( + STATUS "Adding target ${TARGET_NAME} with instantiations for ${ARCH}") add_library(${TARGET_NAME} OBJECT ${INSTANTIATIONS_GENERATED_${ARCH}}) target_link_libraries(${library} PRIVATE ${TARGET_NAME}) set_cuda_architectures(${TARGET_NAME} ${BUILD_ARCHS}) @@ -128,6 +141,7 @@ function(add_instantiations library base_dir) glob_src_create_target(80 "80;86") glob_src_create_target(90 90) glob_src_create_target(100 100f) + glob_src_create_target(103 103) glob_src_create_target(120 120f) endfunction() @@ -231,7 +245,7 @@ if(USING_OSS_CUTLASS_MOE_GEMM) add_cuda_architectures(_moe_gemm_launcher 89) add_library(_moe_gemm_fp4 OBJECT ${MOE_GEMM_SRC_CU_FP4}) - set_cuda_architectures(_moe_gemm_fp4 100f 120f) + set_cuda_architectures(_moe_gemm_fp4 100f 103 120f) process_target(_moe_gemm_fp4 false true) add_library(_moe_gemm_fp8 OBJECT ${MOE_GEMM_SRC_CU_FP8}) diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp index 9e3bbaa32b..29e1528f1e 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp @@ -367,7 +367,8 @@ std::vector get_candidate_configs_sm90(CutlassGemmConfig::Can return candidate_configs; } -std::vector get_candidate_configs_sm100(CutlassGemmConfig::CandidateConfigTypeParam const config) +std::vector get_candidate_configs_sm100( + CutlassGemmConfig::CandidateConfigTypeParam const config, int sm) { #ifdef FAST_BUILD // Fast build disables all configs except this one for SM100 @@ -377,72 +378,78 @@ std::vector get_candidate_configs_sm100(CutlassGemmConfig::Ca if (config & CutlassGemmConfig::GROUPED_GEMM) { std::vector candidate_configs; - if ((config & CutlassGemmConfig::FP4_ONLY) != 0) + if (config & CutlassGemmConfig::FP4_ONLY) { - candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, - MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); - candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x128x128B, - MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); - candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B, - MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); - candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x256x128B, - MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); - candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B, - MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1}); - candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x64x128B, - MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); - candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B, - MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); - return candidate_configs; + if (sm == 103) + { + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM103::CtaShape128x128x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1, sm}); + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM103::CtaShape128x128x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1, sm}); + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM103::CtaShape128x256x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1, sm}); + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM103::CtaShape128x256x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1, sm}); + return candidate_configs; + } + else + { + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); + // TODO These need a specific epilogue sub tile (128, 64), not EpilogueTileAuto, otherwise they crash + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1}); + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); + return candidate_configs; + } } - for (int cluster_m = 1; cluster_m <= 2; cluster_m++) + std::vector> tile_configs{ + {CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_1x1x1}, + {CutlassTileConfigSM100::CtaShape128x256x128B, ClusterShape::ClusterShape_1x1x1}, + {CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_1x2x1}, + {CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_1x2x1}, + {CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_2x1x1}, + {CutlassTileConfigSM100::CtaShape64x256x128B, ClusterShape::ClusterShape_2x1x1}, + {CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_2x2x1}, + {CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_2x2x1}, + {CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_2x1x1}, + {CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_2x1x1}, + {CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_2x1x1}, + {CutlassTileConfigSM100::CtaShape128x256x128B, ClusterShape::ClusterShape_2x1x1}, + {CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_2x2x1}, + {CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_2x2x1}, + {CutlassTileConfigSM100::CtaShape128x32x128B, ClusterShape::ClusterShape_1x1x1}, + {CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_1x1x1}, + {CutlassTileConfigSM100::CtaShape64x32x128B, ClusterShape::ClusterShape_1x2x1}, + {CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_1x1x1}, + {CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_1x2x1}, + {CutlassTileConfigSM100::CtaShape64x256x128B, ClusterShape::ClusterShape_1x1x1}, + {CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_1x2x1}, + {CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_1x1x1}, + {CutlassTileConfigSM100::CtaShape128x32x128B, ClusterShape::ClusterShape_1x2x1}, + }; + + if (config & CutlassGemmConfig::FP8_ONLY) { - bool Is2SM = cluster_m == 2; - for (int cluster_n = 1; cluster_n <= 2; cluster_n++) - { - std::vector base = {// M=128 - CutlassTileConfigSM100::CtaShape128x128x128B, CutlassTileConfigSM100::CtaShape128x256x128B}; + tile_configs.push_back({CutlassTileConfigSM100::CtaShape128x16x128B, ClusterShape::ClusterShape_1x1x1}); + // TODO(sklevtsov): re-enable when handled by the MoE GEMM dispatch + // tile_configs.push_back({ CutlassTileConfigSM100::CtaShape128x8x256B, ClusterShape::ClusterShape_1x1x1 }); + } - if (Is2SM) - { - if (cluster_n == 1) - { - base.push_back(CutlassTileConfigSM100::CtaShape128x64x128B); - base.push_back(CutlassTileConfigSM100::CtaShape256x64x128B); - } - - std::vector twosm = {// M=256 - CutlassTileConfigSM100::CtaShape256x128x128B, CutlassTileConfigSM100::CtaShape256x256x128B}; - std::copy(twosm.begin(), twosm.end(), std::back_inserter(base)); - } - else - { - if (cluster_n == 1) - { - base.push_back(CutlassTileConfigSM100::CtaShape128x32x128B); - if ((config & CutlassGemmConfig::FP8_ONLY) != 0) - { - base.push_back(CutlassTileConfigSM100::CtaShape128x16x128B); - } - } - - std::vector onesm{CutlassTileConfigSM100::CtaShape64x64x128B, - CutlassTileConfigSM100::CtaShape64x128x128B, CutlassTileConfigSM100::CtaShape64x256x128B, - CutlassTileConfigSM100::CtaShape128x64x128B}; - std::copy(onesm.begin(), onesm.end(), std::back_inserter(base)); - } - - constexpr std::array cluster_shapes - = {std::array{ClusterShape::ClusterShape_1x1x1, ClusterShape::ClusterShape_1x2x1}, - std::array{ClusterShape::ClusterShape_2x1x1, ClusterShape::ClusterShape_2x2x1}}; - auto cluster = cluster_shapes[cluster_m - 1][cluster_n - 1]; - for (auto tile : base) - { - CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, cluster}; - candidate_configs.push_back(config); - } - } + for (auto [tile, cluster] : tile_configs) + { + CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, cluster}; + candidate_configs.push_back(config); } return candidate_configs; } @@ -523,7 +530,7 @@ std::vector get_candidate_configs( } if (sm >= 100 && sm < 120 && (config_type_param & CutlassGemmConfig::BLACKWELL)) { - return get_candidate_configs_sm100(config_type_param); + return get_candidate_configs_sm100(config_type_param, sm); } if (sm >= 120 && (config_type_param & CutlassGemmConfig::BLACKWELL)) { diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h index 07ea2923fb..9b8b057226 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h @@ -50,6 +50,7 @@ namespace kernels { namespace cutlass_kernels { +using namespace cute; template ::value || cutlass::platform::is_same::value, "ScaleZeroType must be half for activation=fp8"); - sm90_dispatch_gemm_to_cutlass(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, workspace_ptr, - workspace_bytes, gemm_config, stream, occupancy); + cutlass_kernels_oss::sm90_dispatch_gemm_to_cutlass(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, + group_size, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); } else { diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h index 2253e2339b..7c33d610c0 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h @@ -30,7 +30,7 @@ namespace tensorrt_llm { namespace kernels { -namespace cutlass_kernels +namespace cutlass_kernels_oss { namespace tk = tensorrt_llm::common; namespace tkc = tensorrt_llm::cutlass_extensions; @@ -268,6 +268,6 @@ void sm90_dispatch_gemm_to_cutlass(ActivationType const* A, WeightType const* B, } } -} // namespace cutlass_kernels +} // namespace cutlass_kernels_oss } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h index 9405287ccf..b90970f0c2 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h @@ -22,7 +22,7 @@ namespace tensorrt_llm { namespace kernels { -namespace cutlass_kernels +namespace cutlass_kernels_oss { template ; - static constexpr bool use_wfp4afp4 = std::is_same_v && std::is_same_v; + static constexpr bool use_wfp4afp8 = std::is_same_v && std::is_same_v; #else static constexpr bool use_fp4 = false; - static constexpr bool use_wfp4afp4 = false; + static constexpr bool use_wfp4afp8 = false; #endif void moeGemmBiasAct(GroupedGemmInput inputs, diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h index f4eed277c1..efc7d359f8 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h @@ -14,7 +14,7 @@ * limitations under the License. */ -namespace tensorrt_llm::kernels::cutlass_kernels +namespace tensorrt_llm::kernels::cutlass_kernels_oss { template diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl index f2be4057b9..b08ebe2d63 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl @@ -27,7 +27,7 @@ #include "cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh" #include "tensorrt_llm/common/cudaUtils.h" -namespace tensorrt_llm::kernels::cutlass_kernels +namespace tensorrt_llm::kernels::cutlass_kernels_oss { template @@ -93,4 +93,4 @@ void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWe auto result = cudaGetLastError(); TLLM_CHECK_WITH_INFO(result == cudaSuccess, "Fail to execute fused moe kernel, cuda error %d\n", (int) (result)); } -} // namespace tensorrt_llm::kernels::cutlass_kernels +} // namespace tensorrt_llm::kernels::cutlass_kernels_oss diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.h index dfede09995..8632e04317 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.h @@ -19,9 +19,9 @@ #include "../../include/moe_gemm_kernels.h" #include -namespace tensorrt_llm::kernels::cutlass_kernels +namespace tensorrt_llm::kernels::cutlass_kernels_oss { - +using tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput; // Keep in sync with the signature generated by generate_kernels.py template = 120) { @@ -195,10 +202,26 @@ using SafeBF16 = void; using T = DataType_; \ using WeightType = WeightType_; \ using OutputType = OutputType_; \ - using EpilogueTag = tensorrt_llm::cutlass_extensions::EpilogueTag_; \ - using TileShape = cute::Shape, cute::Int, cute::Int>; \ - using ClusterShape = cute::Shape, cute::Int, cute::Int>; \ constexpr static bool IsMXFPX = MXFPX_; \ + using Arch = ArchTag; \ + constexpr static bool IsBlackwell = Arch::kMinComputeCapability >= 100; \ + constexpr static bool IsSM120 = Arch::kMinComputeCapability == 120 || Arch::kMinComputeCapability == 121; \ + constexpr static bool IsSM103 = ArchTag::kMinComputeCapability == 103; \ + constexpr static bool IsWFP4AFP8 \ + = cutlass::platform::is_same::value && cutlass::platform::is_same::value; \ + constexpr static bool IsFP4 = cutlass::platform::is_same::value; \ + static_assert(!IsFP4 || IsBlackwell, "FP4 is only supported by SM100"); \ + \ + constexpr static bool IsFP8 = cutlass::platform::is_same::value; \ + \ + constexpr static bool IsSM103FP4 = IsSM103 && IsFP4; \ + static_assert(IsSM103 == IsSM103FP4, "SM103 only implemented for fp4"); \ + \ + constexpr static bool Is2SM = IsBlackwell && (CGA_M_ % 2 == 0); \ + using EpilogueTag = tensorrt_llm::cutlass_extensions::EpilogueTag_; \ + using MmaTileShape = cute::Shape, cute::Int, \ + cute::Int>; \ + using ClusterShape = cute::Shape, cute::Int, cute::Int>; \ \ if constexpr (!COMPILE_HOPPER_TMA_GROUPED_GEMMS_ENABLED && ArchTag::kMinComputeCapability >= 90 \ && ArchTag::kMinComputeCapability < 100) \ @@ -217,24 +240,15 @@ using SafeBF16 = void; TLLM_THROW( \ "Please recompile with support for blackwell by passing 120-real as an arch to build_wheel.py."); \ } \ - else if constexpr (!should_filter_tma_warp_specialized_gemm_problem_shape_v) \ + else if constexpr (!should_filter_tma_warp_specialized_gemm_problem_shape_v) \ { \ using namespace cute; \ /* Helper class for defining all the cutlass types \ // template \ + // typename MmaTileShape, typename ClusterShape, bool BIAS, EpilogueFusion FUSION> \ // struct TmaWarpSpecializedGroupedGemmInfo \ { */ \ - using Arch = ArchTag; \ - constexpr static bool IsBlackwell = Arch::kMinComputeCapability >= 100; \ - constexpr static bool IsSM120 = Arch::kMinComputeCapability == 120 || Arch::kMinComputeCapability == 121; \ - constexpr static bool IsWFP4AFP8 = cutlass::platform::is_same::value \ - && cutlass::platform::is_same::value; \ - constexpr static bool IsFP4 = cutlass::platform::is_same::value; \ - static_assert(!IsFP4 || IsBlackwell, "FP4 is only supported by SM100"); \ - \ - constexpr static bool IsFP8 = cutlass::platform::is_same::value; \ \ /* TODO Update once mixed input support is added */ \ static_assert(cutlass::platform::is_same::value || IsWFP4AFP8, \ @@ -332,15 +346,16 @@ using SafeBF16 = void; constexpr static bool Is2SM = IsBlackwell && (cute::size<0>(ClusterShape{}) % 2) == 0; \ using EpilogueScheduleSM100 = std::conditional_t; \ + using EpilogueScheduleSM103 \ + = std::conditional_t; \ + using EpilogueScheduleSM10x \ + = std::conditional_t; \ + \ using EpilogueScheduleSM120 = cutlass::epilogue::TmaWarpSpecialized; \ - using EpilogueScheduleBW = std ::conditional_t; \ + using EpilogueScheduleBW = std ::conditional_t; \ using EpilogueSchedule = std::conditional_t; \ \ - using EpilogueTileShapeSm90 = TileShape; \ - using AtomClusterDiv = std::conditional_t; \ - using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape{})); \ - using EpilogueTileShapeSm100 = decltype(shape_div(TileShape{}, AtomThrShape{})); \ - using EpilogueTileShape = std::conditional_t; \ using EpilogueElementC = std::conditional_t; \ using EpilogueTensorOp = std::conditional_t; \ @@ -350,7 +365,7 @@ using SafeBF16 = void; /* Epilogue For Default Finalize */ \ using CollectiveEpilogueDefault = typename cutlass::epilogue::collective::CollectiveBuilder; \ \ /* TRT-LLM uses vector size 16 for block scaled */ \ + using KernelScheduleSM103 = std::conditional_t; \ + \ using KernelScheduleSM100 = std::conditional_t, \ std::conditional_t>; \ + using KernelScheduleSM10x = std::conditional_t; \ + \ using KernelScheduleSM120 = cutlass ::gemm ::collective::KernelScheduleAuto; \ - using KernelScheduleBW = std::conditional_t; \ + using KernelScheduleBW = std::conditional_t; \ \ using KernelSchedule = std::conditional_t; \ \ @@ -405,16 +426,12 @@ using SafeBF16 = void; using MainloopElementA = std::conditional_t; \ using MainloopElementB = std::conditional_t; \ \ - using MainloopTileShapeSm90 = TileShape; \ - using MainloopTileShapeSm100 = decltype(shape_div(TileShape{}, AtomThrShape{})); \ - using MainloopTileShape = std::conditional_t; \ - \ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder::CollectiveOp; \ \ using GemmKernel = cutlass::gemm::kernel::GemmUniversal; \ // \ // using ElementAccumulator = typename GemmInfo::ElementAccumulator; \ @@ -614,6 +631,6 @@ using SafeBF16 = void; TmaWarpSpecializedGroupedGemmInput tma_ws_input, int num_experts, int const multi_processor_count, \ cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size); -} // namespace cutlass_kernels +} // namespace cutlass_kernels_oss } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h index f63df8944b..2b6b3a81cd 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h @@ -23,15 +23,17 @@ namespace tensorrt_llm { namespace kernels { -namespace cutlass_kernels +namespace cutlass_kernels_oss { - +using tensorrt_llm::kernels::cutlass_kernels::GroupedGemmInput; +using tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput; template -void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput inputs, +void sm90_generic_mixed_moe_gemm_kernelLauncher( + tensorrt_llm::kernels::cutlass_kernels::GroupedGemmInput inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size); -} // namespace cutlass_kernels +} // namespace cutlass_kernels_oss } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl index a0ebfbde34..eac301fe82 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl @@ -65,7 +65,7 @@ namespace tensorrt_llm { namespace kernels { -namespace cutlass_kernels +namespace cutlass_kernels_oss { namespace tk = tensorrt_llm::common; namespace tkc = tensorrt_llm::cutlass_extensions; @@ -244,6 +244,6 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput( - reinterpret_cast(inputs.A), reinterpret_cast(inputs.B), + tensorrt_llm::kernels::cutlass_kernels_oss::sm80_generic_fused_moe_gemm_kernelLauncher(reinterpret_cast(inputs.A), + reinterpret_cast(inputs.B), reinterpret_cast(inputs.biases), inputs.bias_is_broadcast, reinterpret_cast(inputs.C), inputs.total_tokens_including_expert, inputs.num_rows, inputs.n, inputs.k, inputs.num_experts, sm_count_, inputs.stream, inputs.occupancy); @@ -242,16 +243,18 @@ static void dispatch(GroupedGemmInput) &&!isFp4) { // dispatch for quant op type - auto* launcher = kernels::cutlass_kernels::genericMoeGemmKernelLauncher::call; + auto* launcher + = tensorrt_llm::kernels::cutlass_kernels::genericMoeGemmKernelLauncher::call; if (!std::is_same_v && inputs.groupwise_quant_group_size > 0) { - launcher = inputs.zeros ? kernels::cutlass_kernels::genericMoeGemmKernelLauncher::call - : kernels::cutlass_kernels::genericMoeGemmKernelLauncher::call; + launcher = inputs.zeros + ? tensorrt_llm::kernels::cutlass_kernels::genericMoeGemmKernelLauncher::call + : tensorrt_llm::kernels::cutlass_kernels::genericMoeGemmKernelLauncher::call; } launcher(inputs, sm_count_); } @@ -503,13 +506,14 @@ MoeGemmRunner::getAmpereConfigs(int sm auto config_type_param = static_cast( weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper | fp8_only_flag); - if (!kernels::cutlass_kernels::isValidAmpereMOESpecialisation() || (use_w4afp8 && sm != 89)) + if (!tensorrt_llm::kernels::cutlass_kernels::isValidAmpereMOESpecialisation() + || (use_w4afp8 && sm != 89)) { return {}; } std::vector ampere_configs - = kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param); + = tensorrt_llm::kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param); return ampere_configs; } @@ -528,30 +532,40 @@ MoeGemmRunner::getTmaWarpSpecializedCo int const enable_hopper = sm == 90 ? CutlassGemmConfig::HOPPER : CutlassGemmConfig::NONE; static constexpr auto fp8_only_flag = use_fp8 ? CutlassGemmConfig::FP8_ONLY : CutlassGemmConfig::NONE; static constexpr auto fp4_only_flag - = (use_fp4 || use_wfp4afp4) ? CutlassGemmConfig::FP4_ONLY : CutlassGemmConfig::NONE; + = (use_fp4 || use_wfp4afp8) ? CutlassGemmConfig::FP4_ONLY : CutlassGemmConfig::NONE; auto config_type_param = static_cast(weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_blackwell | enable_hopper | fp8_only_flag | fp4_only_flag); TLLM_CHECK_WITH_INFO(!(enable_blackwell && enable_hopper), "Blackwell and hopper flags are mutually exclusive"); - if (sm >= 100 && sm < 120 && !kernels::cutlass_kernels::isValidBlackwellMOESpecialisation()) + sm = use_wfp4afp8 && sm == 103 ? 100 : sm; + if (sm >= 100 && sm < 120 + && !tensorrt_llm::kernels::cutlass_kernels::isValidBlackwellMOESpecialisation()) { TLLM_LOG_TRACE("Blackwell is not supported for this configuration, not selecting any TMA WS implementations"); return {}; } - if ((sm == 120 || sm == 121) && !kernels::cutlass_kernels::isValidSM120MOESpecialisation()) + if ((sm == 120 || sm == 121) + && !tensorrt_llm::kernels::cutlass_kernels::isValidSM120MOESpecialisation()) { TLLM_LOG_TRACE( "Blackwell SM120 is not supported for this configuration, not selecting any TMA WS implementations"); return {}; } - if (enable_hopper && !kernels::cutlass_kernels::isValidHopperMOESpecialisation()) + if (enable_hopper && !tensorrt_llm::kernels::cutlass_kernels::isValidHopperMOESpecialisation()) { TLLM_LOG_TRACE("Hopper is not supported for this configuration, not selecting any TMA WS implementations"); return {}; } std::vector tma_ws_configs - = kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param); + = tensorrt_llm::kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param); + if (sm == 103 && use_fp4) + { + // Explicitly select SM100 as well + auto sm100_configs + = tensorrt_llm::kernels::cutlass_kernels::get_candidate_configs(100, max_split_k, config_type_param); + std::copy(sm100_configs.begin(), sm100_configs.end(), std::back_inserter(tma_ws_configs)); + } return tma_ws_configs; } @@ -566,9 +580,11 @@ bool MoeGemmRunner::isTmaWarpSpecializ template bool MoeGemmRunner::supportsTmaWarpSpecialized() const { - return (sm_ == 90 && kernels::cutlass_kernels::isValidHopperMOESpecialisation()) - || (sm_ >= 100 && sm_ < 120 && kernels::cutlass_kernels::isValidBlackwellMOESpecialisation()) - || ((sm_ == 120 || sm_ == 121) && kernels::cutlass_kernels::isValidSM120MOESpecialisation()); + return (sm_ == 90 && tensorrt_llm::kernels::cutlass_kernels::isValidHopperMOESpecialisation()) + || (sm_ >= 100 && sm_ < 120 + && tensorrt_llm::kernels::cutlass_kernels::isValidBlackwellMOESpecialisation()) + || ((sm_ == 120 || sm_ == 121) + && tensorrt_llm::kernels::cutlass_kernels::isValidSM120MOESpecialisation()); } template @@ -658,15 +674,16 @@ void MoeGemmRunner::dispatchToArch( } } - if constexpr (kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation() + if constexpr (tensorrt_llm::kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation() && !use_w4afp8) { // We allow both tma warp specialized and SM80 configurations to coexist because for some cases with small // numbers of tokens SM80 is faster. We check here to see which is selected if (inputs.gemm_config.sm_version >= 90) { - TLLM_CHECK_WITH_INFO( - (inputs.gemm_config.sm_version == sm_) || (inputs.gemm_config.sm_version == 100 && sm_ == 103), + // Check the major version of the SM matches + TLLM_CHECK_WITH_INFO(inputs.gemm_config.sm_version / 10 == sm_ / 10, "Using SM %d configuration for SM %d device", inputs.gemm_config.sm_version, sm_); TLLM_CHECK_WITH_INFO(inputs.biases != nullptr || hopper_inputs.ptr_c == nullptr, "Input biases and hopper input disagree if bias is enabled"); @@ -679,11 +696,11 @@ void MoeGemmRunner::dispatchToArch( switch (hopper_inputs.fusion) { case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE: - return &dispatchMoeGemmSelectTileShapeTmaWarpSpecialized; + return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized; case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE: - return &dispatchMoeGemmSelectTileShapeTmaWarpSpecialized; + return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized; case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::ACTIVATION: case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::GATED_ACTIVATION: default: TLLM_THROW("Unimplemented fusion %d requested", (int) hopper_inputs.fusion); @@ -707,19 +724,19 @@ void MoeGemmRunner::dispatchToArch( // EpilogueTag is ignored if (inputs.k % 512 == 0) { - sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( inputs, hopper_inputs, multi_processor_count_, nullptr); } else if (inputs.k % 256 == 0) { - sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( inputs, hopper_inputs, multi_processor_count_, nullptr); } else if (inputs.k % 128 == 0) { - sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( inputs, hopper_inputs, multi_processor_count_, nullptr); } @@ -733,7 +750,8 @@ void MoeGemmRunner::dispatchToArch( #endif // Do Ampere case instead - if constexpr (kernels::cutlass_kernels::isValidAmpereMOESpecialisation()) + if constexpr (tensorrt_llm::kernels::cutlass_kernels::isValidAmpereMOESpecialisation()) { TLLM_CHECK_WITH_INFO(!use_fp8, "No fallback FP8 implementation available"); TLLM_CHECK_WITH_INFO(use_w4afp8 || !hopper_inputs.isValid(), @@ -782,26 +800,19 @@ size_t MoeGemmRunner::calcMaxWorkspace { if constexpr (use_w4afp8) { - return calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput( + return cutlass_kernels_oss::calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput( num_experts, multi_processor_count_); } if (!supportsTmaWarpSpecialized()) { return 0; } - // #ifndef CUTLASS_ARCH_MMA_SM100F_SUPPORTED - // static_assert(__CUDA_ARCH__ == 1000, "__CUDA_ARCH__"); - // static_assert(CUTLASS_ARCH_MMA_SM100_SUPPORTED, "CUTLASS_ARCH_MMA_SM100F_SUPPORTED"); - // static_assert(CUTLASS_ARCH_MMA_SM100_ENABLED, "CUTLASS_ARCH_MMA_SM100_ENABLED"); - // static_assert(CUTLASS_ARCH_MMA_SM100F_SUPPORTED, "CUTLASS_ARCH_MMA_SM100F_SUPPORTED"); - // static_assert(CUTLASS_ARCH_MMA_SM100F_ENABLED, "CUTLASS_ARCH_MMA_SM100F_ENABLED"); - // // #error "SM100F not supported!" - // #endif - if constexpr (kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation() && !use_w4afp8) + if constexpr (tensorrt_llm::kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation() + && !use_w4afp8) { auto configs = getTmaWarpSpecializedConfigs(sm_); auto fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE; - if constexpr (use_wfp4afp4) + if constexpr (use_wfp4afp8) { fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; } @@ -818,8 +829,9 @@ size_t MoeGemmRunner::calcMaxWorkspace { \ try \ { \ - size_t size = calcMaxWorkspaceSizeTmaWarpSpecialized( \ - num_experts, conf, multi_processor_count_, fpX_block_scaling_type); \ + size_t size \ + = cutlass_kernels_oss::calcMaxWorkspaceSizeTmaWarpSpecialized( \ + num_experts, conf, multi_processor_count_, fpX_block_scaling_type); \ max_size = std::max(max_size, size); \ has_config = true; \ } \ diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h index d9df31513f..48726f4b32 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h @@ -64,8 +64,9 @@ #include #include -namespace tensorrt_llm::kernels::cutlass_kernels +namespace tensorrt_llm::kernels::cutlass_kernels_oss { +using tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput; using EpilogueFusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion; template = 100 && Arch::kMinComputeCapability < 120) { @@ -122,31 +129,36 @@ void dispatchMoeGemmSelectBiasTmaWarpSpecialized(TmaWarpSpecializedGroupedGemmIn TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, "MXFPX is the only supported scaling type for WFP4AFP8"); - return &kernels::cutlass_kernels::tma_warp_specialized_generic_moe_gemm_kernelLauncher; + return &tma_warp_specialized_generic_moe_gemm_kernelLauncher; } else { TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type != TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, "MXFPX is not supported for the selected weight combination"); - return &kernels::cutlass_kernels::tma_warp_specialized_generic_moe_gemm_kernelLauncher; + return &tma_warp_specialized_generic_moe_gemm_kernelLauncher; } }; getFunc()(hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size); } } -template +template constexpr bool are_tile_shapes_supported_sm100() { using namespace cute; - using CtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); // This is the epilogue shape. The MMA shape will be twice this for 2SM constexpr auto TileM = size<0>(CtaShape{}); constexpr auto TileN = size<1>(CtaShape{}); + if constexpr (Arch::kMinComputeCapability == 103) + { + return std::is_same_v && std::is_same_v && TileM == 128 + && (TileN == 128 || TileN == 256); + } + if constexpr (TileM != 64 && TileM != 128) { return false; @@ -224,7 +236,7 @@ constexpr bool are_tile_shapes_supported() { if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 120) { - return are_tile_shapes_supported_sm100(); + return are_tile_shapes_supported_sm100(); } else if constexpr (Arch::kMinComputeCapability == 120 || Arch::kMinComputeCapability == 121) { @@ -347,12 +359,34 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized(TmaWarpSpecializedGroupedG TLLM_THROW("Unsupported SM90 configuration requested"); } } +#ifdef ENABLE_FP4 + // Check this before SM100 because we fall back to SM100 if not NVFP4 + else if (gemm_config.sm_version == 103 + && std::is_same_v && std::is_same_v) + { + if constexpr (kernels::cutlass_kernels::isValidBlackwellMOESpecialisation()) + { + switch (gemm_config.tile_config_sm100) + { + SHAPE_CASE(103, 128, 128, 128) + SHAPE_CASE(103, 128, 256, 128) + + DEFAULT_CASE(100) // 100 because we use the same member variable for SM100 and SM103 + } + } + else + { + TLLM_THROW("Unsupported SM103 configuration requested"); + } + } +#endif else if (gemm_config.sm_version >= 100 && gemm_config.sm_version < 120) { if constexpr (kernels::cutlass_kernels::isValidBlackwellMOESpecialisation()) { switch (gemm_config.tile_config_sm100) { + SHAPE_CASE(100, 64, 32, 128) SHAPE_CASE(100, 64, 64, 128) SHAPE_CASE(100, 64, 128, 128) SHAPE_CASE(100, 64, 256, 128) @@ -363,10 +397,6 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized(TmaWarpSpecializedGroupedG SHAPE_CASE(100, 128, 128, 128) SHAPE_CASE(100, 128, 256, 128) - SHAPE_CASE(100, 256, 64, 128) - SHAPE_CASE(100, 256, 128, 128) - SHAPE_CASE(100, 256, 256, 128) - // SHAPE_CASE(100, 128, 128, 64) // SHAPE_CASE(100, 128, 256, 64) // SHAPE_CASE(100, 256, 256, 64) @@ -409,4 +439,4 @@ size_t calcMaxWorkspaceSizeTmaWarpSpecialized(int num_experts, cutlass_extension return count; } -} // namespace tensorrt_llm::kernels::cutlass_kernels +} // namespace tensorrt_llm::kernels::cutlass_kernels_oss diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h index 9a9f2ebeb3..4c0ddebf6a 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h @@ -57,9 +57,11 @@ #include #include -namespace tensorrt_llm::kernels::cutlass_kernels +namespace tensorrt_llm::kernels::cutlass_kernels_oss { +using tensorrt_llm::kernels::cutlass_kernels::GroupedGemmInput; +using tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput; namespace tk = tensorrt_llm::common; namespace tkc = tensorrt_llm::cutlass_extensions; @@ -236,4 +238,4 @@ size_t calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput(int num_experts, int sm_ return count; } -} // namespace tensorrt_llm::kernels::cutlass_kernels +} // namespace tensorrt_llm::kernels::cutlass_kernels_oss diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py b/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py index 7e55098cb6..bdb8af652d 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py @@ -302,12 +302,12 @@ namespace tensorrt_llm {{ namespace kernels {{ -namespace cutlass_kernels +namespace cutlass_kernels_oss {{ {instantiations} -}} // namespace cutlass_kernels +}} // namespace cutlass_kernels_oss }} // namespace kernels }} // namespace tensorrt_llm """ @@ -337,18 +337,16 @@ def write_file(launcher_inl_files, operations, output_file): f.write(content) -from operator import mul, truediv - - -def elementwise(x, y, f): - return tuple(f(a, b) for (a, b) in zip(x, y)) - - def is_gemm_op_valid_sm100(op): # TODO These are much more restricted than theory dictates, investigate if more can be enabled in future - tile_m, tile_n, _ = elementwise(op.cta_shape, op.cga_shape, truediv) + tile_m, tile_n, _ = op.cta_shape cga_m, cga_n, _ = op.cga_shape + if op.arch == 103: + return op.act_type == e2m1 and op.weight_type == e2m1 and tile_m == 128 and tile_n in [ + 128, 256 + ] + # Default shapes # This is epilogue tile size. For two CTA this is actually size 128/256 for the MMA if tile_m not in [64, 128]: @@ -366,10 +364,7 @@ def is_gemm_op_valid_sm100(op): if (op.act_type == DataType.e4m3 and (tile_n == 16 or tile_n == 8) and (cga_m == 1 and cga_n == 1)): # todo: double check why this is disable in CUTLASS backend. @yuhan - if tile_m == 128 and tile_n == 8: - return False - else: - return True + return not (tile_m == 128 and tile_n % 16 != 0) # Default alignment requirements if tile_n % 32 != 0 or tile_n < 32 or tile_n > 256: @@ -628,7 +623,6 @@ def generate_sm120_grouped_gemm_operations(is_arch_enabled): operations = list() for dtype, quant_op, epi_tag, epi_fusion, cta_shape_mnk, cga_shape in partial_args: - cga_tile_shape_mnk = elementwise(cta_shape_mnk, cga_shape, mul) # Ignored mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative @@ -641,8 +635,8 @@ def generate_sm120_grouped_gemm_operations(is_arch_enabled): for otype in otypes: moe_gemm_operation = TrtLlm_GemmLauncher( GemmKind.Grouped, arch, dtype, dtype, dtype, dtype, otype, - quant_op, epi_tag, cga_tile_shape_mnk, warp_shape, stages, - cga_shape, mainloop_schedule, epi_schedule, epi_fusion) + quant_op, epi_tag, cta_shape_mnk, warp_shape, stages, cga_shape, + mainloop_schedule, epi_schedule, epi_fusion) operations.append(moe_gemm_operation) return operations @@ -653,10 +647,9 @@ def generate_sm120_operations(is_arch_enabled): return operations -def generate_sm100_grouped_gemm_operations(is_arch_enabled): +def generate_sm100_grouped_gemm_operations(is_arch_enabled, arch): if not is_arch_enabled: return [] - arch = 100 supported_dtypes = [ DataType.f16, DataType.bf16, DataType.f32, DataType.e4m3, e2m1, (DataType.e4m3, e2m1) @@ -664,7 +657,7 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled): quant_ops = [TrtLlm_QuantOp.none] epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default] cta_shapes_m = [64, 128] - cta_shapes_n = [8, 16, 32, 64, 128, 256] + cta_shapes_n = [8, 16, 32, 64, 128, 192, 256] cta_shapes_mn = product(cta_shapes_m, cta_shapes_n) warp_shape = [0, 0, 0] # ignored except for naming @@ -688,7 +681,6 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled): weight_type = dtype cta_shape_mnk = calc_shape_mnk_sm100_grouped_gemm(cta_shape_mn, dtype) - cga_tile_shape_mnk = elementwise(cta_shape_mnk, cga_shape, mul) # Ignored mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative @@ -709,7 +701,7 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled): otype, quant_op, epi_tag, - cga_tile_shape_mnk, + cta_shape_mnk, warp_shape, stages, cga_shape, @@ -723,8 +715,13 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled): return operations +def generate_sm103_operations(is_arch_enabled): + operations = generate_sm100_grouped_gemm_operations(is_arch_enabled, 103) + return operations + + def generate_sm100_operations(is_arch_enabled): - operations = generate_sm100_grouped_gemm_operations(is_arch_enabled) + operations = generate_sm100_grouped_gemm_operations(is_arch_enabled, 100) return operations @@ -804,6 +801,7 @@ if __name__ == "__main__": (GemmKind.Gemm, 90): [fpA_intB_inl], (GemmKind.Grouped, 90): [moe_gemm_inl], (GemmKind.Grouped, 100): [moe_gemm_inl], + (GemmKind.Grouped, 103): [moe_gemm_inl], (GemmKind.Grouped, 120): [moe_gemm_inl], (GemmKind.Grouped, 80): [sm80_moe_gemm_inl] } @@ -815,7 +813,8 @@ if __name__ == "__main__": # Template instantiation dominates the time in a compilation unit, so it is the most important factor to improve. operations = [] operations += generate_sm120_operations(has_arch(120) or has_arch(121)) - operations += generate_sm100_operations(has_arch(100)) + operations += generate_sm103_operations(has_arch(103)) + operations += generate_sm100_operations(has_arch(100) or has_arch(103)) operations += generate_sm90_operations(has_arch(90)) operations += generate_sm80_operations(has_arch(80) or has_arch(89)) diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_gemm_kernels.h b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_gemm_kernels.h index bbeda00ade..a52f4b7aaf 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_gemm_kernels.h +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_gemm_kernels.h @@ -280,7 +280,6 @@ public: #else static constexpr bool use_fp8 = false; static constexpr bool use_w4afp8 = false; - static constexpr bool use_wfp4afp4 = false; #endif #if defined(ENABLE_FP4) diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index e43226a69d..da4a15284f 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -26,8 +26,8 @@ include(GoogleTest) include_directories( ${PROJECT_SOURCE_DIR}/tensorrt_llm/cutlass_extensions/include ${PROJECT_SOURCE_DIR}/include - ${3RDPARTY_DIR}/cutlass/include - ${3RDPARTY_DIR}/cutlass/tools/util/include + ${3RDPARTY_DIR}/dynamic-kernel-generator/cutlass/include + ${3RDPARTY_DIR}/dynamic-kernel-generator/cutlass/tools/util/include ${PROJECT_SOURCE_DIR}/tests/batch_manager ${PROJECT_SOURCE_DIR}/tests/utils) diff --git a/cpp/tests/unit_tests/kernels/CMakeLists.txt b/cpp/tests/unit_tests/kernels/CMakeLists.txt index ae3750597e..1add9d65e4 100644 --- a/cpp/tests/unit_tests/kernels/CMakeLists.txt +++ b/cpp/tests/unit_tests/kernels/CMakeLists.txt @@ -26,6 +26,9 @@ add_gtest(mixtureOfExpertsTest mixtureOfExpertsTest.cu) # If we are using oss cutlass, build an explicit internal test if(USING_OSS_CUTLASS_MOE_GEMM) + target_compile_definitions(mixtureOfExpertsTest + PUBLIC USING_OSS_CUTLASS_MOE_GEMM) + add_gtest(mixtureOfExpertsInternalTest mixtureOfExpertsTest.cu) remove_compile_definition(mixtureOfExpertsInternalTest USING_OSS_CUTLASS_MOE_GEMM) diff --git a/triton_backend/inflight_batcher_llm/CMakeLists.txt b/triton_backend/inflight_batcher_llm/CMakeLists.txt index 0f26015922..5d3a11269e 100644 --- a/triton_backend/inflight_batcher_llm/CMakeLists.txt +++ b/triton_backend/inflight_batcher_llm/CMakeLists.txt @@ -217,7 +217,7 @@ target_include_directories( ${CUDA_INCLUDE_DIRS} ${CUDNN_ROOT_DIR}/include ${NCCL_INCLUDE_DIR} - ${3RDPARTY_DIR}/cutlass/include + ${3RDPARTY_DIR}/dynamic-kernel-generator/cutlass/include ${MPI_INCLUDE_PATH} ${COMMON_HEADER_DIR})