[Bugfix] Fix SM121 (DGX Spark) exclusion from Marlin/CUTLASS FP8 paths (#35568)

Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
Blake Ledden
2026-05-15 13:59:00 -04:00
committed by GitHub
parent f45c210885
commit 06d020bb6e
8 changed files with 19 additions and 14 deletions
@@ -202,7 +202,7 @@ struct cutlass_3x_gemm_sm120 {
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel = enable_sm120_only<cutlass::gemm::kernel::GemmUniversal<
using GemmKernel = enable_sm120_family<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>>;
};
@@ -72,7 +72,7 @@ struct cutlass_3x_gemm_sm120_custom {
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule, void>::CollectiveOp;
using GemmKernel = enable_sm120_only<cutlass::gemm::kernel::GemmUniversal<
using GemmKernel = enable_sm120_family<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>>;
};
@@ -15,11 +15,11 @@ SUPPORT_SM80 = False
for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch)
# only SM89 and SM120 fully support
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
# SM89 and the SM12x family (SM120 RTX 5090, SM121 DGX Spark GB10)
# fully support mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
# SM90 and SM100 can use this PTX, but its simulated
# with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]:
if arch == 89 or arch // 10 == 12:
SUPPORT_FP8 = True
if arch >= 80:
SUPPORT_SM80 = True
+2 -2
View File
@@ -448,8 +448,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
"FP8 only support Ada Lovelace or newer GPUs.");
TORCH_CHECK(
major_capability * 10 + minor_capability == 89 ||
major_capability * 10 + minor_capability == 120,
"Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than "
major_capability == 12,
"Marlin W4A8-FP8 only support SM89 or SM12x device (It is slower than "
"Marlin W4A16 on other devices).");
}
+3 -3
View File
@@ -15,11 +15,11 @@ SUPPORT_SM80 = False
for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch)
# only SM89 and SM120 fully support
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
# SM89 and the SM12x family (SM120 RTX 5090, SM121 DGX Spark GB10)
# fully support mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
# SM90 and SM100 can use this PTX, but its simulated
# with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]:
if arch == 89 or arch // 10 == 12:
SUPPORT_FP8 = True
if arch >= 80:
SUPPORT_SM80 = True
+5 -1
View File
@@ -746,7 +746,8 @@ def marlin_moe_generate_valid_test_cases():
for sub_case in inner_combinations:
if (
sub_case[0] == scalar_types.float8_e4m3fn
and current_platform.get_device_capability() not in [89, 120]
and not current_platform.is_device_capability(89)
and not current_platform.is_device_capability_family(120)
):
continue
@@ -897,6 +898,7 @@ class MarlinMoEWeightData:
marlin_moe_generate_valid_test_cases(),
)
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
@pytest.mark.usefixtures("default_vllm_config")
def test_fused_marlin_moe(
a_type: ScalarType,
b_type: ScalarType,
@@ -1009,6 +1011,7 @@ def test_fused_marlin_moe(
@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
@pytest.mark.usefixtures("default_vllm_config")
@pytest.mark.parametrize("m", [1, 256])
def test_fused_marlin_moe_with_bias(m):
set_random_seed(0)
@@ -1081,6 +1084,7 @@ def test_fused_marlin_moe_with_bias(m):
@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
@pytest.mark.usefixtures("default_vllm_config")
@pytest.mark.parametrize("m", [1, 64, 256])
@pytest.mark.parametrize("n,k", [(1024, 1024), (2048, 2048)])
@pytest.mark.parametrize("e,topk", [(8, 2), (64, 4)])
@@ -381,7 +381,8 @@ def marlin_generate_valid_test_cases():
for sub_case in inner_combinations:
if (
sub_case[0] == scalar_types.float8_e4m3fn
and current_platform.get_device_capability() not in [89, 120]
and not current_platform.is_device_capability(89)
and not current_platform.is_device_capability_family(120)
):
continue
args = sub_case + (size_m, size_n, size_k) + case[4:]
@@ -480,9 +480,9 @@ def get_marlin_input_dtype(prefix: str | None = None):
elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "fp8":
if not current_platform.is_device_capability(
89
) and not current_platform.is_device_capability(120):
) and not current_platform.is_device_capability_family(120):
raise ValueError(
"Marlin W4A8-FP8 only support SM89 or SM120 device "
"Marlin W4A8-FP8 only support SM89 or SM12x device "
"(It is slower than Marlin W4A16 on other devices). "
"You can consider using W4A8-INT8 instead"
"(set VLLM_MARLIN_INPUT_DTYPE=int8)."