From 06d020bb6ef70c9549a4e4c29493ebf048c06ee0 Mon Sep 17 00:00:00 2001 From: Blake Ledden Date: Fri, 15 May 2026 13:59:00 -0400 Subject: [PATCH] [Bugfix] Fix SM121 (DGX Spark) exclusion from Marlin/CUTLASS FP8 paths (#35568) Signed-off-by: Blake Ledden Co-authored-by: Claude Opus 4.6 Co-authored-by: Pavani Majety --- .../quantization/w8a8/cutlass/c3x/scaled_mm.cuh | 2 +- .../w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh | 2 +- csrc/moe/marlin_moe_wna16/generate_kernels.py | 6 +++--- csrc/moe/marlin_moe_wna16/ops.cu | 4 ++-- csrc/quantization/marlin/generate_kernels.py | 6 +++--- tests/kernels/moe/test_moe.py | 6 +++++- tests/kernels/quantization/test_marlin_gemm.py | 3 ++- .../layers/quantization/utils/marlin_utils.py | 4 ++-- 8 files changed, 19 insertions(+), 14 deletions(-) diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm.cuh index e98433bed25..952931103c6 100644 --- a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm.cuh +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm.cuh @@ -202,7 +202,7 @@ struct cutlass_3x_gemm_sm120 { sizeof(typename CollectiveEpilogue::SharedStorage))>, KernelSchedule>::CollectiveOp; - using GemmKernel = enable_sm120_only, CollectiveMainloop, CollectiveEpilogue, void>>; }; diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh index 226e4f7a6bd..7a7229c95ba 100644 --- a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh @@ -72,7 +72,7 @@ struct cutlass_3x_gemm_sm120_custom { sizeof(typename CollectiveEpilogue::SharedStorage))>, KernelSchedule, void>::CollectiveOp; - using GemmKernel = enable_sm120_only, CollectiveMainloop, CollectiveEpilogue, void>>; }; diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py index bca697cae92..6ddda1d51db 100644 --- a/csrc/moe/marlin_moe_wna16/generate_kernels.py +++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -15,11 +15,11 @@ SUPPORT_SM80 = False for arch in sys.argv[1].split(","): arch = arch[: arch.index(".") + 2].replace(".", "") arch = int(arch) - # only SM89 and SM120 fully support - # mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32. + # SM89 and the SM12x family (SM120 RTX 5090, SM121 DGX Spark GB10) + # fully support mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32. # SM90 and SM100 can use this PTX, but it’s simulated # with FP16 MMA, so it cannot achieve any acceleration. - if arch in [89, 120]: + if arch == 89 or arch // 10 == 12: SUPPORT_FP8 = True if arch >= 80: SUPPORT_SM80 = True diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu index cf97f95a8fc..82cba2978b1 100644 --- a/csrc/moe/marlin_moe_wna16/ops.cu +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -448,8 +448,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, "FP8 only support Ada Lovelace or newer GPUs."); TORCH_CHECK( major_capability * 10 + minor_capability == 89 || - major_capability * 10 + minor_capability == 120, - "Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than " + major_capability == 12, + "Marlin W4A8-FP8 only support SM89 or SM12x device (It is slower than " "Marlin W4A16 on other devices)."); } diff --git a/csrc/quantization/marlin/generate_kernels.py b/csrc/quantization/marlin/generate_kernels.py index 19a42de1f87..7b316037ec6 100644 --- a/csrc/quantization/marlin/generate_kernels.py +++ b/csrc/quantization/marlin/generate_kernels.py @@ -15,11 +15,11 @@ SUPPORT_SM80 = False for arch in sys.argv[1].split(","): arch = arch[: arch.index(".") + 2].replace(".", "") arch = int(arch) - # only SM89 and SM120 fully support - # mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32. + # SM89 and the SM12x family (SM120 RTX 5090, SM121 DGX Spark GB10) + # fully support mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32. # SM90 and SM100 can use this PTX, but it’s simulated # with FP16 MMA, so it cannot achieve any acceleration. - if arch in [89, 120]: + if arch == 89 or arch // 10 == 12: SUPPORT_FP8 = True if arch >= 80: SUPPORT_SM80 = True diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index c4efe0a0479..3bdcc447e5e 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -746,7 +746,8 @@ def marlin_moe_generate_valid_test_cases(): for sub_case in inner_combinations: if ( sub_case[0] == scalar_types.float8_e4m3fn - and current_platform.get_device_capability() not in [89, 120] + and not current_platform.is_device_capability(89) + and not current_platform.is_device_capability_family(120) ): continue @@ -897,6 +898,7 @@ class MarlinMoEWeightData: marlin_moe_generate_valid_test_cases(), ) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") +@pytest.mark.usefixtures("default_vllm_config") def test_fused_marlin_moe( a_type: ScalarType, b_type: ScalarType, @@ -1009,6 +1011,7 @@ def test_fused_marlin_moe( @pytest.mark.flaky(reruns=2) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") +@pytest.mark.usefixtures("default_vllm_config") @pytest.mark.parametrize("m", [1, 256]) def test_fused_marlin_moe_with_bias(m): set_random_seed(0) @@ -1081,6 +1084,7 @@ def test_fused_marlin_moe_with_bias(m): @pytest.mark.flaky(reruns=2) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") +@pytest.mark.usefixtures("default_vllm_config") @pytest.mark.parametrize("m", [1, 64, 256]) @pytest.mark.parametrize("n,k", [(1024, 1024), (2048, 2048)]) @pytest.mark.parametrize("e,topk", [(8, 2), (64, 4)]) diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index f918212f763..8b35fab81ef 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -381,7 +381,8 @@ def marlin_generate_valid_test_cases(): for sub_case in inner_combinations: if ( sub_case[0] == scalar_types.float8_e4m3fn - and current_platform.get_device_capability() not in [89, 120] + and not current_platform.is_device_capability(89) + and not current_platform.is_device_capability_family(120) ): continue args = sub_case + (size_m, size_n, size_k) + case[4:] diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 4600cb36918..892d600b72f 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -480,9 +480,9 @@ def get_marlin_input_dtype(prefix: str | None = None): elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "fp8": if not current_platform.is_device_capability( 89 - ) and not current_platform.is_device_capability(120): + ) and not current_platform.is_device_capability_family(120): raise ValueError( - "Marlin W4A8-FP8 only support SM89 or SM120 device " + "Marlin W4A8-FP8 only support SM89 or SM12x device " "(It is slower than Marlin W4A16 on other devices). " "You can consider using W4A8-INT8 instead" "(set VLLM_MARLIN_INPUT_DTYPE=int8)."