[XPU][MoE] Add WNA16 oracle backend for GPTQ sym-int4 (xpu_fused_moe) (#41426)

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Jason Elie Bou Kheir
2026-05-28 09:30:48 -07:00
committed by GitHub
parent a9ec46d4b7
commit 3207e7680e
2 changed files with 179 additions and 7 deletions
@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8DynamicTensorSym,
kFp8StaticTensorSym,
kInt4Static,
kMxfp4Static,
kMxfp8Dynamic,
kMxfp8Static,
@@ -48,6 +49,7 @@ class XPUExperts(mk.FusedMoEExpertsModular):
num_dispatchers,
)
self.is_fp8 = False
self.is_int4 = False
self.is_mxfp4 = False
self.is_mxfp8 = False
self.fused_moe_impl: XpuFusedMoe | None = None
@@ -133,6 +135,15 @@ class XPUExperts(mk.FusedMoEExpertsModular):
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
# The kernel takes is_fp8/is_int4/is_mxfp4 as independent booleans.
# In this hierarchy each subclass flips exactly one to True; assert
# the invariant so a future subclass that sets two doesn't silently
# miscompute (kernel-side priority is undocumented).
assert sum([self.is_fp8, self.is_int4, self.is_mxfp4]) <= 1, (
"XPUExperts: at most one of is_fp8, is_int4, is_mxfp4 may be True; "
f"got is_fp8={self.is_fp8}, is_int4={self.is_int4}, "
f"is_mxfp4={self.is_mxfp4}."
)
if self.fused_moe_impl is None:
topk = topk_ids.size(-1)
self.fused_moe_impl = XpuFusedMoe(
@@ -148,6 +159,7 @@ class XPUExperts(mk.FusedMoEExpertsModular):
ep_rank=self.moe_config.ep_rank,
ep_size=self.moe_config.ep_size,
is_fp8=self.is_fp8,
is_int4=self.is_int4,
is_mxfp4=self.is_mxfp4,
is_mxfp8=self.is_mxfp8,
)
@@ -217,6 +229,42 @@ class XPUExpertsMxfp8(XPUExpertsFp8):
return (weight_key, activation_key) in SUPPORTED_W_A
class XPUExpertsWNA16(XPUExperts):
"""W4A16 INT4-symmetric MoE backed by `xpu_fused_moe(is_int4=True)`.
Weight layout when `is_int4=True` (per `xpu_fused_moe` docstring):
w13: [num_experts, 2*inter_size, hidden_size] contiguous int4-packed
w13_scales: [num_experts, 2*inter_size, hidden_size // group_size]
w2: [num_experts, hidden_size, inter_size] contiguous int4-packed
w2_scales: [num_experts, hidden_size, inter_size // group_size]
Pairs with `INCXPULinearMethod` for the linear layers; together they
cover full-attn + MoE on Intel XPU end-to-end without IPEX.
"""
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
max_num_tokens: int | None = None,
num_dispatchers: int | None = None,
):
super().__init__(
moe_config,
quant_config,
max_num_tokens,
num_dispatchers,
)
self.is_int4 = True
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
return (weight_key, activation_key) == (kInt4Static, None)
class XPUExpertsMXFp4(XPUExperts):
def __init__(
self,
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import sys
from enum import Enum
from typing import TYPE_CHECKING
@@ -26,6 +27,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
from vllm.platforms import current_platform
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization.auto_gptq import AutoGPTQConfig
@@ -37,6 +39,7 @@ logger = init_logger(__name__)
class WNA16MoEBackend(Enum):
MARLIN = "MARLIN"
BATCHED_MARLIN = "BATCHED_MARLIN"
XPU = "XPU"
def backend_to_kernel_cls(
@@ -57,6 +60,13 @@ def backend_to_kernel_cls(
return [BatchedMarlinExperts]
elif backend == WNA16MoEBackend.XPU:
from vllm.model_executor.layers.fused_moe.experts.xpu_moe import (
XPUExpertsWNA16,
)
return [XPUExpertsWNA16]
else:
raise ValueError(f"Unknown WNA16 MoE backend: {backend.value}")
@@ -65,11 +75,12 @@ def _get_priority_backends() -> list[WNA16MoEBackend]:
"""
Get available backends in priority order based on platform and config.
"""
_AVAILABLE_BACKENDS = [
if current_platform.is_xpu():
return [WNA16MoEBackend.XPU]
return [
WNA16MoEBackend.MARLIN,
WNA16MoEBackend.BATCHED_MARLIN,
]
return _AVAILABLE_BACKENDS
def select_wna16_moe_backend(
@@ -156,12 +167,14 @@ def make_wna16_moe_kernel(
w2_g_idx_sort_indices: torch.Tensor | None,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEKernel:
# Currently, we only support MarlinExperts and BatchedMarlinExperts
assert experts_cls in (MarlinExperts, BatchedMarlinExperts)
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.experts.xpu_moe import (
XPUExpertsWNA16,
)
assert experts_cls in (MarlinExperts, BatchedMarlinExperts, XPUExpertsWNA16)
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
@@ -172,11 +185,24 @@ def make_wna16_moe_kernel(
assert prepare_finalize is not None
assert isinstance(prepare_finalize, mk.FusedMoEPrepareAndFinalizeModular)
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
if experts_cls is XPUExpertsWNA16:
assert (
prepare_finalize.activation_format == mk.FusedMoEActivationFormat.Standard
), (
"XPUExpertsWNA16 only supports the Standard activation format; "
"xpu_fused_moe(is_int4=True) does not implement BatchedExperts."
)
experts: mk.FusedMoEExperts = XPUExpertsWNA16(
moe_config=moe_config,
quant_config=moe_quant_config,
)
elif (
prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts
):
assert experts_cls == BatchedMarlinExperts
max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens is not None
experts: mk.FusedMoEExperts = BatchedMarlinExperts(
experts = BatchedMarlinExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
moe_config=moe_config,
@@ -524,6 +550,69 @@ def _process_awq_weights_marlin(
)
def _process_weights_xpu(
layer: torch.nn.Module,
quant_config: QuantizationConfig,
w13_qweight: torch.Tensor,
w2_qweight: torch.Tensor,
w13_scales: torch.Tensor,
w2_scales: torch.Tensor,
w13_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> tuple[
torch.Tensor, # w13_qweight
torch.Tensor, # w2_qweight
torch.Tensor, # w13_scales
torch.Tensor, # w2_scales
torch.Tensor | None, # w13_bias
torch.Tensor | None, # w2_bias
]:
"""Repack GPTQ-format INT4 MoE weights into the layout
`vllm_xpu_kernels.fused_moe_interface.xpu_fused_moe(is_int4=True)` expects:
w13: [E, 2*N, K] int4 (uint8 storage [E, 2*N, K // 2])
w13_scales: [E, 2*N, K // group_size] params_dtype
w2: [E, K, N] int4 (uint8 storage [E, K, N // 2])
w2_scales: [E, K, N // group_size] params_dtype
Input GPTQ layout from FusedMoE.weight_loader:
w13: [E, K // 8, 2*N] int32 (8 nibbles per int32 along the input dim)
w13_scales: [E, K // group_size, 2*N] params_dtype
w2: [E, N // 8, K] int32
w2_scales: [E, N // group_size, K] params_dtype
Transpose dim 1 ↔ dim 2 then view int32 → uint8 to recover sequential
int4-packed bytes along the input dim. Each packed int32 holds 8 nibbles
`(n7<<28)|(n6<<24)|...|(n1<<4)|n0` in ascending K order; on a
little-endian host the int32→uint8 view exposes them as bytes
`[n1<<4|n0, n3<<4|n2, n5<<4|n4, n7<<4|n6]`, i.e. two nibbles per byte
with the lower nibble = lower input-K index. xpu_fused_moe(is_int4=True)
expects this convention; on a big-endian host the byte order reverses
and the kernel would silently miscompute, so we hard-fail.
"""
del layer, quant_config # unused — kept for parity with the marlin helper
if sys.byteorder != "little":
raise NotImplementedError(
"_process_weights_xpu requires a little-endian host: the GPTQ "
"int32 → uint8 nibble repack relies on LE byte ordering."
)
w13_xpu = w13_qweight.transpose(1, 2).contiguous().view(torch.uint8)
w2_xpu = w2_qweight.transpose(1, 2).contiguous().view(torch.uint8)
w13_scales_xpu = w13_scales.transpose(1, 2).contiguous()
w2_scales_xpu = w2_scales.transpose(1, 2).contiguous()
return (
w13_xpu,
w2_xpu,
w13_scales_xpu,
w2_scales_xpu,
w13_bias,
w2_bias,
)
def convert_to_wna16_moe_kernel_format(
backend: WNA16MoEBackend,
layer: torch.nn.Module,
@@ -617,5 +706,40 @@ def convert_to_wna16_moe_kernel_format(
w13_bias,
w2_bias,
)
elif backend == WNA16MoEBackend.XPU:
(
w13_xpu,
w2_xpu,
w13_scale_xpu,
w2_scale_xpu,
w13_bias_out,
w2_bias_out,
) = _process_weights_xpu(
layer,
quant_config,
w13,
w2,
w13_scale,
w2_scale,
w13_bias,
w2_bias,
)
empty = torch.empty((0,), dtype=torch.int32, device=w13.device)
return (
w13_xpu,
w2_xpu,
w13_scale_xpu,
w2_scale_xpu,
empty, # w13_g_idx
empty, # w2_g_idx
empty, # w13_g_idx_sort_indices
empty, # w2_g_idx_sort_indices
None, # w13_qzeros — sym int4 on XPU has none; kernel does uint4b8→s4
None, # w2_qzeros
None, # w13_input_global_scale
None, # w2_input_global_scale
w13_bias_out,
w2_bias_out,
)
else:
raise ValueError(f"Unsupported wna16 MoE backend: {backend.value}")