feat: Add support for SM103 3xFP4 tile shapes

Signed-off-by: Daniel Stokes <dastokes@nvidia.com>
This commit is contained in:
Daniel Stokes 2025-07-08 12:37:34 +12:00 committed by Xiwen Yu
parent 3a94d80839
commit 469a38d0d8
27 changed files with 320 additions and 222 deletions

3
.gitmodules vendored
View File

@ -26,3 +26,6 @@
[submodule "3rdparty/cppzmq"]
path = 3rdparty/cppzmq
url = https://github.com/zeromq/cppzmq.git
[submodule "3rdparty/dynamic-kernel-generator"]
path = 3rdparty/dynamic-kernel-generator
url = ssh://git@gitlab-master.nvidia.com:12051/dlarch-fastkernels/dynamic-kernel-generator.git

2
3rdparty/cutlass vendored

@ -1 +1 @@
Subproject commit dc4817921edda44a549197ff3a9dcf5df0636e7b
Subproject commit a1aaf2300a8fc3a8106a05436e1a2abad0930443

1
3rdparty/dynamic-kernel-generator vendored Submodule

@ -0,0 +1 @@
Subproject commit 34bfe3557372d1d2cebe3c90448b03756c6a16eb

View File

@ -215,8 +215,8 @@ include_directories(
${CUDAToolkit_INCLUDE_DIRS}/cccl
${CUDNN_ROOT_DIR}/include
$<TARGET_PROPERTY:TensorRT::NvInfer,INTERFACE_INCLUDE_DIRECTORIES>
${3RDPARTY_DIR}/cutlass/include
${3RDPARTY_DIR}/cutlass/tools/util/include
${3RDPARTY_DIR}/dynamic-kernel-generator/cutlass/include
${3RDPARTY_DIR}/dynamic-kernel-generator/cutlass/tools/util/include
${3RDPARTY_DIR}/NVTX/include
${3RDPARTY_DIR}/json/include)
if(BINDING_TYPE STREQUAL "pybind" OR BUILD_DEEP_EP)

View File

