fix kernel select code to recognize sm103/sm100f

Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
Xiwen Yu 2025-07-02 19:02:55 +08:00
parent 5c09dc8304
commit 1b846046dd
8 changed files with 39 additions and 13 deletions

View File

@ -160,13 +160,19 @@ function(setup_cuda_architectures)
${CMAKE_CUDA_ARCHITECTURES_ORIG}
PARENT_SCOPE)
set(ARCHITECTURES_WITH_KERNELS 80 86 89 90 100 120)
set(ARCHITECTURES_WITH_KERNELS 80 86 89 90 120)
foreach(CUDA_ARCH IN LISTS ARCHITECTURES_WITH_KERNELS)
if(NOT ${CUDA_ARCH} IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
add_definitions("-DEXCLUDE_SM_${CUDA_ARCH}")
message(STATUS "Excluding SM ${CUDA_ARCH}")
endif()
endforeach()
# deal with SM100/f
if(NOT "100" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG
AND NOT "100f" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
add_definitions("-DEXCLUDE_SM_100")
message(STATUS "Excluding SM 100(f)")
endif()
# -a suffix supported from Hopper (90)
set(MIN_ARCHITECTURE_HAS_ACCEL 90)

View File

@ -311,6 +311,16 @@ inline int getSMVersion()
return sm;
}
inline int getSMFamily()
{
int sm = getSMVersion();
if (sm == 100 || sm == 103)
{
return 100;
}
return sm;
}
inline int getDevice()
{
int deviceID{0};

View File

@ -330,7 +330,7 @@ size_t CutlassFp4GemmRunner<T, fp4GemmType>::dispatchToArch(T* D, void const* A,
{
if constexpr (fp4GemmType == FP4GemmType::W4A8_MXFP4_MXFP8)
{
if (mSm == 100)
if (mSm == 100 || mSm == 103)
{
return dispatchMXFP8xMXFP4GemmCTAShapeSm100<T>(D, A, B, input_sf, weight_sf, global_sf, m, n, k,
batch_count, gemmConfig, workspace, workspaceBytes, stream, occupancy);
@ -343,7 +343,7 @@ size_t CutlassFp4GemmRunner<T, fp4GemmType>::dispatchToArch(T* D, void const* A,
}
else if constexpr (fp4GemmType == FP4GemmType::W4A4_NVFP4_NVFP4)
{
if (mSm == 100)
if (mSm == 100 || mSm == 103)
{
return dispatchNVFP4xNVFP4GemmCTAShapeSm100<T>(D, A, B, input_sf, weight_sf, global_sf, m, n, k,
batch_count, gemmConfig, workspace, workspaceBytes, stream, occupancy);
@ -384,7 +384,7 @@ std::vector<tkc::CutlassGemmConfig> CutlassFp4GemmRunner<T, fp4GemmType>::getCon
std::vector<CutlassGemmConfig> candidateConfigs;
if (mSm == 100)
if (mSm == 100 || mSm == 103)
{
std::vector<tkc::CutlassTileConfigSM100> tilesSm100 = {
tkc::CutlassTileConfigSM100::CtaShape128x64x128B,

View File

@ -665,8 +665,9 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch(
// numbers of tokens SM80 is faster. We check here to see which is selected
if (inputs.gemm_config.sm_version >= 90)
{
TLLM_CHECK_WITH_INFO(inputs.gemm_config.sm_version == sm_, "Using SM %d configuration for SM %d device",
inputs.gemm_config.sm_version, sm_);
TLLM_CHECK_WITH_INFO(
(inputs.gemm_config.sm_version == sm_) || (inputs.gemm_config.sm_version == 100 && sm_ == 103),
"Using SM %d configuration for SM %d device", inputs.gemm_config.sm_version, sm_);
TLLM_CHECK_WITH_INFO(inputs.biases != nullptr || hopper_inputs.ptr_c == nullptr,
"Input biases and hopper input disagree if bias is enabled");
TLLM_CHECK_WITH_INFO(
@ -788,6 +789,14 @@ size_t MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::calcMaxWorkspace
{
return 0;
}
// #ifndef CUTLASS_ARCH_MMA_SM100F_SUPPORTED
// static_assert(__CUDA_ARCH__ == 1000, "__CUDA_ARCH__");
// static_assert(CUTLASS_ARCH_MMA_SM100_SUPPORTED, "CUTLASS_ARCH_MMA_SM100F_SUPPORTED");
// static_assert(CUTLASS_ARCH_MMA_SM100_ENABLED, "CUTLASS_ARCH_MMA_SM100_ENABLED");
// static_assert(CUTLASS_ARCH_MMA_SM100F_SUPPORTED, "CUTLASS_ARCH_MMA_SM100F_SUPPORTED");
// static_assert(CUTLASS_ARCH_MMA_SM100F_ENABLED, "CUTLASS_ARCH_MMA_SM100F_ENABLED");
// // #error "SM100F not supported!"
// #endif
if constexpr (kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation<T, WeightType>() && !use_w4afp8)
{
auto configs = getTmaWarpSpecializedConfigs(sm_);

View File

@ -809,7 +809,7 @@ if __name__ == "__main__":
}
def has_arch(sm):
return f"{sm}" in arches or f"{sm}-real" in arches
return f"{sm}" in arches or f"{sm}-real" in arches or f"{sm}f-real" in arches or f"{sm}f" in arches
# The goal here is to group kernels with common instantiations together in order to reduce template instantiation overheads.
# Template instantiation dominates the time in a compilation unit, so it is the most important factor to improve.

View File

@ -364,7 +364,8 @@ void* HostAccessibleDeviceAllocator::allocate(size_t memorySize)
TLLM_CHECK_WITH_INFO(
mAllowManagedFallback, "HostAccessibleDeviceAllocator is not supported on the current system.");
TLLM_CUDA_CHECK(cudaMallocManaged(&devPtr, memorySize));
TLLM_CUDA_CHECK(cudaMemAdvise(devPtr, memorySize, cudaMemAdviseSetPreferredLocation, currentDevId));
TLLM_CUDA_CHECK(cudaMemAdvise(
devPtr, memorySize, cudaMemAdviseSetPreferredLocation, {cudaMemLocationTypeDevice, currentDevId}));
hostPtr = devPtr;
}
recordAllocation(devPtr, memorySize, hostPtr, memDesc);

View File

@ -38,8 +38,8 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::Tensor const& r
int64_t const local_num_experts, std::optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim,
int64_t const routing_method_type, bool const do_finalize, MoeRunnerType& moe_runner, int64_t const moeConfigIndex)
{
auto const sm = tensorrt_llm::common::getSMVersion();
TORCH_CHECK(sm == 100, "Only SM100 is supported by FP4 block scale MOE");
auto const sm = tensorrt_llm::common::getSMFamily();
TORCH_CHECK(sm == 100, "Only SM100f is supported by FP4 block scale MOE");
TORCH_CHECK(tile_tokens_dim == 8 || tile_tokens_dim == 16 || tile_tokens_dim == 32 || tile_tokens_dim == 64,
"tile_tokens_dim must be 8, 16, 32, 64");
if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3)

View File

@ -568,7 +568,7 @@ def quant_dequant_per_tensor_fp8(a):
@pytest.mark.skipif(
getSMVersion() != 100,
getSMVersion() < 100 or getSMVersion() >= 110,
reason="The kernel only supports Blackwell. Current SM is %d." %
getSMVersion(),
)
@ -702,7 +702,7 @@ class TestMoeFP8:
@pytest.mark.skipif(
getSMVersion() != 100,
getSMVersion() < 100 or getSMVersion() >= 110,
reason="The kernel only supports Blackwell. Current SM is %d." %
getSMVersion(),
)
@ -1061,7 +1061,7 @@ class TestMoeFp4:
@pytest.mark.skipif(
getSMVersion() != 100,
getSMVersion() < 100 or getSMVersion() >= 110,
reason="The kernel only supports Blackwell. Current SM is %d." %
getSMVersion(),
)