[ROCm][Quantization][3/N] Refactor quark_moe w4a4 w/ oracle (#41436)

Signed-off-by: Bowen Bao <bowenbao@amd.com>
This commit is contained in:
Bowen Bao
2026-05-18 10:46:13 -07:00
committed by GitHub
parent 6859ca7615
commit a2c8fc6657
8 changed files with 224 additions and 158 deletions
@@ -0,0 +1,12 @@
model_name: "amd/Qwen3.5-35B-A3B-MXFP4"
accuracy_threshold: 0.89
tolerance: 0.03
num_questions: 1319
num_fewshot: 5
server_args: >-
--max-model-len 4096
--tensor-parallel-size 2
--gpu-memory-utilization 0.35
--moe-backend aiter
env:
VLLM_ROCM_USE_AITER: "1"
@@ -1,8 +1,10 @@
model_name: "amd/Qwen3.5-35B-A3B-MXFP4"
accuracy_threshold: 0.82
accuracy_threshold: 0.89
tolerance: 0.03
num_questions: 1319
num_fewshot: 5
server_args: >-
--max-model-len 4096
--tensor-parallel-size 2
--moe-backend emulation
--gpu-memory-utilization 0.35
+2 -1
View File
@@ -3,4 +3,5 @@ DeepSeek-R1-DP_MI325.yaml
DeepSeek-V3.2-TP_MI325.yaml
DeepSeek-V3.2-DP_MI325.yaml
Qwen3-30B-A3B-NVFP4.yaml
Qwen3.5-35B-A3B-MXFP4-TP2.yaml
Qwen3.5-35B-A3B-MXFP4-AITER-TP2.yaml
Qwen3.5-35B-A3B-MXFP4-EMU-TP2.yaml
@@ -1,2 +1,3 @@
Qwen3.5-35B-A3B-DEP2.yaml
Qwen3.5-35B-A3B-MXFP4-TP2.yaml
Qwen3.5-35B-A3B-MXFP4-AITER-TP2.yaml
Qwen3.5-35B-A3B-MXFP4-EMU-TP2.yaml
+86 -2
View File
@@ -1,6 +1,90 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for MXFP4 MoE oracle backend selection on mi355x (GFX950).
These tests run on real hardware — no mocks. Skipped on non-GFX950 platforms.
"""
import pytest
import torch
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
Mxfp4MoeBackend,
select_mxfp4_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kMxfp4Dynamic,
)
from vllm.platforms import current_platform
ROCM_AVAILABLE = current_platform.is_rocm()
ROCM_GFX950 = False
ROCM_AITER_AVAILABLE = False
if ROCM_AVAILABLE:
from vllm._aiter_ops import rocm_aiter_ops
from vllm.platforms.rocm import on_gfx950
ROCM_GFX950 = on_gfx950()
ROCM_AITER_AVAILABLE = rocm_aiter_ops.is_fused_moe_enabled()
def test_mi355_moe():
print("TODO: add tests for Mi355 MoE quantization")
def _make_w4a4_moe_config(moe_backend: str = "auto") -> FusedMoEConfig:
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
return FusedMoEConfig(
num_experts=8,
experts_per_token=2,
hidden_dim=256,
intermediate_size_per_partition=256,
num_local_experts=8,
num_logical_experts=8,
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
activation=MoEActivation.SILU,
in_dtype=torch.bfloat16,
device="cuda",
routing_method=RoutingMethodType.Renormalize,
moe_backend=moe_backend,
)
@pytest.mark.skipif(not ROCM_GFX950, reason="Requires GFX950 (mi355x)")
@pytest.mark.skipif(not ROCM_AITER_AVAILABLE, reason="Requires AITER enabled")
def test_w4a4_dispatches_to_aiter():
"""With AITER enabled + GFX950, W4A4 selects AITER_MXFP4_MXFP4."""
config = _make_w4a4_moe_config()
backend, experts_cls = select_mxfp4_moe_backend(
config, activation_key=kMxfp4Dynamic
)
assert backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4
assert experts_cls is not None
@pytest.mark.skipif(not ROCM_GFX950, reason="Requires GFX950 (mi355x)")
@pytest.mark.skipif(
ROCM_AITER_AVAILABLE,
reason="Test requires AITER disabled (unset VLLM_ROCM_USE_AITER)",
)
def test_w4a4_raises_without_aiter_and_no_moe_backend():
"""Without AITER and no --moe-backend, raises NotImplementedError
with hint to use --moe-backend emulation."""
config = _make_w4a4_moe_config()
with pytest.raises(NotImplementedError, match="--moe-backend emulation"):
select_mxfp4_moe_backend(config, activation_key=kMxfp4Dynamic)
@pytest.mark.skipif(not ROCM_GFX950, reason="Requires GFX950 (mi355x)")
def test_w4a4_dispatches_to_emulation_with_moe_backend():
"""With --moe-backend emulation, W4A4 selects EMULATION."""
config = _make_w4a4_moe_config(moe_backend="emulation")
backend, experts_cls = select_mxfp4_moe_backend(
config, activation_key=kMxfp4Dynamic
)
assert backend == Mxfp4MoeBackend.EMULATION
assert experts_cls is not None
@@ -26,6 +26,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8Static128BlockSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
kMxfp4Dynamic,
kMxfp4Static,
)
@@ -377,6 +378,21 @@ class AiterExperts(mk.FusedMoEExpertsModular):
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def is_supported_config(
cls, moe_config, weight_key, activation_key, activation_format
):
is_supported, reason = super().is_supported_config(
cls, moe_config, weight_key, activation_key, activation_format
)
if not is_supported and not rocm_aiter_ops.is_fused_moe_enabled():
reason = (
f"{reason}. AITER MoE is not enabled — "
"set VLLM_ROCM_USE_AITER=1 and VLLM_ROCM_USE_AITER_MOE=1 "
"to enable it"
)
return is_supported, reason
@staticmethod
def _supports_current_device() -> bool:
return rocm_aiter_ops.is_fused_moe_enabled()
@@ -397,6 +413,7 @@ class AiterExperts(mk.FusedMoEExpertsModular):
(kFp8StaticTensorSym, kFp8DynamicTensorSym),
(kFp8StaticChannelSym, kFp8DynamicTokenSym),
(kMxfp4Static, None),
(kMxfp4Static, kMxfp4Dynamic),
]
if (weight_key, activation_key) not in SUPPORTED_W_A:
return False
@@ -31,6 +31,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
kFp8StaticTensorSym,
kMxfp4Dynamic,
kMxfp4Static,
kMxfp8Dynamic,
)
@@ -74,6 +75,7 @@ class Mxfp4MoeBackend(Enum):
# Keep the legacy name as an alias while the ROCm split backend rename settles.
AITER = "AITER_MXFP4_BF16"
AITER_MXFP4_FP8 = "AITER_MXFP4_FP8" # W4A8: triton kernel
AITER_MXFP4_MXFP4 = "AITER_MXFP4_MXFP4" # W4A4: CK kernel
# Triton
TRITON = "TRITON"
TRITON_UNFUSED = "TRITON_UNFUSED"
@@ -91,6 +93,7 @@ class Mxfp4MoeBackend(Enum):
AITER_BACKENDS = (
Mxfp4MoeBackend.AITER_MXFP4_BF16,
Mxfp4MoeBackend.AITER_MXFP4_FP8,
Mxfp4MoeBackend.AITER_MXFP4_MXFP4,
)
@@ -195,6 +198,13 @@ def backend_to_kernel_cls(
return [AiterW4A8ExpertsMonolithic]
elif backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4:
from vllm.model_executor.layers.fused_moe.experts.rocm_aiter_moe import (
AiterExperts,
)
return [AiterExperts]
elif backend == Mxfp4MoeBackend.XPU:
from vllm.model_executor.layers.fused_moe.experts.xpu_moe import XPUExpertsMXFp4
@@ -241,8 +251,10 @@ def map_mxfp4_backend(runner_backend: MoEBackend) -> list[Mxfp4MoeBackend]:
"aiter": [
Mxfp4MoeBackend.AITER_MXFP4_BF16,
Mxfp4MoeBackend.AITER_MXFP4_FP8,
Mxfp4MoeBackend.AITER_MXFP4_MXFP4,
],
"aiter_mxfp4_fp8": [Mxfp4MoeBackend.AITER_MXFP4_FP8],
"aiter_mxfp4_mxfp4": [Mxfp4MoeBackend.AITER_MXFP4_MXFP4],
"xpu": [Mxfp4MoeBackend.XPU],
"cpu": [Mxfp4MoeBackend.CPU],
"emulation": [Mxfp4MoeBackend.EMULATION],
@@ -263,6 +275,7 @@ def _get_priority_backends_for_gpt_oss() -> list[Mxfp4MoeBackend]:
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
Mxfp4MoeBackend.AITER_MXFP4_BF16,
Mxfp4MoeBackend.AITER_MXFP4_FP8,
Mxfp4MoeBackend.AITER_MXFP4_MXFP4,
Mxfp4MoeBackend.TRITON,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
@@ -272,7 +285,6 @@ def _get_priority_backends_for_gpt_oss() -> list[Mxfp4MoeBackend]:
Mxfp4MoeBackend.MARLIN,
Mxfp4MoeBackend.BATCHED_MARLIN,
Mxfp4MoeBackend.XPU,
Mxfp4MoeBackend.EMULATION,
]
return _AVAILABLE_BACKENDS
@@ -308,6 +320,8 @@ def _backend_activation_key(backend: Mxfp4MoeBackend) -> QuantKey | None:
return kMxfp8Dynamic
if backend == Mxfp4MoeBackend.AITER_MXFP4_FP8:
return kFp8StaticTensorSym
if backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4:
return kMxfp4Dynamic
return None # BF16 activation
@@ -379,6 +393,7 @@ def _filter_by_activation(
b
for b in backends
if _backend_activation_key(b) == requested_activation_key
or b == Mxfp4MoeBackend.EMULATION
]
bf16 = [b for b in backends if _backend_activation_key(b) is None]
return bf16 if bf16 else backends
@@ -556,7 +571,12 @@ def select_mxfp4_moe_backend(
if current_platform.is_cuda() or current_platform.is_rocm():
raise NotImplementedError(
"No MXFP4 MoE backend supports the deployment configuration."
"No MXFP4 MoE backend supports the deployment configuration. "
f"weight_key=kMxfp4Static, activation_key={activation_key}. "
"Native backends require specific hardware. "
"Set `VLLM_LOGGING_LEVEL=DEBUG` to see detailed unsupported reasons. "
"To use the emulation backend for research/debugging, pass "
"--moe-backend emulation."
)
return Mxfp4MoeBackend.NONE, None
@@ -957,6 +977,49 @@ def convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
w2_bias,
)
elif mxfp4_backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4:
from vllm._aiter_ops import rocm_aiter_ops
if w13_bias is not None:
w13_bias = w13_bias.data.to(torch.float32)
if w2_bias is not None:
w2_bias = w2_bias.data.to(torch.float32)
# e8m0_shuffle on weight scales (GFX950 swizzle layout)
from aiter.utility.fp4_utils import e8m0_shuffle
s0, s1, _ = w13_weight_scale.shape
w13_weight_scale.data = e8m0_shuffle(w13_weight_scale.view(s0 * s1, -1)).view(
s0, s1, -1
)
s0, s1, _ = w2_weight_scale.shape
w2_weight_scale.data = e8m0_shuffle(w2_weight_scale.view(s0 * s1, -1)).view(
s0, s1, -1
)
# View as native FP4 dtype
fp4_dtype = getattr(torch, "float4_e2m1fn_x2", None)
if fp4_dtype is not None:
w13_weight.data = w13_weight.data.view(fp4_dtype)
w2_weight.data = w2_weight.data.view(fp4_dtype)
# Shuffle weights for AITER CK kernel
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
w13_weight, w2_weight
)
shuffled_w13.is_shuffled = True
shuffled_w2.is_shuffled = True
return (
shuffled_w13,
shuffled_w2,
w13_weight_scale,
w2_weight_scale,
w13_bias,
w2_bias,
)
elif mxfp4_backend == Mxfp4MoeBackend.AITER_MXFP4_BF16:
from vllm._aiter_ops import rocm_aiter_ops
@@ -1541,6 +1604,17 @@ def make_mxfp4_moe_quant_config(
block_shape=None,
gemm1_clamp_limit=swiglu_limit,
)
elif mxfp4_backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4:
return ocp_mx_moe_quant_config(
quant_dtype="mxfp4",
w1_bias=w1_bias,
w2_bias=w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=swiglu_limit,
)
elif mxfp4_backend in (
Mxfp4MoeBackend.MARLIN,
Mxfp4MoeBackend.BATCHED_MARLIN,
@@ -55,6 +55,7 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8StaticTensorSym,
kMxfp4Dynamic,
kNvfp4Dynamic,
kNvfp4Static,
)
@@ -1042,6 +1043,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(
moe, activation_key=kFp8StaticTensorSym
)
elif self.ocp_mx_scheme == "w_mxfp4_a_mxfp4":
# W4A4: MXFP4 weights + MXFP4 activations
self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(
moe, activation_key=kMxfp4Dynamic
)
# Validation for unsupported schemes
if any(
@@ -1061,45 +1067,19 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
"Please open an issue."
)
self.use_rocm_aiter_moe = rocm_aiter_ops.is_fused_moe_enabled()
self.model_type = getattr(
get_current_vllm_config().model_config.hf_config, "model_type", None
)
# TODO: Remove once all OCP MX schemes use the kernel abstraction
_AITER_NATIVE_OCP_MX_SCHEMES = ("w_mxfp4", "w_mxfp4_a_mxfp4", "w_mxfp4_a_fp8")
self.emulate = (
not current_platform.supports_mx()
or self.ocp_mx_scheme not in _AITER_NATIVE_OCP_MX_SCHEMES
) and (
self.mxfp4_backend is Mxfp4MoeBackend.NONE or not self.use_rocm_aiter_moe
)
if self.emulate:
# We use the same code path between MXFP4/MXFP6 emulation.
# If no native backend available, use emulation.
if self.mxfp4_backend is Mxfp4MoeBackend.NONE:
self.mxfp4_backend = Mxfp4MoeBackend.EMULATION
# TODO: Remove `self.mxfp4_backend != Mxfp4MoeBackend.NONE` and make it so that
# all MXFP4 backends use the kernel abstraction.
if self.mxfp4_backend != Mxfp4MoeBackend.NONE:
self.experts_cls = backend_to_kernel_cls(self.mxfp4_backend)[0]
self.experts_cls = backend_to_kernel_cls(self.mxfp4_backend)[0]
# Log backend selection
if self.mxfp4_backend != Mxfp4MoeBackend.NONE:
logger.info_once(
f"Using {self.mxfp4_backend.value} backend for {self.ocp_mx_scheme}"
)
elif self.emulate:
logger.warning_once(
f"The current mode (supports_mx={current_platform.supports_mx()}, "
f"use_rocm_aiter_moe={self.use_rocm_aiter_moe}, "
f"ocp_mx_scheme={self.ocp_mx_scheme}) "
"does not support native MXFP4/MXFP6 "
"computation. Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision."
)
logger.info_once(
f"Using {self.mxfp4_backend.value} backend for {self.ocp_mx_scheme}"
)
def maybe_roundup_sizes(
self,
@@ -1243,97 +1223,8 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: RoutedExperts):
# For MXFP4 schemes with native backend, use oracle
if self.mxfp4_backend != Mxfp4MoeBackend.NONE:
self._setup_kernel(layer)
return
if self.static_input_scales and self.input_dtype == "fp8":
# firstly, process activations if fp8 static input
if layer.w13_input_scale is None or layer.w2_input_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
layer.w2_input_scale
):
logger.warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. "
)
layer.w13_input_scale = torch.nn.Parameter(
layer.w13_input_scale.max(), requires_grad=False
)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False
)
if current_platform.is_fp8_fnuz():
# Normalize the weights and scales
_, _, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
torch.empty_like(layer.w13_weight, dtype=torch.float8_e4m3fn),
torch.empty_like(
layer.w13_weight_scale, dtype=layer.w13_weight_scale.dtype
),
layer.w13_input_scale,
)
_, _, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
torch.empty_like(layer.w2_weight, dtype=torch.float8_e4m3fn),
torch.empty_like(
layer.w2_weight_scale, dtype=layer.w13_weight_scale.dtype
),
layer.w2_input_scale,
)
# Reset the parameter
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(
w13_input_scale, requires_grad=False
)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False
)
# TODO(bowenbao): gradually migrate to oracles.
# Existing AITER path for w_mxfp4_a_mxfp4 and other schemes
from aiter.utility.fp4_utils import e8m0_shuffle
# Pre-shuffle weight scales
s0, s1, _ = layer.w13_weight_scale.shape
w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1)
w13_weight_scale = e8m0_shuffle(w13_weight_scale)
layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1)
s0, s1, _ = layer.w2_weight_scale.shape
w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1)
w2_weight_scale = e8m0_shuffle(w2_weight_scale)
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
if self.fp4_dtype is not None:
layer.w13_weight = torch.nn.Parameter(
layer.w13_weight.view(self.fp4_dtype),
requires_grad=layer.w13_weight.requires_grad,
)
layer.w2_weight = torch.nn.Parameter(
layer.w2_weight.view(self.fp4_dtype),
requires_grad=layer.w2_weight.requires_grad,
)
# Pre-shuffle weight
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
layer.w13_weight.is_shuffled = True
layer.w2_weight.is_shuffled = True
# Build quant config for AITER path
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
torch.accelerator.empty_cache()
def process_weights_after_loading(self, layer):
self._setup_kernel(layer)
def _setup_kernel(self, layer: RoutedExperts):
"""Setup kernel using oracle functions for MXFP4 schemes (W4A16, W4A8)."""
@@ -1375,6 +1266,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
replace_parameter(layer, "w13_bias", w13_bias)
replace_parameter(layer, "w2_bias", w2_bias)
if self.mxfp4_backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4:
layer.w13_weight.is_shuffled = True
layer.w2_weight.is_shuffled = True
torch.accelerator.empty_cache()
# Build quant config and kernel
@@ -1464,38 +1359,18 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
shared_experts: SharedExperts | None,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor:
# For oracle-based kernels (W4A16, W4A8) or emulation kernel
if self.moe_kernel is not None:
return self.moe_kernel.apply(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=layer.expert_map,
shared_experts=shared_experts,
shared_experts_input=shared_experts_input,
)
# AITER path
# TODO: Refactor this to use modular MOE kernel as well.
from vllm.model_executor.layers.fused_moe.experts.rocm_aiter_moe import (
rocm_aiter_fused_experts,
)
return rocm_aiter_fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
assert self.moe_kernel is not None
return self.moe_kernel.apply(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
quant_config=self.moe_quant_config,
moe_config=layer.moe_config,
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=layer.expert_map,
shared_experts_input=shared_experts_input,
)
def apply_monolithic(