@ -150,6 +150,9 @@ function(setup_cuda_architectures)
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.7")
list(APPEND CMAKE_CUDA_ARCHITECTURES_RAW 100 120)
endif()
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.9")
list(APPEND CMAKE_CUDA_ARCHITECTURES_RAW 103)
endif()
endif()
# CMAKE_CUDA_ARCHITECTURES_ORIG contains all architectures enabled, without
@ -160,7 +163,14 @@ function(setup_cuda_architectures)
${CMAKE_CUDA_ARCHITECTURES_ORIG}
PARENT_SCOPE)
set(ARCHITECTURES_WITH_KERNELS 80 86 89 90 120)
set(ARCHITECTURES_WITH_KERNELS
80
86
89
90
100
103
120)
foreach(CUDA_ARCH IN LISTS ARCHITECTURES_WITH_KERNELS)
if(NOT ${CUDA_ARCH} IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
add_definitions("-DEXCLUDE_SM_${CUDA_ARCH}")

View File

@ -133,13 +133,10 @@ enum class CutlassTileConfigSM100
CtaShape128x256x128B,
CtaShape128x128x256B,
CtaShape128x256x256B,
// M=256
CtaShape256x64x128B,
CtaShape256x128x128B,
CtaShape256x256x128B,
};
using CutlassTileConfigSM103 = CutlassTileConfigSM100;
enum class CutlassTileConfigSM120
{
// Signals that we should run heuristics do choose a config
@ -461,14 +458,15 @@ struct CutlassGemmConfig
}
CutlassGemmConfig(CutlassTileConfigSM100 tile_config_sm100, MainloopScheduleType mainloop_schedule,
EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape)
EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape, int sm_version = 100)
: tile_config_sm100(tile_config_sm100)
, mainloop_schedule(mainloop_schedule)
, epilogue_schedule(epilogue_schedule)
, cluster_shape(cluster_shape)
, sm_version(100)
, sm_version(sm_version)
, is_tma_warp_specialized(true)
{
assert(sm_version >= 100 && sm_version < 120 && "Expected SM 10x version");
}
CutlassGemmConfig(CutlassTileConfigSM120 tile_config_sm120, MainloopScheduleType mainloop_schedule,

View File

@ -28,7 +28,7 @@ if(NOT Python3_EXECUTABLE)
endif()
execute_process(
WORKING_DIRECTORY ${3RDPARTY_DIR}/cutlass/python/
WORKING_DIRECTORY ${3RDPARTY_DIR}/dynamic-kernel-generator/cutlass/python/
COMMAND ${Python3_EXECUTABLE} setup_library.py develop --user
RESULT_VARIABLE _CUTLASS_LIBRARY_SUCCESS)
@ -72,10 +72,14 @@ function(process_target target_name enable_hopper enable_blackwell)
if(${enable_blackwell}
AND ("100" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG
OR "103" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG
OR "120" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG
OR "121" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG))
OR "121" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG
))
if("100" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
# Both 100 and 103 support these kernels
if("100" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG
OR "103" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
# No kernels should be parsed, unless blackwell is specified. This is a
# build time improvement
target_compile_definitions(${target_name}
@ -83,6 +87,13 @@ function(process_target target_name enable_hopper enable_blackwell)
target_compile_definitions(${target_name}
PUBLIC COMPILE_BLACKWELL_TMA_GROUPED_GEMMS)
endif()
# SM103 only kernels
if("103" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
target_compile_definitions(${target_name}
PUBLIC COMPILE_BLACKWELL_SM103_TMA_GEMMS)
target_compile_definitions(
${target_name} PUBLIC COMPILE_BLACKWELL_SM103_TMA_GROUPED_GEMMS)
endif()
if("120" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG
OR "121" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
target_compile_definitions(${target_name}
@ -113,6 +124,8 @@ function(add_instantiations library base_dir)
list(LENGTH INSTANTIATIONS_GENERATED_${ARCH} n)
if(${n} GREATER 0)
set(TARGET_NAME "_${library}_instantiations_${ARCH}")
message(
STATUS "Adding target ${TARGET_NAME} with instantiations for ${ARCH}")
add_library(${TARGET_NAME} OBJECT ${INSTANTIATIONS_GENERATED_${ARCH}})
target_link_libraries(${library} PRIVATE ${TARGET_NAME})
set_cuda_architectures(${TARGET_NAME} ${BUILD_ARCHS})
@ -128,6 +141,7 @@ function(add_instantiations library base_dir)
glob_src_create_target(80 "80;86")
glob_src_create_target(90 90)
glob_src_create_target(100 100f)
glob_src_create_target(103 103)
glob_src_create_target(120 120f)
endfunction()
@ -231,7 +245,7 @@ if(USING_OSS_CUTLASS_MOE_GEMM)
add_cuda_architectures(_moe_gemm_launcher 89)
add_library(_moe_gemm_fp4 OBJECT ${MOE_GEMM_SRC_CU_FP4})
set_cuda_architectures(_moe_gemm_fp4 100f 120f)
set_cuda_architectures(_moe_gemm_fp4 100f 103 120f)
process_target(_moe_gemm_fp4 false true)
add_library(_moe_gemm_fp8 OBJECT ${MOE_GEMM_SRC_CU_FP8})

View File

@ -367,7 +367,8 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm90(CutlassGemmConfig::Can
return candidate_configs;
}
std::vector<CutlassGemmConfig> get_candidate_configs_sm100(CutlassGemmConfig::CandidateConfigTypeParam const config)
std::vector<CutlassGemmConfig> get_candidate_configs_sm100(
CutlassGemmConfig::CandidateConfigTypeParam const config, int sm)
{
#ifdef FAST_BUILD
// Fast build disables all configs except this one for SM100
@ -377,72 +378,78 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm100(CutlassGemmConfig::Ca
if (config & CutlassGemmConfig::GROUPED_GEMM)
{
std::vector<CutlassGemmConfig> candidate_configs;
if ((config & CutlassGemmConfig::FP4_ONLY) != 0)
if (config & CutlassGemmConfig::FP4_ONLY)
{
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x128x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x256x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x64x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1});
return candidate_configs;
if (sm == 103)
{
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM103::CtaShape128x128x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1, sm});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM103::CtaShape128x128x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1, sm});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM103::CtaShape128x256x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1, sm});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM103::CtaShape128x256x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1, sm});
return candidate_configs;
}
else
{
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1});
// TODO These need a specific epilogue sub tile (128, 64), not EpilogueTileAuto, otherwise they crash
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1});
return candidate_configs;
}
}
for (int cluster_m = 1; cluster_m <= 2; cluster_m++)
std::vector<std::pair<CutlassTileConfigSM100, ClusterShape>> tile_configs{
{CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_1x1x1},
{CutlassTileConfigSM100::CtaShape128x256x128B, ClusterShape::ClusterShape_1x1x1},
{CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_1x2x1},
{CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_1x2x1},
{CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_2x1x1},
{CutlassTileConfigSM100::CtaShape64x256x128B, ClusterShape::ClusterShape_2x1x1},
{CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_2x2x1},
{CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_2x2x1},
{CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_2x1x1},
{CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_2x1x1},
{CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_2x1x1},
{CutlassTileConfigSM100::CtaShape128x256x128B, ClusterShape::ClusterShape_2x1x1},
{CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_2x2x1},
{CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_2x2x1},
{CutlassTileConfigSM100::CtaShape128x32x128B, ClusterShape::ClusterShape_1x1x1},
{CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_1x1x1},
{CutlassTileConfigSM100::CtaShape64x32x128B, ClusterShape::ClusterShape_1x2x1},
{CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_1x1x1},
{CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_1x2x1},
{CutlassTileConfigSM100::CtaShape64x256x128B, ClusterShape::ClusterShape_1x1x1},
{CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_1x2x1},
{CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_1x1x1},
{CutlassTileConfigSM100::CtaShape128x32x128B, ClusterShape::ClusterShape_1x2x1},
};
if (config & CutlassGemmConfig::FP8_ONLY)
{
bool Is2SM = cluster_m == 2;
for (int cluster_n = 1; cluster_n <= 2; cluster_n++)
{
std::vector base = {// M=128
CutlassTileConfigSM100::CtaShape128x128x128B, CutlassTileConfigSM100::CtaShape128x256x128B};
tile_configs.push_back({CutlassTileConfigSM100::CtaShape128x16x128B, ClusterShape::ClusterShape_1x1x1});
// TODO(sklevtsov): re-enable when handled by the MoE GEMM dispatch
// tile_configs.push_back({ CutlassTileConfigSM100::CtaShape128x8x256B, ClusterShape::ClusterShape_1x1x1 });
}
if (Is2SM)
{
if (cluster_n == 1)
{
base.push_back(CutlassTileConfigSM100::CtaShape128x64x128B);
base.push_back(CutlassTileConfigSM100::CtaShape256x64x128B);
}
std::vector twosm = {// M=256
CutlassTileConfigSM100::CtaShape256x128x128B, CutlassTileConfigSM100::CtaShape256x256x128B};
std::copy(twosm.begin(), twosm.end(), std::back_inserter(base));
}
else
{
if (cluster_n == 1)
{
base.push_back(CutlassTileConfigSM100::CtaShape128x32x128B);
if ((config & CutlassGemmConfig::FP8_ONLY) != 0)
{
base.push_back(CutlassTileConfigSM100::CtaShape128x16x128B);
}
}
std::vector onesm{CutlassTileConfigSM100::CtaShape64x64x128B,
CutlassTileConfigSM100::CtaShape64x128x128B, CutlassTileConfigSM100::CtaShape64x256x128B,
CutlassTileConfigSM100::CtaShape128x64x128B};
std::copy(onesm.begin(), onesm.end(), std::back_inserter(base));
}
constexpr std::array cluster_shapes
= {std::array{ClusterShape::ClusterShape_1x1x1, ClusterShape::ClusterShape_1x2x1},
std::array{ClusterShape::ClusterShape_2x1x1, ClusterShape::ClusterShape_2x2x1}};
auto cluster = cluster_shapes[cluster_m - 1][cluster_n - 1];
for (auto tile : base)
{
CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, cluster};
candidate_configs.push_back(config);
}
}
for (auto [tile, cluster] : tile_configs)
{
CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, cluster};
candidate_configs.push_back(config);
}
return candidate_configs;
}
@ -523,7 +530,7 @@ std::vector<CutlassGemmConfig> get_candidate_configs(
}
if (sm >= 100 && sm < 120 && (config_type_param & CutlassGemmConfig::BLACKWELL))
{
return get_candidate_configs_sm100(config_type_param);
return get_candidate_configs_sm100(config_type_param, sm);
}
if (sm >= 120 && (config_type_param & CutlassGemmConfig::BLACKWELL))
{

View File

@ -50,6 +50,7 @@ namespace kernels
{
namespace cutlass_kernels
{
using namespace cute;
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename ThreadblockShape,
@ -452,9 +453,9 @@ void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType
static_assert(!cutlass::platform::is_same<ActivationType, __nv_fp8_e4m3>::value
|| cutlass::platform::is_same<ScaleZeroType, half>::value,
"ScaleZeroType must be half for activation=fp8");
sm90_dispatch_gemm_to_cutlass<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp,
EpilogueTag>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, workspace_ptr,
workspace_bytes, gemm_config, stream, occupancy);
cutlass_kernels_oss::sm90_dispatch_gemm_to_cutlass<ActivationType, WeightType, ScaleZeroType, BiasType,
OutputType, QuantOp, EpilogueTag>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k,
group_size, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
}
else
{

View File

@ -30,7 +30,7 @@ namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels
namespace cutlass_kernels_oss
{
namespace tk = tensorrt_llm::common;
namespace tkc = tensorrt_llm::cutlass_extensions;
@ -268,6 +268,6 @@ void sm90_dispatch_gemm_to_cutlass(ActivationType const* A, WeightType const* B,
}
}
} // namespace cutlass_kernels
} // namespace cutlass_kernels_oss
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -22,7 +22,7 @@ namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels
namespace cutlass_kernels_oss
{
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
@ -34,6 +34,6 @@ void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType
tensorrt_llm::cutlass_extensions::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
cudaStream_t stream, int* occupancy = nullptr);
} // namespace cutlass_kernels
} // namespace cutlass_kernels_oss
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -51,8 +51,9 @@ namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels
namespace cutlass_kernels_oss
{
using namespace tensorrt_llm::kernels::cutlass_kernels;
namespace tk = tensorrt_llm::common;
namespace tkc = tensorrt_llm::cutlass_extensions;
@ -295,6 +296,6 @@ void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType
#endif // COMPILE_HOPPER_TMA_GEMMS
}
} // namespace cutlass_kernels
} // namespace cutlass_kernels_oss
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -280,15 +280,14 @@ public:
#else
static constexpr bool use_fp8 = false;
static constexpr bool use_w4afp8 = false;
static constexpr bool use_wfp4afp4 = false;
#endif
#if defined(ENABLE_FP4)
static constexpr bool use_fp4 = std::is_same_v<T, __nv_fp4_e2m1>;
static constexpr bool use_wfp4afp4 = std::is_same_v<T, __nv_fp8_e4m3> && std::is_same_v<WeightType, __nv_fp4_e2m1>;
static constexpr bool use_wfp4afp8 = std::is_same_v<T, __nv_fp8_e4m3> && std::is_same_v<WeightType, __nv_fp4_e2m1>;
#else
static constexpr bool use_fp4 = false;
static constexpr bool use_wfp4afp4 = false;
static constexpr bool use_wfp4afp8 = false;
#endif
void moeGemmBiasAct(GroupedGemmInput<T, WeightType, ScaleBiasType, OutputType> inputs,

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
namespace tensorrt_llm::kernels::cutlass_kernels
namespace tensorrt_llm::kernels::cutlass_kernels_oss
{
template <typename ElementType_, typename CutlassWeightType_, int MaxTileM_, int TileN_, int TileK_, int Stages_,
typename EpilogueTag>

View File

@ -27,7 +27,7 @@
#include "cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh"
#include "tensorrt_llm/common/cudaUtils.h"
namespace tensorrt_llm::kernels::cutlass_kernels
namespace tensorrt_llm::kernels::cutlass_kernels_oss
{
template <typename ElementType_, typename CutlassWeightType_, int MaxTileM_, int TileN_, int TileK_, int Stages_,
typename EpilogueTag>
@ -93,4 +93,4 @@ void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWe
auto result = cudaGetLastError();
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "Fail to execute fused moe kernel, cuda error %d\n", (int) (result));
}
} // namespace tensorrt_llm::kernels::cutlass_kernels
} // namespace tensorrt_llm::kernels::cutlass_kernels_oss

