mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[ROCm][Quantization][3/N] Refactor quark_moe w4a4 w/ oracle (#41436)
Signed-off-by: Bowen Bao <bowenbao@amd.com>
This commit is contained in:
@@ -0,0 +1,12 @@
|
||||
model_name: "amd/Qwen3.5-35B-A3B-MXFP4"
|
||||
accuracy_threshold: 0.89
|
||||
tolerance: 0.03
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
server_args: >-
|
||||
--max-model-len 4096
|
||||
--tensor-parallel-size 2
|
||||
--gpu-memory-utilization 0.35
|
||||
--moe-backend aiter
|
||||
env:
|
||||
VLLM_ROCM_USE_AITER: "1"
|
||||
+3
-1
@@ -1,8 +1,10 @@
|
||||
model_name: "amd/Qwen3.5-35B-A3B-MXFP4"
|
||||
accuracy_threshold: 0.82
|
||||
accuracy_threshold: 0.89
|
||||
tolerance: 0.03
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
server_args: >-
|
||||
--max-model-len 4096
|
||||
--tensor-parallel-size 2
|
||||
--moe-backend emulation
|
||||
--gpu-memory-utilization 0.35
|
||||
@@ -3,4 +3,5 @@ DeepSeek-R1-DP_MI325.yaml
|
||||
DeepSeek-V3.2-TP_MI325.yaml
|
||||
DeepSeek-V3.2-DP_MI325.yaml
|
||||
Qwen3-30B-A3B-NVFP4.yaml
|
||||
Qwen3.5-35B-A3B-MXFP4-TP2.yaml
|
||||
Qwen3.5-35B-A3B-MXFP4-AITER-TP2.yaml
|
||||
Qwen3.5-35B-A3B-MXFP4-EMU-TP2.yaml
|
||||
@@ -1,2 +1,3 @@
|
||||
Qwen3.5-35B-A3B-DEP2.yaml
|
||||
Qwen3.5-35B-A3B-MXFP4-TP2.yaml
|
||||
Qwen3.5-35B-A3B-MXFP4-AITER-TP2.yaml
|
||||
Qwen3.5-35B-A3B-MXFP4-EMU-TP2.yaml
|
||||
|
||||
@@ -1,6 +1,90 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests for MXFP4 MoE oracle backend selection on mi355x (GFX950).
|
||||
|
||||
These tests run on real hardware — no mocks. Skipped on non-GFX950 platforms.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
|
||||
Mxfp4MoeBackend,
|
||||
select_mxfp4_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kMxfp4Dynamic,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
ROCM_AVAILABLE = current_platform.is_rocm()
|
||||
ROCM_GFX950 = False
|
||||
ROCM_AITER_AVAILABLE = False
|
||||
|
||||
if ROCM_AVAILABLE:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.platforms.rocm import on_gfx950
|
||||
|
||||
ROCM_GFX950 = on_gfx950()
|
||||
ROCM_AITER_AVAILABLE = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
|
||||
|
||||
def test_mi355_moe():
|
||||
print("TODO: add tests for Mi355 MoE quantization")
|
||||
def _make_w4a4_moe_config(moe_backend: str = "auto") -> FusedMoEConfig:
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
|
||||
return FusedMoEConfig(
|
||||
num_experts=8,
|
||||
experts_per_token=2,
|
||||
hidden_dim=256,
|
||||
intermediate_size_per_partition=256,
|
||||
num_local_experts=8,
|
||||
num_logical_experts=8,
|
||||
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
|
||||
activation=MoEActivation.SILU,
|
||||
in_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
routing_method=RoutingMethodType.Renormalize,
|
||||
moe_backend=moe_backend,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not ROCM_GFX950, reason="Requires GFX950 (mi355x)")
|
||||
@pytest.mark.skipif(not ROCM_AITER_AVAILABLE, reason="Requires AITER enabled")
|
||||
def test_w4a4_dispatches_to_aiter():
|
||||
"""With AITER enabled + GFX950, W4A4 selects AITER_MXFP4_MXFP4."""
|
||||
config = _make_w4a4_moe_config()
|
||||
backend, experts_cls = select_mxfp4_moe_backend(
|
||||
config, activation_key=kMxfp4Dynamic
|
||||
)
|
||||
assert backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4
|
||||
assert experts_cls is not None
|
||||
|
||||
|
||||
@pytest.mark.skipif(not ROCM_GFX950, reason="Requires GFX950 (mi355x)")
|
||||
@pytest.mark.skipif(
|
||||
ROCM_AITER_AVAILABLE,
|
||||
reason="Test requires AITER disabled (unset VLLM_ROCM_USE_AITER)",
|
||||
)
|
||||
def test_w4a4_raises_without_aiter_and_no_moe_backend():
|
||||
"""Without AITER and no --moe-backend, raises NotImplementedError
|
||||
with hint to use --moe-backend emulation."""
|
||||
config = _make_w4a4_moe_config()
|
||||
with pytest.raises(NotImplementedError, match="--moe-backend emulation"):
|
||||
select_mxfp4_moe_backend(config, activation_key=kMxfp4Dynamic)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not ROCM_GFX950, reason="Requires GFX950 (mi355x)")
|
||||
def test_w4a4_dispatches_to_emulation_with_moe_backend():
|
||||
"""With --moe-backend emulation, W4A4 selects EMULATION."""
|
||||
config = _make_w4a4_moe_config(moe_backend="emulation")
|
||||
backend, experts_cls = select_mxfp4_moe_backend(
|
||||
config, activation_key=kMxfp4Dynamic
|
||||
)
|
||||
assert backend == Mxfp4MoeBackend.EMULATION
|
||||
assert experts_cls is not None
|
||||
|
||||
@@ -26,6 +26,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8Static128BlockSym,
|
||||
kFp8StaticChannelSym,
|
||||
kFp8StaticTensorSym,
|
||||
kMxfp4Dynamic,
|
||||
kMxfp4Static,
|
||||
)
|
||||
|
||||
@@ -377,6 +378,21 @@ class AiterExperts(mk.FusedMoEExpertsModular):
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
@staticmethod
|
||||
def is_supported_config(
|
||||
cls, moe_config, weight_key, activation_key, activation_format
|
||||
):
|
||||
is_supported, reason = super().is_supported_config(
|
||||
cls, moe_config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if not is_supported and not rocm_aiter_ops.is_fused_moe_enabled():
|
||||
reason = (
|
||||
f"{reason}. AITER MoE is not enabled — "
|
||||
"set VLLM_ROCM_USE_AITER=1 and VLLM_ROCM_USE_AITER_MOE=1 "
|
||||
"to enable it"
|
||||
)
|
||||
return is_supported, reason
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
return rocm_aiter_ops.is_fused_moe_enabled()
|
||||
@@ -397,6 +413,7 @@ class AiterExperts(mk.FusedMoEExpertsModular):
|
||||
(kFp8StaticTensorSym, kFp8DynamicTensorSym),
|
||||
(kFp8StaticChannelSym, kFp8DynamicTokenSym),
|
||||
(kMxfp4Static, None),
|
||||
(kMxfp4Static, kMxfp4Dynamic),
|
||||
]
|
||||
if (weight_key, activation_key) not in SUPPORTED_W_A:
|
||||
return False
|
||||
|
||||
@@ -31,6 +31,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8StaticTensorSym,
|
||||
kMxfp4Dynamic,
|
||||
kMxfp4Static,
|
||||
kMxfp8Dynamic,
|
||||
)
|
||||
@@ -74,6 +75,7 @@ class Mxfp4MoeBackend(Enum):
|
||||
# Keep the legacy name as an alias while the ROCm split backend rename settles.
|
||||
AITER = "AITER_MXFP4_BF16"
|
||||
AITER_MXFP4_FP8 = "AITER_MXFP4_FP8" # W4A8: triton kernel
|
||||
AITER_MXFP4_MXFP4 = "AITER_MXFP4_MXFP4" # W4A4: CK kernel
|
||||
# Triton
|
||||
TRITON = "TRITON"
|
||||
TRITON_UNFUSED = "TRITON_UNFUSED"
|
||||
@@ -91,6 +93,7 @@ class Mxfp4MoeBackend(Enum):
|
||||
AITER_BACKENDS = (
|
||||
Mxfp4MoeBackend.AITER_MXFP4_BF16,
|
||||
Mxfp4MoeBackend.AITER_MXFP4_FP8,
|
||||
Mxfp4MoeBackend.AITER_MXFP4_MXFP4,
|
||||
)
|
||||
|
||||
|
||||
@@ -195,6 +198,13 @@ def backend_to_kernel_cls(
|
||||
|
||||
return [AiterW4A8ExpertsMonolithic]
|
||||
|
||||
elif backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4:
|
||||
from vllm.model_executor.layers.fused_moe.experts.rocm_aiter_moe import (
|
||||
AiterExperts,
|
||||
)
|
||||
|
||||
return [AiterExperts]
|
||||
|
||||
elif backend == Mxfp4MoeBackend.XPU:
|
||||
from vllm.model_executor.layers.fused_moe.experts.xpu_moe import XPUExpertsMXFp4
|
||||
|
||||
@@ -241,8 +251,10 @@ def map_mxfp4_backend(runner_backend: MoEBackend) -> list[Mxfp4MoeBackend]:
|
||||
"aiter": [
|
||||
Mxfp4MoeBackend.AITER_MXFP4_BF16,
|
||||
Mxfp4MoeBackend.AITER_MXFP4_FP8,
|
||||
Mxfp4MoeBackend.AITER_MXFP4_MXFP4,
|
||||
],
|
||||
"aiter_mxfp4_fp8": [Mxfp4MoeBackend.AITER_MXFP4_FP8],
|
||||
"aiter_mxfp4_mxfp4": [Mxfp4MoeBackend.AITER_MXFP4_MXFP4],
|
||||
"xpu": [Mxfp4MoeBackend.XPU],
|
||||
"cpu": [Mxfp4MoeBackend.CPU],
|
||||
"emulation": [Mxfp4MoeBackend.EMULATION],
|
||||
@@ -263,6 +275,7 @@ def _get_priority_backends_for_gpt_oss() -> list[Mxfp4MoeBackend]:
|
||||
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
|
||||
Mxfp4MoeBackend.AITER_MXFP4_BF16,
|
||||
Mxfp4MoeBackend.AITER_MXFP4_FP8,
|
||||
Mxfp4MoeBackend.AITER_MXFP4_MXFP4,
|
||||
Mxfp4MoeBackend.TRITON,
|
||||
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
|
||||
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
|
||||
@@ -272,7 +285,6 @@ def _get_priority_backends_for_gpt_oss() -> list[Mxfp4MoeBackend]:
|
||||
Mxfp4MoeBackend.MARLIN,
|
||||
Mxfp4MoeBackend.BATCHED_MARLIN,
|
||||
Mxfp4MoeBackend.XPU,
|
||||
Mxfp4MoeBackend.EMULATION,
|
||||
]
|
||||
return _AVAILABLE_BACKENDS
|
||||
|
||||
@@ -308,6 +320,8 @@ def _backend_activation_key(backend: Mxfp4MoeBackend) -> QuantKey | None:
|
||||
return kMxfp8Dynamic
|
||||
if backend == Mxfp4MoeBackend.AITER_MXFP4_FP8:
|
||||
return kFp8StaticTensorSym
|
||||
if backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4:
|
||||
return kMxfp4Dynamic
|
||||
return None # BF16 activation
|
||||
|
||||
|
||||
@@ -379,6 +393,7 @@ def _filter_by_activation(
|
||||
b
|
||||
for b in backends
|
||||
if _backend_activation_key(b) == requested_activation_key
|
||||
or b == Mxfp4MoeBackend.EMULATION
|
||||
]
|
||||
bf16 = [b for b in backends if _backend_activation_key(b) is None]
|
||||
return bf16 if bf16 else backends
|
||||
@@ -556,7 +571,12 @@ def select_mxfp4_moe_backend(
|
||||
|
||||
if current_platform.is_cuda() or current_platform.is_rocm():
|
||||
raise NotImplementedError(
|
||||
"No MXFP4 MoE backend supports the deployment configuration."
|
||||
"No MXFP4 MoE backend supports the deployment configuration. "
|
||||
f"weight_key=kMxfp4Static, activation_key={activation_key}. "
|
||||
"Native backends require specific hardware. "
|
||||
"Set `VLLM_LOGGING_LEVEL=DEBUG` to see detailed unsupported reasons. "
|
||||
"To use the emulation backend for research/debugging, pass "
|
||||
"--moe-backend emulation."
|
||||
)
|
||||
|
||||
return Mxfp4MoeBackend.NONE, None
|
||||
@@ -957,6 +977,49 @@ def convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
|
||||
w2_bias,
|
||||
)
|
||||
|
||||
elif mxfp4_backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
if w13_bias is not None:
|
||||
w13_bias = w13_bias.data.to(torch.float32)
|
||||
if w2_bias is not None:
|
||||
w2_bias = w2_bias.data.to(torch.float32)
|
||||
|
||||
# e8m0_shuffle on weight scales (GFX950 swizzle layout)
|
||||
from aiter.utility.fp4_utils import e8m0_shuffle
|
||||
|
||||
s0, s1, _ = w13_weight_scale.shape
|
||||
w13_weight_scale.data = e8m0_shuffle(w13_weight_scale.view(s0 * s1, -1)).view(
|
||||
s0, s1, -1
|
||||
)
|
||||
|
||||
s0, s1, _ = w2_weight_scale.shape
|
||||
w2_weight_scale.data = e8m0_shuffle(w2_weight_scale.view(s0 * s1, -1)).view(
|
||||
s0, s1, -1
|
||||
)
|
||||
|
||||
# View as native FP4 dtype
|
||||
fp4_dtype = getattr(torch, "float4_e2m1fn_x2", None)
|
||||
if fp4_dtype is not None:
|
||||
w13_weight.data = w13_weight.data.view(fp4_dtype)
|
||||
w2_weight.data = w2_weight.data.view(fp4_dtype)
|
||||
|
||||
# Shuffle weights for AITER CK kernel
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
w13_weight, w2_weight
|
||||
)
|
||||
shuffled_w13.is_shuffled = True
|
||||
shuffled_w2.is_shuffled = True
|
||||
|
||||
return (
|
||||
shuffled_w13,
|
||||
shuffled_w2,
|
||||
w13_weight_scale,
|
||||
w2_weight_scale,
|
||||
w13_bias,
|
||||
w2_bias,
|
||||
)
|
||||
|
||||
elif mxfp4_backend == Mxfp4MoeBackend.AITER_MXFP4_BF16:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
@@ -1541,6 +1604,17 @@ def make_mxfp4_moe_quant_config(
|
||||
block_shape=None,
|
||||
gemm1_clamp_limit=swiglu_limit,
|
||||
)
|
||||
elif mxfp4_backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4:
|
||||
return ocp_mx_moe_quant_config(
|
||||
quant_dtype="mxfp4",
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
gemm1_alpha=gemm1_alpha,
|
||||
gemm1_beta=gemm1_beta,
|
||||
gemm1_clamp_limit=swiglu_limit,
|
||||
)
|
||||
elif mxfp4_backend in (
|
||||
Mxfp4MoeBackend.MARLIN,
|
||||
Mxfp4MoeBackend.BATCHED_MARLIN,
|
||||
|
||||
@@ -55,6 +55,7 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
kFp8StaticTensorSym,
|
||||
kMxfp4Dynamic,
|
||||
kNvfp4Dynamic,
|
||||
kNvfp4Static,
|
||||
)
|
||||
@@ -1042,6 +1043,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(
|
||||
moe, activation_key=kFp8StaticTensorSym
|
||||
)
|
||||
elif self.ocp_mx_scheme == "w_mxfp4_a_mxfp4":
|
||||
# W4A4: MXFP4 weights + MXFP4 activations
|
||||
self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(
|
||||
moe, activation_key=kMxfp4Dynamic
|
||||
)
|
||||
|
||||
# Validation for unsupported schemes
|
||||
if any(
|
||||
@@ -1061,45 +1067,19 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
"Please open an issue."
|
||||
)
|
||||
|
||||
self.use_rocm_aiter_moe = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
|
||||
self.model_type = getattr(
|
||||
get_current_vllm_config().model_config.hf_config, "model_type", None
|
||||
)
|
||||
|
||||
# TODO: Remove once all OCP MX schemes use the kernel abstraction
|
||||
_AITER_NATIVE_OCP_MX_SCHEMES = ("w_mxfp4", "w_mxfp4_a_mxfp4", "w_mxfp4_a_fp8")
|
||||
self.emulate = (
|
||||
not current_platform.supports_mx()
|
||||
or self.ocp_mx_scheme not in _AITER_NATIVE_OCP_MX_SCHEMES
|
||||
) and (
|
||||
self.mxfp4_backend is Mxfp4MoeBackend.NONE or not self.use_rocm_aiter_moe
|
||||
)
|
||||
|
||||
if self.emulate:
|
||||
# We use the same code path between MXFP4/MXFP6 emulation.
|
||||
# If no native backend available, use emulation.
|
||||
if self.mxfp4_backend is Mxfp4MoeBackend.NONE:
|
||||
self.mxfp4_backend = Mxfp4MoeBackend.EMULATION
|
||||
|
||||
# TODO: Remove `self.mxfp4_backend != Mxfp4MoeBackend.NONE` and make it so that
|
||||
# all MXFP4 backends use the kernel abstraction.
|
||||
if self.mxfp4_backend != Mxfp4MoeBackend.NONE:
|
||||
self.experts_cls = backend_to_kernel_cls(self.mxfp4_backend)[0]
|
||||
self.experts_cls = backend_to_kernel_cls(self.mxfp4_backend)[0]
|
||||
|
||||
# Log backend selection
|
||||
if self.mxfp4_backend != Mxfp4MoeBackend.NONE:
|
||||
logger.info_once(
|
||||
f"Using {self.mxfp4_backend.value} backend for {self.ocp_mx_scheme}"
|
||||
)
|
||||
elif self.emulate:
|
||||
logger.warning_once(
|
||||
f"The current mode (supports_mx={current_platform.supports_mx()}, "
|
||||
f"use_rocm_aiter_moe={self.use_rocm_aiter_moe}, "
|
||||
f"ocp_mx_scheme={self.ocp_mx_scheme}) "
|
||||
"does not support native MXFP4/MXFP6 "
|
||||
"computation. Simulated weight dequantization and activation "
|
||||
"QDQ (quantize and dequantize) will be used, with the linear "
|
||||
"layers computed in high precision."
|
||||
)
|
||||
logger.info_once(
|
||||
f"Using {self.mxfp4_backend.value} backend for {self.ocp_mx_scheme}"
|
||||
)
|
||||
|
||||
def maybe_roundup_sizes(
|
||||
self,
|
||||
@@ -1243,97 +1223,8 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: RoutedExperts):
|
||||
# For MXFP4 schemes with native backend, use oracle
|
||||
if self.mxfp4_backend != Mxfp4MoeBackend.NONE:
|
||||
self._setup_kernel(layer)
|
||||
return
|
||||
|
||||
if self.static_input_scales and self.input_dtype == "fp8":
|
||||
# firstly, process activations if fp8 static input
|
||||
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None."
|
||||
)
|
||||
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
|
||||
layer.w2_input_scale
|
||||
):
|
||||
logger.warning_once(
|
||||
"Found input_scales that are not equal for "
|
||||
"fp8 MoE layer. Using the maximum across experts "
|
||||
"for each layer. "
|
||||
)
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
layer.w13_input_scale.max(), requires_grad=False
|
||||
)
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
layer.w2_input_scale.max(), requires_grad=False
|
||||
)
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
# Normalize the weights and scales
|
||||
_, _, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
torch.empty_like(layer.w13_weight, dtype=torch.float8_e4m3fn),
|
||||
torch.empty_like(
|
||||
layer.w13_weight_scale, dtype=layer.w13_weight_scale.dtype
|
||||
),
|
||||
layer.w13_input_scale,
|
||||
)
|
||||
_, _, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
torch.empty_like(layer.w2_weight, dtype=torch.float8_e4m3fn),
|
||||
torch.empty_like(
|
||||
layer.w2_weight_scale, dtype=layer.w13_weight_scale.dtype
|
||||
),
|
||||
layer.w2_input_scale,
|
||||
)
|
||||
# Reset the parameter
|
||||
if w13_input_scale is not None:
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
w13_input_scale, requires_grad=False
|
||||
)
|
||||
if w2_input_scale is not None:
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
w2_input_scale, requires_grad=False
|
||||
)
|
||||
|
||||
# TODO(bowenbao): gradually migrate to oracles.
|
||||
# Existing AITER path for w_mxfp4_a_mxfp4 and other schemes
|
||||
from aiter.utility.fp4_utils import e8m0_shuffle
|
||||
|
||||
# Pre-shuffle weight scales
|
||||
s0, s1, _ = layer.w13_weight_scale.shape
|
||||
w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1)
|
||||
w13_weight_scale = e8m0_shuffle(w13_weight_scale)
|
||||
layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1)
|
||||
|
||||
s0, s1, _ = layer.w2_weight_scale.shape
|
||||
w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1)
|
||||
w2_weight_scale = e8m0_shuffle(w2_weight_scale)
|
||||
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
|
||||
|
||||
if self.fp4_dtype is not None:
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
layer.w13_weight.view(self.fp4_dtype),
|
||||
requires_grad=layer.w13_weight.requires_grad,
|
||||
)
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
layer.w2_weight.view(self.fp4_dtype),
|
||||
requires_grad=layer.w2_weight.requires_grad,
|
||||
)
|
||||
# Pre-shuffle weight
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data
|
||||
)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
|
||||
layer.w13_weight.is_shuffled = True
|
||||
layer.w2_weight.is_shuffled = True
|
||||
|
||||
# Build quant config for AITER path
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
torch.accelerator.empty_cache()
|
||||
def process_weights_after_loading(self, layer):
|
||||
self._setup_kernel(layer)
|
||||
|
||||
def _setup_kernel(self, layer: RoutedExperts):
|
||||
"""Setup kernel using oracle functions for MXFP4 schemes (W4A16, W4A8)."""
|
||||
@@ -1375,6 +1266,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
replace_parameter(layer, "w13_bias", w13_bias)
|
||||
replace_parameter(layer, "w2_bias", w2_bias)
|
||||
|
||||
if self.mxfp4_backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4:
|
||||
layer.w13_weight.is_shuffled = True
|
||||
layer.w2_weight.is_shuffled = True
|
||||
|
||||
torch.accelerator.empty_cache()
|
||||
|
||||
# Build quant config and kernel
|
||||
@@ -1464,38 +1359,18 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
shared_experts: SharedExperts | None,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
# For oracle-based kernels (W4A16, W4A8) or emulation kernel
|
||||
if self.moe_kernel is not None:
|
||||
return self.moe_kernel.apply(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
expert_map=layer.expert_map,
|
||||
shared_experts=shared_experts,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
# AITER path
|
||||
# TODO: Refactor this to use modular MOE kernel as well.
|
||||
from vllm.model_executor.layers.fused_moe.experts.rocm_aiter_moe import (
|
||||
rocm_aiter_fused_experts,
|
||||
)
|
||||
|
||||
return rocm_aiter_fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
moe_config=layer.moe_config,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
expert_map=layer.expert_map,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
def apply_monolithic(
|
||||
|
||||
Reference in New Issue
Block a user