From a2c8fc66573664395f491a94da1882fdf92e034b Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Mon, 18 May 2026 10:46:13 -0700 Subject: [PATCH] [ROCm][Quantization][3/N] Refactor quark_moe w4a4 w/ oracle (#41436) Signed-off-by: Bowen Bao --- .../Qwen3.5-35B-A3B-MXFP4-AITER-TP2.yaml | 12 ++ ...aml => Qwen3.5-35B-A3B-MXFP4-EMU-TP2.yaml} | 4 +- tests/evals/gsm8k/configs/models-mi3xx.txt | 3 +- .../gsm8k/configs/models-qwen35-mi355.txt | 3 +- tests/quantization/test_gfx950_moe.py | 88 ++++++++- .../fused_moe/experts/rocm_aiter_moe.py | 17 ++ .../layers/fused_moe/oracle/mxfp4.py | 78 +++++++- .../layers/quantization/quark/quark_moe.py | 177 +++--------------- 8 files changed, 224 insertions(+), 158 deletions(-) create mode 100644 tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-MXFP4-AITER-TP2.yaml rename tests/evals/gsm8k/configs/{Qwen3.5-35B-A3B-MXFP4-TP2.yaml => Qwen3.5-35B-A3B-MXFP4-EMU-TP2.yaml} (65%) diff --git a/tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-MXFP4-AITER-TP2.yaml b/tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-MXFP4-AITER-TP2.yaml new file mode 100644 index 00000000000..657251a6603 --- /dev/null +++ b/tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-MXFP4-AITER-TP2.yaml @@ -0,0 +1,12 @@ +model_name: "amd/Qwen3.5-35B-A3B-MXFP4" +accuracy_threshold: 0.89 +tolerance: 0.03 +num_questions: 1319 +num_fewshot: 5 +server_args: >- + --max-model-len 4096 + --tensor-parallel-size 2 + --gpu-memory-utilization 0.35 + --moe-backend aiter +env: + VLLM_ROCM_USE_AITER: "1" diff --git a/tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-MXFP4-TP2.yaml b/tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-MXFP4-EMU-TP2.yaml similarity index 65% rename from tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-MXFP4-TP2.yaml rename to tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-MXFP4-EMU-TP2.yaml index 67eda114155..ad5ca701258 100644 --- a/tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-MXFP4-TP2.yaml +++ b/tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-MXFP4-EMU-TP2.yaml @@ -1,8 +1,10 @@ model_name: "amd/Qwen3.5-35B-A3B-MXFP4" -accuracy_threshold: 0.82 +accuracy_threshold: 0.89 tolerance: 0.03 num_questions: 1319 num_fewshot: 5 server_args: >- --max-model-len 4096 --tensor-parallel-size 2 + --moe-backend emulation + --gpu-memory-utilization 0.35 diff --git a/tests/evals/gsm8k/configs/models-mi3xx.txt b/tests/evals/gsm8k/configs/models-mi3xx.txt index dfa4bc8eb53..e8759d7d02b 100644 --- a/tests/evals/gsm8k/configs/models-mi3xx.txt +++ b/tests/evals/gsm8k/configs/models-mi3xx.txt @@ -3,4 +3,5 @@ DeepSeek-R1-DP_MI325.yaml DeepSeek-V3.2-TP_MI325.yaml DeepSeek-V3.2-DP_MI325.yaml Qwen3-30B-A3B-NVFP4.yaml -Qwen3.5-35B-A3B-MXFP4-TP2.yaml \ No newline at end of file +Qwen3.5-35B-A3B-MXFP4-AITER-TP2.yaml +Qwen3.5-35B-A3B-MXFP4-EMU-TP2.yaml \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/models-qwen35-mi355.txt b/tests/evals/gsm8k/configs/models-qwen35-mi355.txt index db8e88e2735..49925c827e3 100644 --- a/tests/evals/gsm8k/configs/models-qwen35-mi355.txt +++ b/tests/evals/gsm8k/configs/models-qwen35-mi355.txt @@ -1,2 +1,3 @@ Qwen3.5-35B-A3B-DEP2.yaml -Qwen3.5-35B-A3B-MXFP4-TP2.yaml +Qwen3.5-35B-A3B-MXFP4-AITER-TP2.yaml +Qwen3.5-35B-A3B-MXFP4-EMU-TP2.yaml diff --git a/tests/quantization/test_gfx950_moe.py b/tests/quantization/test_gfx950_moe.py index 9cb94086f73..4b65961d8db 100644 --- a/tests/quantization/test_gfx950_moe.py +++ b/tests/quantization/test_gfx950_moe.py @@ -1,6 +1,90 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for MXFP4 MoE oracle backend selection on mi355x (GFX950). + +These tests run on real hardware — no mocks. Skipped on non-GFX950 platforms. +""" + +import pytest +import torch + +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + RoutingMethodType, +) +from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import ( + Mxfp4MoeBackend, + select_mxfp4_moe_backend, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kMxfp4Dynamic, +) +from vllm.platforms import current_platform + +ROCM_AVAILABLE = current_platform.is_rocm() +ROCM_GFX950 = False +ROCM_AITER_AVAILABLE = False + +if ROCM_AVAILABLE: + from vllm._aiter_ops import rocm_aiter_ops + from vllm.platforms.rocm import on_gfx950 + + ROCM_GFX950 = on_gfx950() + ROCM_AITER_AVAILABLE = rocm_aiter_ops.is_fused_moe_enabled() -def test_mi355_moe(): - print("TODO: add tests for Mi355 MoE quantization") +def _make_w4a4_moe_config(moe_backend: str = "auto") -> FusedMoEConfig: + from vllm.model_executor.layers.fused_moe.activation import MoEActivation + + return FusedMoEConfig( + num_experts=8, + experts_per_token=2, + hidden_dim=256, + intermediate_size_per_partition=256, + num_local_experts=8, + num_logical_experts=8, + moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(), + activation=MoEActivation.SILU, + in_dtype=torch.bfloat16, + device="cuda", + routing_method=RoutingMethodType.Renormalize, + moe_backend=moe_backend, + ) + + +@pytest.mark.skipif(not ROCM_GFX950, reason="Requires GFX950 (mi355x)") +@pytest.mark.skipif(not ROCM_AITER_AVAILABLE, reason="Requires AITER enabled") +def test_w4a4_dispatches_to_aiter(): + """With AITER enabled + GFX950, W4A4 selects AITER_MXFP4_MXFP4.""" + config = _make_w4a4_moe_config() + backend, experts_cls = select_mxfp4_moe_backend( + config, activation_key=kMxfp4Dynamic + ) + assert backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4 + assert experts_cls is not None + + +@pytest.mark.skipif(not ROCM_GFX950, reason="Requires GFX950 (mi355x)") +@pytest.mark.skipif( + ROCM_AITER_AVAILABLE, + reason="Test requires AITER disabled (unset VLLM_ROCM_USE_AITER)", +) +def test_w4a4_raises_without_aiter_and_no_moe_backend(): + """Without AITER and no --moe-backend, raises NotImplementedError + with hint to use --moe-backend emulation.""" + config = _make_w4a4_moe_config() + with pytest.raises(NotImplementedError, match="--moe-backend emulation"): + select_mxfp4_moe_backend(config, activation_key=kMxfp4Dynamic) + + +@pytest.mark.skipif(not ROCM_GFX950, reason="Requires GFX950 (mi355x)") +def test_w4a4_dispatches_to_emulation_with_moe_backend(): + """With --moe-backend emulation, W4A4 selects EMULATION.""" + config = _make_w4a4_moe_config(moe_backend="emulation") + backend, experts_cls = select_mxfp4_moe_backend( + config, activation_key=kMxfp4Dynamic + ) + assert backend == Mxfp4MoeBackend.EMULATION + assert experts_cls is not None diff --git a/vllm/model_executor/layers/fused_moe/experts/rocm_aiter_moe.py b/vllm/model_executor/layers/fused_moe/experts/rocm_aiter_moe.py index be7d1f99215..2b8abbfc50d 100644 --- a/vllm/model_executor/layers/fused_moe/experts/rocm_aiter_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/rocm_aiter_moe.py @@ -26,6 +26,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8Static128BlockSym, kFp8StaticChannelSym, kFp8StaticTensorSym, + kMxfp4Dynamic, kMxfp4Static, ) @@ -377,6 +378,21 @@ class AiterExperts(mk.FusedMoEExpertsModular): def activation_format() -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard + @staticmethod + def is_supported_config( + cls, moe_config, weight_key, activation_key, activation_format + ): + is_supported, reason = super().is_supported_config( + cls, moe_config, weight_key, activation_key, activation_format + ) + if not is_supported and not rocm_aiter_ops.is_fused_moe_enabled(): + reason = ( + f"{reason}. AITER MoE is not enabled — " + "set VLLM_ROCM_USE_AITER=1 and VLLM_ROCM_USE_AITER_MOE=1 " + "to enable it" + ) + return is_supported, reason + @staticmethod def _supports_current_device() -> bool: return rocm_aiter_ops.is_fused_moe_enabled() @@ -397,6 +413,7 @@ class AiterExperts(mk.FusedMoEExpertsModular): (kFp8StaticTensorSym, kFp8DynamicTensorSym), (kFp8StaticChannelSym, kFp8DynamicTokenSym), (kMxfp4Static, None), + (kMxfp4Static, kMxfp4Dynamic), ] if (weight_key, activation_key) not in SUPPORTED_W_A: return False diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index 74210b69ed5..a69bae1b051 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -31,6 +31,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8Dynamic128Sym, kFp8StaticTensorSym, + kMxfp4Dynamic, kMxfp4Static, kMxfp8Dynamic, ) @@ -74,6 +75,7 @@ class Mxfp4MoeBackend(Enum): # Keep the legacy name as an alias while the ROCm split backend rename settles. AITER = "AITER_MXFP4_BF16" AITER_MXFP4_FP8 = "AITER_MXFP4_FP8" # W4A8: triton kernel + AITER_MXFP4_MXFP4 = "AITER_MXFP4_MXFP4" # W4A4: CK kernel # Triton TRITON = "TRITON" TRITON_UNFUSED = "TRITON_UNFUSED" @@ -91,6 +93,7 @@ class Mxfp4MoeBackend(Enum): AITER_BACKENDS = ( Mxfp4MoeBackend.AITER_MXFP4_BF16, Mxfp4MoeBackend.AITER_MXFP4_FP8, + Mxfp4MoeBackend.AITER_MXFP4_MXFP4, ) @@ -195,6 +198,13 @@ def backend_to_kernel_cls( return [AiterW4A8ExpertsMonolithic] + elif backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4: + from vllm.model_executor.layers.fused_moe.experts.rocm_aiter_moe import ( + AiterExperts, + ) + + return [AiterExperts] + elif backend == Mxfp4MoeBackend.XPU: from vllm.model_executor.layers.fused_moe.experts.xpu_moe import XPUExpertsMXFp4 @@ -241,8 +251,10 @@ def map_mxfp4_backend(runner_backend: MoEBackend) -> list[Mxfp4MoeBackend]: "aiter": [ Mxfp4MoeBackend.AITER_MXFP4_BF16, Mxfp4MoeBackend.AITER_MXFP4_FP8, + Mxfp4MoeBackend.AITER_MXFP4_MXFP4, ], "aiter_mxfp4_fp8": [Mxfp4MoeBackend.AITER_MXFP4_FP8], + "aiter_mxfp4_mxfp4": [Mxfp4MoeBackend.AITER_MXFP4_MXFP4], "xpu": [Mxfp4MoeBackend.XPU], "cpu": [Mxfp4MoeBackend.CPU], "emulation": [Mxfp4MoeBackend.EMULATION], @@ -263,6 +275,7 @@ def _get_priority_backends_for_gpt_oss() -> list[Mxfp4MoeBackend]: Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, Mxfp4MoeBackend.AITER_MXFP4_BF16, Mxfp4MoeBackend.AITER_MXFP4_FP8, + Mxfp4MoeBackend.AITER_MXFP4_MXFP4, Mxfp4MoeBackend.TRITON, Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, @@ -272,7 +285,6 @@ def _get_priority_backends_for_gpt_oss() -> list[Mxfp4MoeBackend]: Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN, Mxfp4MoeBackend.XPU, - Mxfp4MoeBackend.EMULATION, ] return _AVAILABLE_BACKENDS @@ -308,6 +320,8 @@ def _backend_activation_key(backend: Mxfp4MoeBackend) -> QuantKey | None: return kMxfp8Dynamic if backend == Mxfp4MoeBackend.AITER_MXFP4_FP8: return kFp8StaticTensorSym + if backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4: + return kMxfp4Dynamic return None # BF16 activation @@ -379,6 +393,7 @@ def _filter_by_activation( b for b in backends if _backend_activation_key(b) == requested_activation_key + or b == Mxfp4MoeBackend.EMULATION ] bf16 = [b for b in backends if _backend_activation_key(b) is None] return bf16 if bf16 else backends @@ -556,7 +571,12 @@ def select_mxfp4_moe_backend( if current_platform.is_cuda() or current_platform.is_rocm(): raise NotImplementedError( - "No MXFP4 MoE backend supports the deployment configuration." + "No MXFP4 MoE backend supports the deployment configuration. " + f"weight_key=kMxfp4Static, activation_key={activation_key}. " + "Native backends require specific hardware. " + "Set `VLLM_LOGGING_LEVEL=DEBUG` to see detailed unsupported reasons. " + "To use the emulation backend for research/debugging, pass " + "--moe-backend emulation." ) return Mxfp4MoeBackend.NONE, None @@ -957,6 +977,49 @@ def convert_gpt_oss_weight_to_mxfp4_moe_kernel_format( w2_bias, ) + elif mxfp4_backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4: + from vllm._aiter_ops import rocm_aiter_ops + + if w13_bias is not None: + w13_bias = w13_bias.data.to(torch.float32) + if w2_bias is not None: + w2_bias = w2_bias.data.to(torch.float32) + + # e8m0_shuffle on weight scales (GFX950 swizzle layout) + from aiter.utility.fp4_utils import e8m0_shuffle + + s0, s1, _ = w13_weight_scale.shape + w13_weight_scale.data = e8m0_shuffle(w13_weight_scale.view(s0 * s1, -1)).view( + s0, s1, -1 + ) + + s0, s1, _ = w2_weight_scale.shape + w2_weight_scale.data = e8m0_shuffle(w2_weight_scale.view(s0 * s1, -1)).view( + s0, s1, -1 + ) + + # View as native FP4 dtype + fp4_dtype = getattr(torch, "float4_e2m1fn_x2", None) + if fp4_dtype is not None: + w13_weight.data = w13_weight.data.view(fp4_dtype) + w2_weight.data = w2_weight.data.view(fp4_dtype) + + # Shuffle weights for AITER CK kernel + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( + w13_weight, w2_weight + ) + shuffled_w13.is_shuffled = True + shuffled_w2.is_shuffled = True + + return ( + shuffled_w13, + shuffled_w2, + w13_weight_scale, + w2_weight_scale, + w13_bias, + w2_bias, + ) + elif mxfp4_backend == Mxfp4MoeBackend.AITER_MXFP4_BF16: from vllm._aiter_ops import rocm_aiter_ops @@ -1541,6 +1604,17 @@ def make_mxfp4_moe_quant_config( block_shape=None, gemm1_clamp_limit=swiglu_limit, ) + elif mxfp4_backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4: + return ocp_mx_moe_quant_config( + quant_dtype="mxfp4", + w1_bias=w1_bias, + w2_bias=w2_bias, + w1_scale=w1_scale, + w2_scale=w2_scale, + gemm1_alpha=gemm1_alpha, + gemm1_beta=gemm1_beta, + gemm1_clamp_limit=swiglu_limit, + ) elif mxfp4_backend in ( Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN, diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index ad2b842e782..627a44a77b5 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -55,6 +55,7 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, kFp8StaticTensorSym, + kMxfp4Dynamic, kNvfp4Dynamic, kNvfp4Static, ) @@ -1042,6 +1043,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend( moe, activation_key=kFp8StaticTensorSym ) + elif self.ocp_mx_scheme == "w_mxfp4_a_mxfp4": + # W4A4: MXFP4 weights + MXFP4 activations + self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend( + moe, activation_key=kMxfp4Dynamic + ) # Validation for unsupported schemes if any( @@ -1061,45 +1067,19 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): "Please open an issue." ) - self.use_rocm_aiter_moe = rocm_aiter_ops.is_fused_moe_enabled() - self.model_type = getattr( get_current_vllm_config().model_config.hf_config, "model_type", None ) - # TODO: Remove once all OCP MX schemes use the kernel abstraction - _AITER_NATIVE_OCP_MX_SCHEMES = ("w_mxfp4", "w_mxfp4_a_mxfp4", "w_mxfp4_a_fp8") - self.emulate = ( - not current_platform.supports_mx() - or self.ocp_mx_scheme not in _AITER_NATIVE_OCP_MX_SCHEMES - ) and ( - self.mxfp4_backend is Mxfp4MoeBackend.NONE or not self.use_rocm_aiter_moe - ) - - if self.emulate: - # We use the same code path between MXFP4/MXFP6 emulation. + # If no native backend available, use emulation. + if self.mxfp4_backend is Mxfp4MoeBackend.NONE: self.mxfp4_backend = Mxfp4MoeBackend.EMULATION - # TODO: Remove `self.mxfp4_backend != Mxfp4MoeBackend.NONE` and make it so that - # all MXFP4 backends use the kernel abstraction. - if self.mxfp4_backend != Mxfp4MoeBackend.NONE: - self.experts_cls = backend_to_kernel_cls(self.mxfp4_backend)[0] + self.experts_cls = backend_to_kernel_cls(self.mxfp4_backend)[0] - # Log backend selection - if self.mxfp4_backend != Mxfp4MoeBackend.NONE: - logger.info_once( - f"Using {self.mxfp4_backend.value} backend for {self.ocp_mx_scheme}" - ) - elif self.emulate: - logger.warning_once( - f"The current mode (supports_mx={current_platform.supports_mx()}, " - f"use_rocm_aiter_moe={self.use_rocm_aiter_moe}, " - f"ocp_mx_scheme={self.ocp_mx_scheme}) " - "does not support native MXFP4/MXFP6 " - "computation. Simulated weight dequantization and activation " - "QDQ (quantize and dequantize) will be used, with the linear " - "layers computed in high precision." - ) + logger.info_once( + f"Using {self.mxfp4_backend.value} backend for {self.ocp_mx_scheme}" + ) def maybe_roundup_sizes( self, @@ -1243,97 +1223,8 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): layer.w13_input_scale = None layer.w2_input_scale = None - def process_weights_after_loading(self, layer: RoutedExperts): - # For MXFP4 schemes with native backend, use oracle - if self.mxfp4_backend != Mxfp4MoeBackend.NONE: - self._setup_kernel(layer) - return - - if self.static_input_scales and self.input_dtype == "fp8": - # firstly, process activations if fp8 static input - if layer.w13_input_scale is None or layer.w2_input_scale is None: - raise ValueError( - "QuantConfig has static quantization, but found " - "activation scales are None." - ) - if not all_close_1d(layer.w13_input_scale) or not all_close_1d( - layer.w2_input_scale - ): - logger.warning_once( - "Found input_scales that are not equal for " - "fp8 MoE layer. Using the maximum across experts " - "for each layer. " - ) - layer.w13_input_scale = torch.nn.Parameter( - layer.w13_input_scale.max(), requires_grad=False - ) - layer.w2_input_scale = torch.nn.Parameter( - layer.w2_input_scale.max(), requires_grad=False - ) - - if current_platform.is_fp8_fnuz(): - # Normalize the weights and scales - _, _, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz( - torch.empty_like(layer.w13_weight, dtype=torch.float8_e4m3fn), - torch.empty_like( - layer.w13_weight_scale, dtype=layer.w13_weight_scale.dtype - ), - layer.w13_input_scale, - ) - _, _, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz( - torch.empty_like(layer.w2_weight, dtype=torch.float8_e4m3fn), - torch.empty_like( - layer.w2_weight_scale, dtype=layer.w13_weight_scale.dtype - ), - layer.w2_input_scale, - ) - # Reset the parameter - if w13_input_scale is not None: - layer.w13_input_scale = torch.nn.Parameter( - w13_input_scale, requires_grad=False - ) - if w2_input_scale is not None: - layer.w2_input_scale = torch.nn.Parameter( - w2_input_scale, requires_grad=False - ) - - # TODO(bowenbao): gradually migrate to oracles. - # Existing AITER path for w_mxfp4_a_mxfp4 and other schemes - from aiter.utility.fp4_utils import e8m0_shuffle - - # Pre-shuffle weight scales - s0, s1, _ = layer.w13_weight_scale.shape - w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1) - w13_weight_scale = e8m0_shuffle(w13_weight_scale) - layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1) - - s0, s1, _ = layer.w2_weight_scale.shape - w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1) - w2_weight_scale = e8m0_shuffle(w2_weight_scale) - layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1) - - if self.fp4_dtype is not None: - layer.w13_weight = torch.nn.Parameter( - layer.w13_weight.view(self.fp4_dtype), - requires_grad=layer.w13_weight.requires_grad, - ) - layer.w2_weight = torch.nn.Parameter( - layer.w2_weight.view(self.fp4_dtype), - requires_grad=layer.w2_weight.requires_grad, - ) - # Pre-shuffle weight - shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( - layer.w13_weight.data, layer.w2_weight.data - ) - - layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) - layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) - layer.w13_weight.is_shuffled = True - layer.w2_weight.is_shuffled = True - - # Build quant config for AITER path - self.moe_quant_config = self.get_fused_moe_quant_config(layer) - torch.accelerator.empty_cache() + def process_weights_after_loading(self, layer): + self._setup_kernel(layer) def _setup_kernel(self, layer: RoutedExperts): """Setup kernel using oracle functions for MXFP4 schemes (W4A16, W4A8).""" @@ -1375,6 +1266,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): replace_parameter(layer, "w13_bias", w13_bias) replace_parameter(layer, "w2_bias", w2_bias) + if self.mxfp4_backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4: + layer.w13_weight.is_shuffled = True + layer.w2_weight.is_shuffled = True + torch.accelerator.empty_cache() # Build quant config and kernel @@ -1464,38 +1359,18 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): shared_experts: SharedExperts | None, shared_experts_input: torch.Tensor | None, ) -> torch.Tensor: - # For oracle-based kernels (W4A16, W4A8) or emulation kernel - if self.moe_kernel is not None: - return self.moe_kernel.apply( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=layer.activation, - global_num_experts=layer.global_num_experts, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - expert_map=layer.expert_map, - shared_experts=shared_experts, - shared_experts_input=shared_experts_input, - ) - - # AITER path - # TODO: Refactor this to use modular MOE kernel as well. - from vllm.model_executor.layers.fused_moe.experts.rocm_aiter_moe import ( - rocm_aiter_fused_experts, - ) - - return rocm_aiter_fused_experts( - x, - layer.w13_weight, - layer.w2_weight, + assert self.moe_kernel is not None + return self.moe_kernel.apply( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, activation=layer.activation, - quant_config=self.moe_quant_config, - moe_config=layer.moe_config, + global_num_experts=layer.global_num_experts, + apply_router_weight_on_input=layer.apply_router_weight_on_input, expert_map=layer.expert_map, + shared_experts_input=shared_experts_input, ) def apply_monolithic(