View File

@ -19,9 +19,9 @@
#include "../../include/moe_gemm_kernels.h"
#include <cuda_runtime_api.h>
namespace tensorrt_llm::kernels::cutlass_kernels
namespace tensorrt_llm::kernels::cutlass_kernels_oss
{
using tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput;
// Keep in sync with the signature generated by generate_kernels.py
template <typename Arch, typename T, typename WeightType, typename OutputType, typename EpilogueTag,
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion FUSION, typename TileShape, typename ClusterShape, bool IsMXFPX,
@ -29,4 +29,4 @@ template <typename Arch, typename T, typename WeightType, typename OutputType, t
void tma_warp_specialized_generic_moe_gemm_kernelLauncher(TmaWarpSpecializedGroupedGemmInput hopper_input,
int num_experts, int multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size);
} // namespace tensorrt_llm::kernels::cutlass_kernels
} // namespace tensorrt_llm::kernels::cutlass_kernels_oss

View File

@ -66,8 +66,9 @@ namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels
namespace cutlass_kernels_oss
{
using namespace tensorrt_llm::kernels::cutlass_kernels;
using EpilogueFusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion;
// Constructs an object with specific arguments only if flag is true
@ -131,6 +132,12 @@ void tma_warp_specialized_generic_moe_gemm_kernelLauncher(TmaWarpSpecializedGrou
TLLM_THROW("Please recompile with support for blackwell by passing 100-real as an arch to build_wheel.py.");
}
#endif
#ifndef COMPILE_BLACKWELL_SM103_TMA_GROUPED_GEMMS
else if constexpr (ArchTag::kMinComputeCapability == 103)
{
TLLM_THROW("Please recompile with support for blackwell by passing 103-real as an arch to build_wheel.py.");
}
#endif
#ifndef COMPILE_BLACKWELL_SM120_TMA_GROUPED_GEMMS
else if constexpr (ArchTag::kMinComputeCapability >= 120)
{
@ -195,10 +202,26 @@ using SafeBF16 = void;
using T = DataType_; \
using WeightType = WeightType_; \
using OutputType = OutputType_; \
using EpilogueTag = tensorrt_llm::cutlass_extensions::EpilogueTag_; \
using TileShape = cute::Shape<cute::Int<CTA_M_>, cute::Int<CTA_N_>, cute::Int<CTA_K_>>; \
using ClusterShape = cute::Shape<cute::Int<CGA_M_>, cute::Int<CGA_N_>, cute::Int<CGA_K_>>; \
constexpr static bool IsMXFPX = MXFPX_; \
using Arch = ArchTag; \
constexpr static bool IsBlackwell = Arch::kMinComputeCapability >= 100; \
constexpr static bool IsSM120 = Arch::kMinComputeCapability == 120 || Arch::kMinComputeCapability == 121; \
constexpr static bool IsSM103 = ArchTag::kMinComputeCapability == 103; \
constexpr static bool IsWFP4AFP8 \
= cutlass::platform::is_same<WeightType, SafeFP4>::value && cutlass::platform::is_same<T, SafeFP8>::value; \
constexpr static bool IsFP4 = cutlass::platform::is_same<T, SafeFP4>::value; \
static_assert(!IsFP4 || IsBlackwell, "FP4 is only supported by SM100"); \
\
constexpr static bool IsFP8 = cutlass::platform::is_same<T, SafeFP8>::value; \
\
constexpr static bool IsSM103FP4 = IsSM103 && IsFP4; \
static_assert(IsSM103 == IsSM103FP4, "SM103 only implemented for fp4"); \
\
constexpr static bool Is2SM = IsBlackwell && (CGA_M_ % 2 == 0); \
using EpilogueTag = tensorrt_llm::cutlass_extensions::EpilogueTag_; \
using MmaTileShape = cute::Shape<cute::Int<CTA_M_*(Is2SM ? 2 : 1)>, cute::Int<CTA_N_>, \
cute::Int<CTA_K_*(IsSM103FP4 ? 3 : 1)>>; \
using ClusterShape = cute::Shape<cute::Int<CGA_M_>, cute::Int<CGA_N_>, cute::Int<CGA_K_>>; \
\
if constexpr (!COMPILE_HOPPER_TMA_GROUPED_GEMMS_ENABLED && ArchTag::kMinComputeCapability >= 90 \
&& ArchTag::kMinComputeCapability < 100) \
@ -217,24 +240,15 @@ using SafeBF16 = void;
TLLM_THROW( \
"Please recompile with support for blackwell by passing 120-real as an arch to build_wheel.py."); \
} \
else if constexpr (!should_filter_tma_warp_specialized_gemm_problem_shape_v<ArchTag, TileShape, ClusterShape, \
T>) \
else if constexpr (!should_filter_tma_warp_specialized_gemm_problem_shape_v<ArchTag, MmaTileShape, \
ClusterShape, T>) \
{ \
using namespace cute; \
/* Helper class for defining all the cutlass types \
// template <typename ArchTag, typename T, typename WeightType, typename OutputType, typename EpilogueTag, \
// typename TileShape, typename ClusterShape, bool BIAS, EpilogueFusion FUSION> \
// typename MmaTileShape, typename ClusterShape, bool BIAS, EpilogueFusion FUSION> \
// struct TmaWarpSpecializedGroupedGemmInfo \
{ */ \
using Arch = ArchTag; \
constexpr static bool IsBlackwell = Arch::kMinComputeCapability >= 100; \
constexpr static bool IsSM120 = Arch::kMinComputeCapability == 120 || Arch::kMinComputeCapability == 121; \
constexpr static bool IsWFP4AFP8 = cutlass::platform::is_same<WeightType, SafeFP4>::value \
&& cutlass::platform::is_same<T, SafeFP8>::value; \
constexpr static bool IsFP4 = cutlass::platform::is_same<T, SafeFP4>::value; \
static_assert(!IsFP4 || IsBlackwell, "FP4 is only supported by SM100"); \
\
constexpr static bool IsFP8 = cutlass::platform::is_same<T, SafeFP8>::value; \
\
/* TODO Update once mixed input support is added */ \
static_assert(cutlass::platform::is_same<T, WeightType>::value || IsWFP4AFP8, \
@ -332,15 +346,16 @@ using SafeBF16 = void;
constexpr static bool Is2SM = IsBlackwell && (cute::size<0>(ClusterShape{}) % 2) == 0; \
using EpilogueScheduleSM100 = std::conditional_t<Is2SM, cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm, \
cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm>; \
using EpilogueScheduleSM103 \
= std::conditional_t<Is2SM, cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm, \
cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm>; \
using EpilogueScheduleSM10x \
= std::conditional_t<IsSM103FP4, EpilogueScheduleSM103, EpilogueScheduleSM100>; \
\
using EpilogueScheduleSM120 = cutlass::epilogue::TmaWarpSpecialized; \
using EpilogueScheduleBW = std ::conditional_t<IsSM120, EpilogueScheduleSM120, EpilogueScheduleSM100>; \
using EpilogueScheduleBW = std ::conditional_t<IsSM120, EpilogueScheduleSM120, EpilogueScheduleSM10x>; \
using EpilogueSchedule = std::conditional_t<IsBlackwell, EpilogueScheduleBW, EpilogueScheduleSM90>; \
\
using EpilogueTileShapeSm90 = TileShape; \
using AtomClusterDiv = std::conditional_t<Is2SM, _2, _1>; \
using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<AtomClusterDiv, _1, _1>{})); \
using EpilogueTileShapeSm100 = decltype(shape_div(TileShape{}, AtomThrShape{})); \
using EpilogueTileShape = std::conditional_t<IsBlackwell, EpilogueTileShapeSm100, EpilogueTileShapeSm90>; \
using EpilogueElementC = std::conditional_t<IsSM120, ElementCSafe, ElementC>; \
using EpilogueTensorOp = std::conditional_t<IsBlackwell && IsBlockScaled, \
cutlass::arch::OpClassBlockScaledTensorOp, cutlass::arch::OpClassTensorOp>; \
@ -350,7 +365,7 @@ using SafeBF16 = void;
/* Epilogue For Default Finalize */ \
using CollectiveEpilogueDefault = typename cutlass::epilogue::collective::CollectiveBuilder</**/ \
Arch, EpilogueTensorOp, /**/ \
EpilogueTileShape, ClusterShape, /**/ \
MmaTileShape, ClusterShape, /**/ \
EpilogueSubTile, /**/ \
ElementAccumulator, ElementAccumulator, /**/ \
EpilogueElementC, LayoutC*, AlignmentC, /**/ \
@ -360,7 +375,7 @@ using SafeBF16 = void;
/* Epilogue For Fused Finalize */ \
using CollectiveEpilogueFinalize = \
typename cutlass::epilogue::collective::EpilogueMoeFusedFinalizeBuilder< /**/ \
Arch, EpilogueTileShape, /**/ \
Arch, MmaTileShape, /**/ \
ElementCSafe, StrideC*, /**/ \
ElementFinalOutput, \
TmaWarpSpecializedGroupedGemmInput::FusedFinalizeEpilogue::StrideFinalOutput, /**/ \
@ -389,13 +404,19 @@ using SafeBF16 = void;
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100>; \
\
/* TRT-LLM uses vector size 16 for block scaled */ \
using KernelScheduleSM103 = std::conditional_t<Is2SM, \
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaled3xOmmaVs16Sm103, \
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaled3xOmmaVs16Sm103>; \
\
using KernelScheduleSM100 = std::conditional_t<Is2SM, \
std::conditional_t<IsBlockScaled, KernelSchedule2SmSm100BlockScaled, \
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100>, \
std::conditional_t<IsBlockScaled, KernelSchedule1SmSm100BlockScaled, \
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100>>; \
using KernelScheduleSM10x = std::conditional_t<IsSM103FP4, KernelScheduleSM103, KernelScheduleSM100>; \
\
using KernelScheduleSM120 = cutlass ::gemm ::collective::KernelScheduleAuto; \
using KernelScheduleBW = std::conditional_t<IsSM120, KernelScheduleSM120, KernelScheduleSM100>; \
using KernelScheduleBW = std::conditional_t<IsSM120, KernelScheduleSM120, KernelScheduleSM10x>; \
\
using KernelSchedule = std::conditional_t<IsBlackwell, KernelScheduleBW, KernelScheduleSM90>; \
\
@ -405,16 +426,12 @@ using SafeBF16 = void;
using MainloopElementA = std::conditional_t<IsBlackwell && IsBlockScaled, ElementABlockScaled, ElementA>; \
using MainloopElementB = std::conditional_t<IsBlackwell && IsBlockScaled, ElementBBlockScaled, ElementB>; \
\
using MainloopTileShapeSm90 = TileShape; \
using MainloopTileShapeSm100 = decltype(shape_div(TileShape{}, AtomThrShape{})); \
using MainloopTileShape = std::conditional_t<IsBlackwell, MainloopTileShapeSm100, MainloopTileShapeSm90>; \
\
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder</**/ \
Arch, TensorOp, /**/ \
MainloopElementB, LayoutB*, AlignmentB, /* A & B swapped here */ \
MainloopElementA, LayoutA*, AlignmentA, /**/ \
ElementAccumulator, /**/ \
MainloopTileShape, ClusterShape, /**/ \
MmaTileShape, ClusterShape, /**/ \
StageCountAutoCarveout, KernelSchedule>::CollectiveOp; \
\
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<TmaWarpSpecializedGroupedGemmInput::ProblemShape, \
@ -426,7 +443,7 @@ using SafeBF16 = void;
// using namespace cute; \
// using GemmInfo = TmaWarpSpecializedGroupedGemmInfo;<ArchTag, T, WeightType, OutputType, \
EpilogueTag, \
// TileShape, \
// MmaTileShape, \
// ClusterShape, BIAS, FUSION>; \
// \
// using ElementAccumulator = typename GemmInfo::ElementAccumulator; \
@ -614,6 +631,6 @@ using SafeBF16 = void;
TmaWarpSpecializedGroupedGemmInput tma_ws_input, int num_experts, int const multi_processor_count, \
cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size);
} // namespace cutlass_kernels
} // namespace cutlass_kernels_oss
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -23,15 +23,17 @@ namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels
namespace cutlass_kernels_oss
{
using tensorrt_llm::kernels::cutlass_kernels::GroupedGemmInput;
using tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput;
template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, typename CTAShape,
typename ClusterShape, typename MainloopScheduleType, typename EpilogueScheduleType,
cutlass::WeightOnlyQuantOp QuantOp>
void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType, GemmOutputType, GemmOutputType> inputs,
void sm90_generic_mixed_moe_gemm_kernelLauncher(
tensorrt_llm::kernels::cutlass_kernels::GroupedGemmInput<T, WeightType, GemmOutputType, GemmOutputType> inputs,
TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size);
} // namespace cutlass_kernels
} // namespace cutlass_kernels_oss
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -65,7 +65,7 @@ namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels
namespace cutlass_kernels_oss
{
namespace tk = tensorrt_llm::common;
namespace tkc = tensorrt_llm::cutlass_extensions;
@ -244,6 +244,6 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
return;
}
} // namespace cutlass_kernels
} // namespace cutlass_kernels_oss
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -201,9 +201,10 @@ struct genericMoeGemmKernelLauncher
// kernel.. (only support
// fp16 or bf16)
{
sm80_generic_fused_moe_gemm_kernelLauncher<ElementType, CutlassWeightType, ThreadblockShape::kM,
ThreadblockShape::kN, ThreadblockShape::kK, Stages, EpilogueTag>(
reinterpret_cast<ElementType const*>(inputs.A), reinterpret_cast<CutlassWeightType const*>(inputs.B),
tensorrt_llm::kernels::cutlass_kernels_oss::sm80_generic_fused_moe_gemm_kernelLauncher<ElementType,
CutlassWeightType, ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK, Stages,
EpilogueTag>(reinterpret_cast<ElementType const*>(inputs.A),
reinterpret_cast<CutlassWeightType const*>(inputs.B),
reinterpret_cast<ElementType const*>(inputs.biases), inputs.bias_is_broadcast,
reinterpret_cast<ElementType*>(inputs.C), inputs.total_tokens_including_expert, inputs.num_rows,
inputs.n, inputs.k, inputs.num_experts, sm_count_, inputs.stream, inputs.occupancy);
@ -242,16 +243,18 @@ static void dispatch(GroupedGemmInput<T, WeightType, GemmOutputType, GemmOutputT
&& (!isFp8 || std::is_same_v<Arch, cutlass::arch::Sm89>) &&!isFp4)
{
// dispatch for quant op type
auto* launcher = kernels::cutlass_kernels::genericMoeGemmKernelLauncher<T, WeightType, GemmOutputType, Arch,
cutlass::WeightOnlyQuantOp::UNDEFINED, EpilogueTag, ThreadblockShape, WarpShape, Stages>::call;
auto* launcher
= tensorrt_llm::kernels::cutlass_kernels::genericMoeGemmKernelLauncher<T, WeightType, GemmOutputType, Arch,
cutlass::WeightOnlyQuantOp::UNDEFINED, EpilogueTag, ThreadblockShape, WarpShape, Stages>::call;
if (!std::is_same_v<WeightType, T> && inputs.groupwise_quant_group_size > 0)
{
launcher = inputs.zeros ? kernels::cutlass_kernels::genericMoeGemmKernelLauncher<T, WeightType,
GemmOutputType, Arch, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, EpilogueTag,
ThreadblockShape, WarpShape, Stages>::call
: kernels::cutlass_kernels::genericMoeGemmKernelLauncher<T, WeightType,
GemmOutputType, Arch, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY,
EpilogueTag, ThreadblockShape, WarpShape, Stages>::call;
launcher = inputs.zeros
? tensorrt_llm::kernels::cutlass_kernels::genericMoeGemmKernelLauncher<T, WeightType, GemmOutputType,
Arch, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, EpilogueTag, ThreadblockShape,
WarpShape, Stages>::call
: tensorrt_llm::kernels::cutlass_kernels::genericMoeGemmKernelLauncher<T, WeightType, GemmOutputType,
Arch, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, EpilogueTag, ThreadblockShape, WarpShape,
Stages>::call;
}
launcher(inputs, sm_count_);
}
@ -503,13 +506,14 @@ MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getAmpereConfigs(int sm
auto config_type_param = static_cast<CutlassGemmConfig::CandidateConfigTypeParam>(
weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper | fp8_only_flag);
if (!kernels::cutlass_kernels::isValidAmpereMOESpecialisation<T, WeightType>() || (use_w4afp8 && sm != 89))
if (!tensorrt_llm::kernels::cutlass_kernels::isValidAmpereMOESpecialisation<T, WeightType>()
|| (use_w4afp8 && sm != 89))
{
return {};
}
std::vector<cutlass_extensions::CutlassGemmConfig> ampere_configs
= kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param);
= tensorrt_llm::kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param);
return ampere_configs;
}
@ -528,30 +532,40 @@ MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getTmaWarpSpecializedCo
int const enable_hopper = sm == 90 ? CutlassGemmConfig::HOPPER : CutlassGemmConfig::NONE;
static constexpr auto fp8_only_flag = use_fp8 ? CutlassGemmConfig::FP8_ONLY : CutlassGemmConfig::NONE;
static constexpr auto fp4_only_flag
= (use_fp4 || use_wfp4afp4) ? CutlassGemmConfig::FP4_ONLY : CutlassGemmConfig::NONE;
= (use_fp4 || use_wfp4afp8) ? CutlassGemmConfig::FP4_ONLY : CutlassGemmConfig::NONE;
auto config_type_param = static_cast<CutlassGemmConfig::CandidateConfigTypeParam>(weight_only_flag | simt_only_flag
| grouped_gemm_flag | enable_blackwell | enable_hopper | fp8_only_flag | fp4_only_flag);
TLLM_CHECK_WITH_INFO(!(enable_blackwell && enable_hopper), "Blackwell and hopper flags are mutually exclusive");
if (sm >= 100 && sm < 120 && !kernels::cutlass_kernels::isValidBlackwellMOESpecialisation<T, WeightType>())
sm = use_wfp4afp8 && sm == 103 ? 100 : sm;
if (sm >= 100 && sm < 120
&& !tensorrt_llm::kernels::cutlass_kernels::isValidBlackwellMOESpecialisation<T, WeightType>())
{
TLLM_LOG_TRACE("Blackwell is not supported for this configuration, not selecting any TMA WS implementations");
return {};
}
if ((sm == 120 || sm == 121) && !kernels::cutlass_kernels::isValidSM120MOESpecialisation<T, WeightType>())
if ((sm == 120 || sm == 121)
&& !tensorrt_llm::kernels::cutlass_kernels::isValidSM120MOESpecialisation<T, WeightType>())
{
TLLM_LOG_TRACE(
"Blackwell SM120 is not supported for this configuration, not selecting any TMA WS implementations");
return {};
}
if (enable_hopper && !kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType>())
if (enable_hopper && !tensorrt_llm::kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType>())
{
TLLM_LOG_TRACE("Hopper is not supported for this configuration, not selecting any TMA WS implementations");
return {};
}
std::vector<cutlass_extensions::CutlassGemmConfig> tma_ws_configs
= kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param);
= tensorrt_llm::kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param);
if (sm == 103 && use_fp4)
{
// Explicitly select SM100 as well
auto sm100_configs
= tensorrt_llm::kernels::cutlass_kernels::get_candidate_configs(100, max_split_k, config_type_param);
std::copy(sm100_configs.begin(), sm100_configs.end(), std::back_inserter(tma_ws_configs));
}
return tma_ws_configs;
}
@ -566,9 +580,11 @@ bool MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::isTmaWarpSpecializ
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
bool MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::supportsTmaWarpSpecialized() const
{
return (sm_ == 90 && kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType>())
|| (sm_ >= 100 && sm_ < 120 && kernels::cutlass_kernels::isValidBlackwellMOESpecialisation<T, WeightType>())
|| ((sm_ == 120 || sm_ == 121) && kernels::cutlass_kernels::isValidSM120MOESpecialisation<T, WeightType>());
return (sm_ == 90 && tensorrt_llm::kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType>())
|| (sm_ >= 100 && sm_ < 120
&& tensorrt_llm::kernels::cutlass_kernels::isValidBlackwellMOESpecialisation<T, WeightType>())
|| ((sm_ == 120 || sm_ == 121)
&& tensorrt_llm::kernels::cutlass_kernels::isValidSM120MOESpecialisation<T, WeightType>());
}
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
@ -658,15 +674,16 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch(
}
}
if constexpr (kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation<T, WeightType, EpilogueTag>()
if constexpr (tensorrt_llm::kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation<T, WeightType,
EpilogueTag>()
&& !use_w4afp8)
{
// We allow both tma warp specialized and SM80 configurations to coexist because for some cases with small
// numbers of tokens SM80 is faster. We check here to see which is selected
if (inputs.gemm_config.sm_version >= 90)
{
TLLM_CHECK_WITH_INFO(
(inputs.gemm_config.sm_version == sm_) || (inputs.gemm_config.sm_version == 100 && sm_ == 103),
// Check the major version of the SM matches
TLLM_CHECK_WITH_INFO(inputs.gemm_config.sm_version / 10 == sm_ / 10,
"Using SM %d configuration for SM %d device", inputs.gemm_config.sm_version, sm_);
TLLM_CHECK_WITH_INFO(inputs.biases != nullptr || hopper_inputs.ptr_c == nullptr,
"Input biases and hopper input disagree if bias is enabled");
@ -679,11 +696,11 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch(
switch (hopper_inputs.fusion)
{
case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE:
return &dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType, OutputType, EpilogueTag,
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE>;
return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType,
OutputType, EpilogueTag, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE>;
case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE:
return &dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType, OutputType, EpilogueTag,
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE>;
return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType,
OutputType, EpilogueTag, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE>;
case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::ACTIVATION:
case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::GATED_ACTIVATION:
default: TLLM_THROW("Unimplemented fusion %d requested", (int) hopper_inputs.fusion);
@ -707,19 +724,19 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch(
// EpilogueTag is ignored
if (inputs.k % 512 == 0)
{
sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType,
cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType,
cutlass_extensions::EpilogueOpDefault, 4>(
inputs, hopper_inputs, multi_processor_count_, nullptr);
}
else if (inputs.k % 256 == 0)
{
sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType,
cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType,
cutlass_extensions::EpilogueOpDefault, 2>(
inputs, hopper_inputs, multi_processor_count_, nullptr);
}
else if (inputs.k % 128 == 0)
{
sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType,
cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType,
cutlass_extensions::EpilogueOpDefault, 1>(
inputs, hopper_inputs, multi_processor_count_, nullptr);
}
@ -733,7 +750,8 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch(
#endif
// Do Ampere case instead
if constexpr (kernels::cutlass_kernels::isValidAmpereMOESpecialisation<T, WeightType, EpilogueTag>())
if constexpr (tensorrt_llm::kernels::cutlass_kernels::isValidAmpereMOESpecialisation<T, WeightType,
EpilogueTag>())
{
TLLM_CHECK_WITH_INFO(!use_fp8, "No fallback FP8 implementation available");
TLLM_CHECK_WITH_INFO(use_w4afp8 || !hopper_inputs.isValid(),
@ -782,26 +800,19 @@ size_t MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::calcMaxWorkspace
{
if constexpr (use_w4afp8)
{
return calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput<T, WeightType, OutputType>(
return cutlass_kernels_oss::calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput<T, WeightType, OutputType>(
num_experts, multi_processor_count_);
}
if (!supportsTmaWarpSpecialized())
{
return 0;
}
// #ifndef CUTLASS_ARCH_MMA_SM100F_SUPPORTED
// static_assert(__CUDA_ARCH__ == 1000, "__CUDA_ARCH__");
// static_assert(CUTLASS_ARCH_MMA_SM100_SUPPORTED, "CUTLASS_ARCH_MMA_SM100F_SUPPORTED");
// static_assert(CUTLASS_ARCH_MMA_SM100_ENABLED, "CUTLASS_ARCH_MMA_SM100_ENABLED");
// static_assert(CUTLASS_ARCH_MMA_SM100F_SUPPORTED, "CUTLASS_ARCH_MMA_SM100F_SUPPORTED");
// static_assert(CUTLASS_ARCH_MMA_SM100F_ENABLED, "CUTLASS_ARCH_MMA_SM100F_ENABLED");
// // #error "SM100F not supported!"
// #endif
if constexpr (kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation<T, WeightType>() && !use_w4afp8)
if constexpr (tensorrt_llm::kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation<T, WeightType>()
&& !use_w4afp8)
{
auto configs = getTmaWarpSpecializedConfigs(sm_);
auto fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE;
if constexpr (use_wfp4afp4)
if constexpr (use_wfp4afp8)
{
fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX;
}
@ -818,8 +829,9 @@ size_t MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::calcMaxWorkspace
{ \
try \
{ \
size_t size = calcMaxWorkspaceSizeTmaWarpSpecialized<T, WeightType, OutputType, FUSION>( \
num_experts, conf, multi_processor_count_, fpX_block_scaling_type); \
size_t size \
= cutlass_kernels_oss::calcMaxWorkspaceSizeTmaWarpSpecialized<T, WeightType, OutputType, FUSION>( \
num_experts, conf, multi_processor_count_, fpX_block_scaling_type); \
max_size = std::max(max_size, size); \
has_config = true; \
} \

View File

@ -64,8 +64,9 @@
#include <math.h>
#include <sstream>
namespace tensorrt_llm::kernels::cutlass_kernels
namespace tensorrt_llm::kernels::cutlass_kernels_oss
{
using tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput;
using EpilogueFusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion;
template <typename Arch, typename T, typename WeightType, typename OutputType, typename EpilogueTag,
@ -101,6 +102,12 @@ void dispatchMoeGemmSelectBiasTmaWarpSpecialized(TmaWarpSpecializedGroupedGemmIn
TLLM_THROW("Please recompile with support for hopper by passing 90-real as an arch to build_wheel.py.");
}
#endif
#ifndef COMPILE_BLACKWELL_SM103_TMA_GROUPED_GEMMS
else if constexpr (Arch::kMinComputeCapability == 103)
{
TLLM_THROW("Please recompile with support for blackwell by passing 103-real as an arch to build_wheel.py.");
}
#endif
#ifndef COMPILE_BLACKWELL_TMA_GROUPED_GEMMS
else if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 120)
{
@ -122,31 +129,36 @@ void dispatchMoeGemmSelectBiasTmaWarpSpecialized(TmaWarpSpecializedGroupedGemmIn
TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type
== TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX,
"MXFPX is the only supported scaling type for WFP4AFP8");
return &kernels::cutlass_kernels::tma_warp_specialized_generic_moe_gemm_kernelLauncher<Arch, T,
WeightType, OutputType, EpilogueTag, FUSION, TileShape, ClusterShape, true, false>;
return &tma_warp_specialized_generic_moe_gemm_kernelLauncher<Arch, T, WeightType, OutputType,
EpilogueTag, FUSION, TileShape, ClusterShape, true, false>;
}
else
{
TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type
!= TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX,
"MXFPX is not supported for the selected weight combination");
return &kernels::cutlass_kernels::tma_warp_specialized_generic_moe_gemm_kernelLauncher<Arch, T,
WeightType, OutputType, EpilogueTag, FUSION, TileShape, ClusterShape, false, false>;
return &tma_warp_specialized_generic_moe_gemm_kernelLauncher<Arch, T, WeightType, OutputType,
EpilogueTag, FUSION, TileShape, ClusterShape, false, false>;
}
};
getFunc()(hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size);
}
}
template <typename ClusterTileShape, typename ClusterShape, typename DataType, typename WeightType>
template <typename Arch, typename CtaShape, typename ClusterShape, typename DataType, typename WeightType>
constexpr bool are_tile_shapes_supported_sm100()
{
using namespace cute;
using CtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{}));
// This is the epilogue shape. The MMA shape will be twice this for 2SM
constexpr auto TileM = size<0>(CtaShape{});
constexpr auto TileN = size<1>(CtaShape{});
if constexpr (Arch::kMinComputeCapability == 103)
{
return std::is_same_v<DataType, __nv_fp4_e2m1> && std::is_same_v<WeightType, __nv_fp4_e2m1> && TileM == 128
&& (TileN == 128 || TileN == 256);
}
if constexpr (TileM != 64 && TileM != 128)
{
return false;
@ -224,7 +236,7 @@ constexpr bool are_tile_shapes_supported()
{
if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 120)
{
return are_tile_shapes_supported_sm100<CTAShape, ClusterShape, DataType, WeightType>();
return are_tile_shapes_supported_sm100<Arch, CTAShape, ClusterShape, DataType, WeightType>();
}
else if constexpr (Arch::kMinComputeCapability == 120 || Arch::kMinComputeCapability == 121)
{
@ -347,12 +359,34 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized(TmaWarpSpecializedGroupedG
TLLM_THROW("Unsupported SM90 configuration requested");
}
}
#ifdef ENABLE_FP4
// Check this before SM100 because we fall back to SM100 if not NVFP4
else if (gemm_config.sm_version == 103
&& std::is_same_v<T, __nv_fp4_e2m1> && std::is_same_v<WeightType, __nv_fp4_e2m1>)
{
if constexpr (kernels::cutlass_kernels::isValidBlackwellMOESpecialisation<T, WeightType, EpilogueTag, FUSION>())
{
switch (gemm_config.tile_config_sm100)
{
SHAPE_CASE(103, 128, 128, 128)
SHAPE_CASE(103, 128, 256, 128)
DEFAULT_CASE(100) // 100 because we use the same member variable for SM100 and SM103
}
}
else
{
TLLM_THROW("Unsupported SM103 configuration requested");
}
}
#endif
else if (gemm_config.sm_version >= 100 && gemm_config.sm_version < 120)
{
if constexpr (kernels::cutlass_kernels::isValidBlackwellMOESpecialisation<T, WeightType, EpilogueTag, FUSION>())
{
switch (gemm_config.tile_config_sm100)
{
SHAPE_CASE(100, 64, 32, 128)
SHAPE_CASE(100, 64, 64, 128)
SHAPE_CASE(100, 64, 128, 128)
SHAPE_CASE(100, 64, 256, 128)
@ -363,10 +397,6 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized(TmaWarpSpecializedGroupedG
SHAPE_CASE(100, 128, 128, 128)
SHAPE_CASE(100, 128, 256, 128)
SHAPE_CASE(100, 256, 64, 128)
SHAPE_CASE(100, 256, 128, 128)
SHAPE_CASE(100, 256, 256, 128)
// SHAPE_CASE(100, 128, 128, 64)
// SHAPE_CASE(100, 128, 256, 64)
// SHAPE_CASE(100, 256, 256, 64)
@ -409,4 +439,4 @@ size_t calcMaxWorkspaceSizeTmaWarpSpecialized(int num_experts, cutlass_extension
return count;
}
} // namespace tensorrt_llm::kernels::cutlass_kernels
} // namespace tensorrt_llm::kernels::cutlass_kernels_oss

View File

@ -57,9 +57,11 @@
#include <math.h>
#include <sstream>
namespace tensorrt_llm::kernels::cutlass_kernels
namespace tensorrt_llm::kernels::cutlass_kernels_oss
{
using tensorrt_llm::kernels::cutlass_kernels::GroupedGemmInput;
using tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput;
namespace tk = tensorrt_llm::common;
namespace tkc = tensorrt_llm::cutlass_extensions;
@ -236,4 +238,4 @@ size_t calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput(int num_experts, int sm_
return count;
}
} // namespace tensorrt_llm::kernels::cutlass_kernels
} // namespace tensorrt_llm::kernels::cutlass_kernels_oss

View File

@ -302,12 +302,12 @@ namespace tensorrt_llm
{{
namespace kernels
{{
namespace cutlass_kernels
namespace cutlass_kernels_oss
{{
{instantiations}
}} // namespace cutlass_kernels
}} // namespace cutlass_kernels_oss
}} // namespace kernels
}} // namespace tensorrt_llm
"""
@ -337,18 +337,16 @@ def write_file(launcher_inl_files, operations, output_file):
f.write(content)
from operator import mul, truediv
def elementwise(x, y, f):
return tuple(f(a, b) for (a, b) in zip(x, y))
def is_gemm_op_valid_sm100(op):
# TODO These are much more restricted than theory dictates, investigate if more can be enabled in future
tile_m, tile_n, _ = elementwise(op.cta_shape, op.cga_shape, truediv)
tile_m, tile_n, _ = op.cta_shape
cga_m, cga_n, _ = op.cga_shape
if op.arch == 103:
return op.act_type == e2m1 and op.weight_type == e2m1 and tile_m == 128 and tile_n in [
128, 256
]
# Default shapes
# This is epilogue tile size. For two CTA this is actually size 128/256 for the MMA
if tile_m not in [64, 128]:
@ -366,10 +364,7 @@ def is_gemm_op_valid_sm100(op):
if (op.act_type == DataType.e4m3 and (tile_n == 16 or tile_n == 8)
and (cga_m == 1 and cga_n == 1)):
# todo: double check why this is disable in CUTLASS backend. @yuhan
if tile_m == 128 and tile_n == 8:
return False
else:
return True
return not (tile_m == 128 and tile_n % 16 != 0)
# Default alignment requirements
if tile_n % 32 != 0 or tile_n < 32 or tile_n > 256:
@ -628,7 +623,6 @@ def generate_sm120_grouped_gemm_operations(is_arch_enabled):
operations = list()
for dtype, quant_op, epi_tag, epi_fusion, cta_shape_mnk, cga_shape in partial_args:
cga_tile_shape_mnk = elementwise(cta_shape_mnk, cga_shape, mul)
# Ignored
mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative
@ -641,8 +635,8 @@ def generate_sm120_grouped_gemm_operations(is_arch_enabled):
for otype in otypes:
moe_gemm_operation = TrtLlm_GemmLauncher(
GemmKind.Grouped, arch, dtype, dtype, dtype, dtype, otype,
quant_op, epi_tag, cga_tile_shape_mnk, warp_shape, stages,
cga_shape, mainloop_schedule, epi_schedule, epi_fusion)
quant_op, epi_tag, cta_shape_mnk, warp_shape, stages, cga_shape,
mainloop_schedule, epi_schedule, epi_fusion)
operations.append(moe_gemm_operation)
return operations
@ -653,10 +647,9 @@ def generate_sm120_operations(is_arch_enabled):
return operations
def generate_sm100_grouped_gemm_operations(is_arch_enabled):
def generate_sm100_grouped_gemm_operations(is_arch_enabled, arch):
if not is_arch_enabled:
return []
arch = 100
supported_dtypes = [
DataType.f16, DataType.bf16, DataType.f32, DataType.e4m3, e2m1,
(DataType.e4m3, e2m1)
@ -664,7 +657,7 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled):
quant_ops = [TrtLlm_QuantOp.none]
epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default]
cta_shapes_m = [64, 128]
cta_shapes_n = [8, 16, 32, 64, 128, 256]
cta_shapes_n = [8, 16, 32, 64, 128, 192, 256]
cta_shapes_mn = product(cta_shapes_m, cta_shapes_n)
warp_shape = [0, 0, 0] # ignored except for naming
@ -688,7 +681,6 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled):
weight_type = dtype
cta_shape_mnk = calc_shape_mnk_sm100_grouped_gemm(cta_shape_mn, dtype)
cga_tile_shape_mnk = elementwise(cta_shape_mnk, cga_shape, mul)
# Ignored
mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative
@ -709,7 +701,7 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled):
otype,
quant_op,
epi_tag,
cga_tile_shape_mnk,
cta_shape_mnk,
warp_shape,
stages,
cga_shape,
@ -723,8 +715,13 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled):
return operations
def generate_sm103_operations(is_arch_enabled):
operations = generate_sm100_grouped_gemm_operations(is_arch_enabled, 103)
return operations
def generate_sm100_operations(is_arch_enabled):
operations = generate_sm100_grouped_gemm_operations(is_arch_enabled)
operations = generate_sm100_grouped_gemm_operations(is_arch_enabled, 100)
return operations
@ -804,6 +801,7 @@ if __name__ == "__main__":
(GemmKind.Gemm, 90): [fpA_intB_inl],
(GemmKind.Grouped, 90): [moe_gemm_inl],
(GemmKind.Grouped, 100): [moe_gemm_inl],
(GemmKind.Grouped, 103): [moe_gemm_inl],
(GemmKind.Grouped, 120): [moe_gemm_inl],
(GemmKind.Grouped, 80): [sm80_moe_gemm_inl]
}
@ -815,7 +813,8 @@ if __name__ == "__main__":
# Template instantiation dominates the time in a compilation unit, so it is the most important factor to improve.
operations = []
operations += generate_sm120_operations(has_arch(120) or has_arch(121))
operations += generate_sm100_operations(has_arch(100))
operations += generate_sm103_operations(has_arch(103))
operations += generate_sm100_operations(has_arch(100) or has_arch(103))
operations += generate_sm90_operations(has_arch(90))
operations += generate_sm80_operations(has_arch(80) or has_arch(89))

View File

@ -280,7 +280,6 @@ public:
#else
static constexpr bool use_fp8 = false;
static constexpr bool use_w4afp8 = false;
static constexpr bool use_wfp4afp4 = false;
#endif
#if defined(ENABLE_FP4)

View File

@ -26,8 +26,8 @@ include(GoogleTest)
include_directories(
${PROJECT_SOURCE_DIR}/tensorrt_llm/cutlass_extensions/include
${PROJECT_SOURCE_DIR}/include
${3RDPARTY_DIR}/cutlass/include
${3RDPARTY_DIR}/cutlass/tools/util/include
${3RDPARTY_DIR}/dynamic-kernel-generator/cutlass/include
${3RDPARTY_DIR}/dynamic-kernel-generator/cutlass/tools/util/include
${PROJECT_SOURCE_DIR}/tests/batch_manager
${PROJECT_SOURCE_DIR}/tests/utils)

View File

@ -26,6 +26,9 @@ add_gtest(mixtureOfExpertsTest mixtureOfExpertsTest.cu)
# If we are using oss cutlass, build an explicit internal test
if(USING_OSS_CUTLASS_MOE_GEMM)
target_compile_definitions(mixtureOfExpertsTest
PUBLIC USING_OSS_CUTLASS_MOE_GEMM)
add_gtest(mixtureOfExpertsInternalTest mixtureOfExpertsTest.cu)
remove_compile_definition(mixtureOfExpertsInternalTest
USING_OSS_CUTLASS_MOE_GEMM)

View File

@ -217,7 +217,7 @@ target_include_directories(
${CUDA_INCLUDE_DIRS}
${CUDNN_ROOT_DIR}/include
${NCCL_INCLUDE_DIR}
${3RDPARTY_DIR}/cutlass/include
${3RDPARTY_DIR}/dynamic-kernel-generator/cutlass/include
${MPI_INCLUDE_PATH}
${COMMON_HEADER_DIR})