mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Bugfix] Fix SM121 (DGX Spark) exclusion from Marlin/CUTLASS FP8 paths (#35568)
Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -202,7 +202,7 @@ struct cutlass_3x_gemm_sm120 {
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = enable_sm120_only<cutlass::gemm::kernel::GemmUniversal<
|
||||
using GemmKernel = enable_sm120_family<cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>>;
|
||||
};
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ struct cutlass_3x_gemm_sm120_custom {
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule, void>::CollectiveOp;
|
||||
|
||||
using GemmKernel = enable_sm120_only<cutlass::gemm::kernel::GemmUniversal<
|
||||
using GemmKernel = enable_sm120_family<cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>>;
|
||||
};
|
||||
|
||||
|
||||
@@ -15,11 +15,11 @@ SUPPORT_SM80 = False
|
||||
for arch in sys.argv[1].split(","):
|
||||
arch = arch[: arch.index(".") + 2].replace(".", "")
|
||||
arch = int(arch)
|
||||
# only SM89 and SM120 fully support
|
||||
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
|
||||
# SM89 and the SM12x family (SM120 RTX 5090, SM121 DGX Spark GB10)
|
||||
# fully support mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
|
||||
# SM90 and SM100 can use this PTX, but it’s simulated
|
||||
# with FP16 MMA, so it cannot achieve any acceleration.
|
||||
if arch in [89, 120]:
|
||||
if arch == 89 or arch // 10 == 12:
|
||||
SUPPORT_FP8 = True
|
||||
if arch >= 80:
|
||||
SUPPORT_SM80 = True
|
||||
|
||||
@@ -448,8 +448,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
"FP8 only support Ada Lovelace or newer GPUs.");
|
||||
TORCH_CHECK(
|
||||
major_capability * 10 + minor_capability == 89 ||
|
||||
major_capability * 10 + minor_capability == 120,
|
||||
"Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than "
|
||||
major_capability == 12,
|
||||
"Marlin W4A8-FP8 only support SM89 or SM12x device (It is slower than "
|
||||
"Marlin W4A16 on other devices).");
|
||||
}
|
||||
|
||||
|
||||
@@ -15,11 +15,11 @@ SUPPORT_SM80 = False
|
||||
for arch in sys.argv[1].split(","):
|
||||
arch = arch[: arch.index(".") + 2].replace(".", "")
|
||||
arch = int(arch)
|
||||
# only SM89 and SM120 fully support
|
||||
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
|
||||
# SM89 and the SM12x family (SM120 RTX 5090, SM121 DGX Spark GB10)
|
||||
# fully support mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
|
||||
# SM90 and SM100 can use this PTX, but it’s simulated
|
||||
# with FP16 MMA, so it cannot achieve any acceleration.
|
||||
if arch in [89, 120]:
|
||||
if arch == 89 or arch // 10 == 12:
|
||||
SUPPORT_FP8 = True
|
||||
if arch >= 80:
|
||||
SUPPORT_SM80 = True
|
||||
|
||||
@@ -746,7 +746,8 @@ def marlin_moe_generate_valid_test_cases():
|
||||
for sub_case in inner_combinations:
|
||||
if (
|
||||
sub_case[0] == scalar_types.float8_e4m3fn
|
||||
and current_platform.get_device_capability() not in [89, 120]
|
||||
and not current_platform.is_device_capability(89)
|
||||
and not current_platform.is_device_capability_family(120)
|
||||
):
|
||||
continue
|
||||
|
||||
@@ -897,6 +898,7 @@ class MarlinMoEWeightData:
|
||||
marlin_moe_generate_valid_test_cases(),
|
||||
)
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
@pytest.mark.usefixtures("default_vllm_config")
|
||||
def test_fused_marlin_moe(
|
||||
a_type: ScalarType,
|
||||
b_type: ScalarType,
|
||||
@@ -1009,6 +1011,7 @@ def test_fused_marlin_moe(
|
||||
|
||||
@pytest.mark.flaky(reruns=2)
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
@pytest.mark.usefixtures("default_vllm_config")
|
||||
@pytest.mark.parametrize("m", [1, 256])
|
||||
def test_fused_marlin_moe_with_bias(m):
|
||||
set_random_seed(0)
|
||||
@@ -1081,6 +1084,7 @@ def test_fused_marlin_moe_with_bias(m):
|
||||
|
||||
@pytest.mark.flaky(reruns=2)
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
@pytest.mark.usefixtures("default_vllm_config")
|
||||
@pytest.mark.parametrize("m", [1, 64, 256])
|
||||
@pytest.mark.parametrize("n,k", [(1024, 1024), (2048, 2048)])
|
||||
@pytest.mark.parametrize("e,topk", [(8, 2), (64, 4)])
|
||||
|
||||
@@ -381,7 +381,8 @@ def marlin_generate_valid_test_cases():
|
||||
for sub_case in inner_combinations:
|
||||
if (
|
||||
sub_case[0] == scalar_types.float8_e4m3fn
|
||||
and current_platform.get_device_capability() not in [89, 120]
|
||||
and not current_platform.is_device_capability(89)
|
||||
and not current_platform.is_device_capability_family(120)
|
||||
):
|
||||
continue
|
||||
args = sub_case + (size_m, size_n, size_k) + case[4:]
|
||||
|
||||
@@ -480,9 +480,9 @@ def get_marlin_input_dtype(prefix: str | None = None):
|
||||
elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "fp8":
|
||||
if not current_platform.is_device_capability(
|
||||
89
|
||||
) and not current_platform.is_device_capability(120):
|
||||
) and not current_platform.is_device_capability_family(120):
|
||||
raise ValueError(
|
||||
"Marlin W4A8-FP8 only support SM89 or SM120 device "
|
||||
"Marlin W4A8-FP8 only support SM89 or SM12x device "
|
||||
"(It is slower than Marlin W4A16 on other devices). "
|
||||
"You can consider using W4A8-INT8 instead"
|
||||
"(set VLLM_MARLIN_INPUT_DTYPE=int8)."
|
||||
|
||||
Reference in New Issue
Block a user