Revert "[MoE Refactor] Migrate MoeWNA16Method quantization to MK orac… (#44033)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2026-05-29 19:45:29 -04:00
committed by GitHub
parent 8fad266507
commit 187457a952
8 changed files with 98 additions and 333 deletions
@@ -43,10 +43,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8Static128BlockSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
kInt4Static,
kInt4Static32,
kInt8DynamicTokenSym,
kInt8Static,
kInt8StaticChannelSym,
)
from vllm.platforms import current_platform
@@ -459,45 +456,40 @@ class TritonExperts(LoRAExpertsMixin, mk.FusedMoEExpertsModular):
class TritonWNA16Experts(TritonExperts):
@staticmethod
def _supports_current_device() -> bool:
return current_platform.is_cuda_alike() or current_platform.is_xpu()
raise NotImplementedError(
"TritonWNA16Experts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
return True
raise NotImplementedError(
"TritonWNA16Experts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
SUPPORTED_W = [
kInt4Static,
kInt8Static,
kInt4Static32,
# other group sizes?
]
return weight_key in SUPPORTED_W
raise NotImplementedError(
"TritonWNA16Experts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.GELU_TANH,
MoEActivation.SWIGLUOAI,
MoEActivation.SWIGLUSTEP,
MoEActivation.SILU_NO_MUL,
MoEActivation.GELU_NO_MUL,
MoEActivation.GELU_TANH_NO_MUL,
MoEActivation.RELU2_NO_MUL,
]
raise NotImplementedError(
"TritonWNA16Experts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
# Why?
return not (
moe_parallel_config.use_fi_nvl_two_sided_kernels
or moe_parallel_config.use_fi_nvl_one_sided_kernels
raise NotImplementedError(
"TritonWNA16Experts is not yet used by an Oracle. "
"This method should not be called."
)
def apply(
@@ -520,9 +512,7 @@ class TritonWNA16Experts(TritonExperts):
):
# Check constraints.
if self.quant_config.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), (
f"Hidden size mismatch {hidden_states.size(-1) // 2} == {w1.size(2)}"
)
assert hidden_states.size(-1) // 2 == w1.size(2), "Hidden size mismatch"
else:
assert hidden_states.size(-1) == w1.size(2), (
f"Hidden size mismatch {hidden_states.size(-1)} != {w1.size(2)}"
@@ -30,7 +30,6 @@ from vllm.model_executor.layers.fused_moe.utils import (
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import next_power_of_2
from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__)
@@ -1244,11 +1243,7 @@ def get_default_config(
bit = 4 if dtype == "int4_w4a16" else 8
use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk, block_shape[1], E, bit)
if use_moe_wna16_cuda:
config = {
"BLOCK_SIZE_M": min(16, next_power_of_2(M)),
"GROUP_SIZE_M": 1,
"SPLIT_K": 1,
}
config = {"BLOCK_SIZE_M": min(16, M), "SPLIT_K": 1}
elif M <= 20:
config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1, "SPLIT_K": 1}
elif M <= 40:
@@ -23,9 +23,6 @@ from vllm.model_executor.layers.fused_moe.experts.marlin_moe import (
MarlinExperts,
MarlinExpertsBase,
)
from vllm.model_executor.layers.fused_moe.experts.triton_moe import (
TritonWNA16Experts,
)
from vllm.model_executor.layers.fused_moe.experts.trtllm_mxint4_moe import (
TrtLlmMxint4ExpertsMonolithic,
)
@@ -34,7 +31,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_act_int8_process_scales,
marlin_moe_permute_scales,
marlin_permute_bias,
marlin_zero_points,
moe_awq_to_marlin_zero_points,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
@@ -49,7 +45,6 @@ class WNA16MoEBackend(Enum):
MARLIN = "MARLIN"
BATCHED_MARLIN = "BATCHED_MARLIN"
FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM"
TRITON = "TRITON"
XPU = "XPU"
@@ -63,8 +58,6 @@ def backend_to_kernel_cls(
return [BatchedMarlinExperts]
elif backend == WNA16MoEBackend.FLASHINFER_TRTLLM:
return [TrtLlmMxint4ExpertsMonolithic]
elif backend == WNA16MoEBackend.TRITON:
return [TritonWNA16Experts]
elif backend == WNA16MoEBackend.XPU:
from vllm.model_executor.layers.fused_moe.experts.xpu_moe import (
XPUExpertsWNA16,
@@ -75,38 +68,24 @@ def backend_to_kernel_cls(
raise ValueError(f"Unknown WNA16 MoE backend: {backend.value}")
def _get_priority_backends(
may_have_zp: bool, may_have_bias: bool
) -> list[WNA16MoEBackend]:
def _get_priority_backends() -> list[WNA16MoEBackend]:
"""
Get available backends in priority order based on platform and config.
"""
if current_platform.is_xpu():
return [WNA16MoEBackend.XPU]
_AVAILABLE_BACKENDS = []
if not may_have_zp and not may_have_bias:
_AVAILABLE_BACKENDS.append(WNA16MoEBackend.FLASHINFER_TRTLLM)
# Marlin supports ZP and bias
_AVAILABLE_BACKENDS += [
_AVAILABLE_BACKENDS = [
WNA16MoEBackend.FLASHINFER_TRTLLM,
WNA16MoEBackend.MARLIN,
WNA16MoEBackend.BATCHED_MARLIN,
]
return _AVAILABLE_BACKENDS
if not may_have_bias:
_AVAILABLE_BACKENDS.append(WNA16MoEBackend.TRITON)
return _AVAILABLE_BACKENDS
def select_wna16_moe_backend(
config: FusedMoEConfig,
weight_key: QuantKey,
may_have_zp: bool,
may_have_bias: bool,
) -> tuple[WNA16MoEBackend, type[mk.FusedMoEExperts]]:
"""Select the WNA16 MoE backend.
@@ -157,7 +136,7 @@ def select_wna16_moe_backend(
raise ValueError(_make_log_unsupported(backend, reason))
# Select kernels in order of backend.
AVAILABLE_BACKENDS = _get_priority_backends(may_have_zp, may_have_bias)
AVAILABLE_BACKENDS = _get_priority_backends()
for backend in AVAILABLE_BACKENDS:
activation_key = None # always BF16 activation for WNA16 MoE
@@ -239,7 +218,6 @@ def make_wna16_moe_kernel(
assert experts_cls in (
MarlinExperts,
BatchedMarlinExperts,
TritonWNA16Experts,
TrtLlmMxint4ExpertsMonolithic,
XPUExpertsWNA16,
)
@@ -256,7 +234,6 @@ def make_wna16_moe_kernel(
assert prepare_finalize is not None
logger.info_once("Using %s", prepare_finalize.__class__.__name__, scope="local")
logger.info_once("Using %s", experts_cls.__name__, scope="local")
extra_args: dict[str, Any] = {}
if issubclass(experts_cls, MarlinExpertsBase):
@@ -426,8 +403,6 @@ def _process_weights_marlin(
w2_input_global_scale: torch.Tensor | None = None
w13_bias_out: torch.Tensor | None = None
w2_bias_out: torch.Tensor | None = None
w13_qzeros_out: torch.Tensor | None = None
w2_qzeros_out: torch.Tensor | None = None
# --- FP8 weight / scale adjustment ---
if input_dtype == torch.float8_e4m3fn:
@@ -527,24 +502,6 @@ def _process_weights_marlin(
if w2_bias is not None:
w2_bias_out = marlin_permute_bias(w2_bias)
if w13_qzeros is not None:
w13_qzeros_out = marlin_zero_points(
w13_qzeros,
size_k=layer.intermediate_size_per_partition,
size_n=w13_qzeros.shape[2],
num_bits=num_bits,
is_a_8bit=is_a_8bit,
)
if w2_qzeros is not None:
w2_qzeros_out = marlin_zero_points(
w2_qzeros,
size_k=w2_qzeros.shape[1] * group_size_or_pack_factor,
size_n=w2_qzeros.shape[2],
num_bits=num_bits,
is_a_8bit=is_a_8bit,
)
return (
marlin_w13_qweight,
marlin_w2_qweight,
@@ -554,8 +511,8 @@ def _process_weights_marlin(
w2_g_idx,
w13_g_idx_sort_indices,
w2_g_idx_sort_indices,
w13_qzeros_out,
w2_qzeros_out,
w13_qzeros,
w2_qzeros,
w13_input_global_scale,
w2_input_global_scale,
w13_bias_out,
@@ -823,9 +780,6 @@ def convert_to_wna16_moe_kernel_format(
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig,
)
from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config,
)
if isinstance(quant_config, AWQMarlinConfig):
if w13_qzeros is None or w2_qzeros is None:
@@ -860,17 +814,11 @@ def convert_to_wna16_moe_kernel_format(
pack_factor = 32 // quant_config.num_bits
group_size = quant_config.group_size
actorder = quant_config.actorder
elif isinstance(quant_config, MoeWNA16Config):
num_bits = quant_config.weight_bits
pack_factor = quant_config.bit8_pack_factor
group_size = quant_config.group_size
actorder = None
else:
raise TypeError(
"Marlin WNA16 MoE backend requires AutoGPTQConfig, AWQMarlinConfig or "
f"QuantizationArgs, got {type(quant_config).__name__}."
)
if w13_g_idx is None or w2_g_idx is None:
raise ValueError("GPTQ Marlin MoE requires g_idx tensors.")
return _process_weights_marlin(
@@ -902,39 +850,6 @@ def convert_to_wna16_moe_kernel_format(
w13_bias,
w2_bias,
)
elif backend == WNA16MoEBackend.TRITON:
# Convert from int32 to uint8 format for Triton kernel.
# This changes the shape from (E, N, K // 8) to (E, N, K // 2) for int4,
# which matches what the Triton kernel expects.
w13_uint8 = w13.view(torch.uint8)
w2_uint8 = w2.view(torch.uint8)
return (
w13_uint8,
w2_uint8,
w13_scale,
w2_scale,
None,
None,
None,
None,
w13_qzeros,
w2_qzeros,
None,
None,
w13_bias,
w2_bias,
)
elif backend == WNA16MoEBackend.FLASHINFER_TRTLLM:
return _process_weights_flashinfer(
w13,
w2,
w13_scale,
w2_scale,
w13_g_idx,
w2_g_idx,
w13_bias,
w2_bias,
)
elif backend == WNA16MoEBackend.XPU:
assert quant_config is not None
(
@@ -485,8 +485,6 @@ class AutoGPTQMoEMethod(FusedMoEMethodBase):
self.wna16_moe_backend, self.experts_cls = select_wna16_moe_backend(
moe,
weight_key,
may_have_zp=True,
may_have_bias=True,
)
def create_weights(
@@ -646,22 +644,12 @@ class AutoGPTQMoEMethod(FusedMoEMethodBase):
layer.workspace = marlin_make_workspace_new(device, 4)
def process_weights_after_loading(self, layer: RoutedExperts) -> None:
def replace_or_register(name: str, val: torch.Tensor | None):
if val is None:
return
if hasattr(layer, name):
replace_parameter(layer, name, val)
else:
layer.register_parameter(
name, torch.nn.Parameter(val, requires_grad=False)
)
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
assert not is_a_8bit or self.quant_config.quant_type.size_bits == 8, (
"W8A8-INT8 is not supported by marlin kernel."
)
if is_a_8bit:
assert self.quant_config.quant_type.size_bits == 8, (
"W8A8-INT8 is not supported by marlin kernel."
)
(
w13,
@@ -672,8 +660,8 @@ class AutoGPTQMoEMethod(FusedMoEMethodBase):
w2_g_idx,
w13_g_idx_sort_indices,
w2_g_idx_sort_indices,
w13_qzeros,
w2_qzeros,
_w13_qzeros,
_w2_qzeros,
w13_input_global_scale,
w2_input_global_scale,
w13_bias,
@@ -691,8 +679,6 @@ class AutoGPTQMoEMethod(FusedMoEMethodBase):
w2_g_idx=layer.w2_g_idx,
w13_bias=getattr(layer, "w13_bias", None),
w2_bias=getattr(layer, "w2_bias", None),
w13_qzeros=getattr(layer, "w13_qzeros", None),
w2_qzeros=getattr(layer, "w2_qzeros", None),
)
replace_parameter(layer, "w13_qweight", w13)
@@ -703,12 +689,38 @@ class AutoGPTQMoEMethod(FusedMoEMethodBase):
replace_parameter(layer, "w2_g_idx", w2_g_idx)
replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
replace_or_register("w13_input_global_scale", w13_input_global_scale)
replace_or_register("w2_input_global_scale", w2_input_global_scale)
replace_or_register("w13_bias", w13_bias)
replace_or_register("w2_bias", w2_bias)
replace_or_register("w13_qzeros", w13_qzeros)
replace_or_register("w2_qzeros", w2_qzeros)
if w13_input_global_scale is not None:
if hasattr(layer, "w13_input_global_scale"):
replace_parameter(
layer, "w13_input_global_scale", w13_input_global_scale
)
else:
layer.register_parameter(
"w13_input_global_scale",
torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
)
if w2_input_global_scale is not None:
if hasattr(layer, "w2_input_global_scale"):
replace_parameter(layer, "w2_input_global_scale", w2_input_global_scale)
else:
layer.register_parameter(
"w2_input_global_scale",
torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
)
if w13_bias is not None:
if hasattr(layer, "w13_bias"):
replace_parameter(layer, "w13_bias", w13_bias)
else:
layer.register_parameter(
"w13_bias", torch.nn.Parameter(w13_bias, requires_grad=False)
)
if w2_bias is not None:
if hasattr(layer, "w2_bias"):
replace_parameter(layer, "w2_bias", w2_bias)
else:
layer.register_parameter(
"w2_bias", torch.nn.Parameter(w2_bias, requires_grad=False)
)
self._setup_kernel(layer)
@@ -524,8 +524,6 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
self.wna16_moe_backend, self.experts_cls = select_wna16_moe_backend(
moe,
kInt4Static,
may_have_zp=self.quant_config.zero_point,
may_have_bias=True,
)
def create_weights(
@@ -88,13 +88,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
self.wna16_backend, self.experts_cls = select_wna16_moe_backend(
config=self.moe,
weight_key=weight_key,
may_have_zp=False,
may_have_bias=False,
)
self.is_marlin = self.wna16_backend in [
WNA16MoEBackend.MARLIN,
WNA16MoEBackend.BATCHED_MARLIN,
]
def get_weight_shape(
self,
@@ -120,6 +114,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
"num_groups_w2 must be provided for weight scales"
)
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
is_flashinfer = self.wna16_backend == WNA16MoEBackend.FLASHINFER_TRTLLM
shape_map = {
"w13_weight": {
"Flashinfer": (
@@ -162,7 +157,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
"Marlin": (num_experts, num_groups_w2, hidden_size),
},
}
backend_key = "Marlin" if self.is_marlin else "Flashinfer"
backend_key = "Flashinfer" if is_flashinfer else "Marlin"
return shape_map[weight_name][backend_key]
def create_weights(
@@ -179,8 +174,9 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
# Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will
# shard for TP along the transposed dims
is_transposed = self.wna16_backend != WNA16MoEBackend.FLASHINFER_TRTLLM
extra_weight_attrs.update(
{"is_transposed": self.is_marlin, "quant_method": self.strategy}
{"is_transposed": is_transposed, "quant_method": self.strategy}
)
w13_weight = torch.nn.Parameter(
@@ -328,6 +324,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Process weights using the shared oracle infrastructure
is_flashinfer = self.wna16_backend == WNA16MoEBackend.FLASHINFER_TRTLLM
(
w13_qweight,
w2_qweight,
@@ -363,7 +360,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
replace_parameter(layer, "w2_weight_scale", w2_scales)
# Marlin-specific parameters (not needed for Flashinfer)
if self.is_marlin:
if not is_flashinfer:
replace_parameter(layer, "w13_weight_g_idx", w13_g_idx_processed)
replace_parameter(layer, "w2_weight_g_idx", w2_g_idx_processed)
replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
@@ -395,7 +392,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
# Add Marlin-specific arguments
marlin_args: dict[str, Any] = {}
if self.is_marlin:
if not is_flashinfer:
marlin_args = {
"w13_g_idx": layer.w13_weight_g_idx,
"w2_g_idx": layer.w2_weight_g_idx,
@@ -13,15 +13,11 @@ from vllm.model_executor.layers.fused_moe import (
RoutedExperts,
SharedExperts,
)
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.oracle.int_wna16 import (
WNA16MoEBackend,
convert_to_wna16_moe_kernel_format,
make_wna16_moe_kernel,
make_wna16_moe_quant_config,
select_wna16_moe_backend,
int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
@@ -35,15 +31,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supports_layer,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
INT4_DTYPE,
INT8_DTYPE,
QuantKey,
kInt4Static32GroupScale,
kInt4StaticGroupScale,
kInt8StaticGroupScale,
)
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
@@ -228,33 +216,6 @@ class MoeWNA16Method(FusedMoEMethodBase):
super().__init__(moe)
self.quant_config = quant_config
num_bits = self.quant_config.weight_bits
group_size = self.quant_config.group_size
if num_bits == 4:
quant_type = INT4_DTYPE
if group_size == 32:
scale = kInt4Static32GroupScale
else:
scale = kInt4StaticGroupScale
elif num_bits == 8:
assert group_size == -1
quant_type = INT8_DTYPE
scale = kInt8StaticGroupScale
else:
raise ValueError("MoeWNA16Method only supports int4 and int8 now.")
weight_key = QuantKey(quant_type, scale)
# Select WNA16 MoE backend via oracle.
# handle ZP?
self.wna16_backend, self.experts_cls = select_wna16_moe_backend(
config=self.moe,
weight_key=weight_key,
may_have_zp=self.quant_config.has_zp,
may_have_bias=False,
)
def create_weights(
self,
layer: RoutedExperts,
@@ -375,89 +336,24 @@ class MoeWNA16Method(FusedMoEMethodBase):
layer.register_parameter(key, param)
set_weight_attrs(param, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
has_zp = self.quant_config.has_zp
(
w13_qweight,
w2_qweight,
w13_scales,
w2_scales,
w13_g_idx_processed,
w2_g_idx_processed,
w13_g_idx_sort_indices,
w2_g_idx_sort_indices,
w13_qzeros,
w2_qzeros,
w13_input_global_scale,
w2_input_global_scale,
_, # w13_bias
_, # w2_bias
) = convert_to_wna16_moe_kernel_format(
backend=self.wna16_backend,
layer=layer,
quant_config=self.quant_config,
input_dtype=None,
w13=layer.w13_qweight,
w2=layer.w2_qweight,
w13_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
w13_g_idx=getattr(layer, "w13_g_idx", None),
w2_g_idx=getattr(layer, "w2_g_idx", None),
w13_qzeros=layer.w13_qzeros if has_zp else None,
w2_qzeros=layer.w2_qzeros if has_zp else None,
)
# Replace common parameters
replace_parameter(layer, "w13_qweight", w13_qweight)
replace_parameter(layer, "w2_qweight", w2_qweight)
replace_parameter(layer, "w13_scales", w13_scales)
replace_parameter(layer, "w2_scales", w2_scales)
if has_zp:
assert w13_qzeros is not None and w2_qzeros is not None
replace_parameter(layer, "w13_qzeros", w13_qzeros)
replace_parameter(layer, "w2_qzeros", w2_qzeros)
# Marlin-specific parameters (not needed for Flashinfer)
if self.wna16_backend != WNA16MoEBackend.FLASHINFER_TRTLLM:
replace_parameter(layer, "w13_g_idx", w13_g_idx_processed)
replace_parameter(layer, "w2_g_idx", w2_g_idx_processed)
replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
# Register input global scales if present
if w13_input_global_scale is not None:
layer.register_parameter(
"w13_input_global_scale",
torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
)
if w2_input_global_scale is not None:
layer.register_parameter(
"w2_input_global_scale",
torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
)
assert self.experts_cls is not None
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
assert self.moe_quant_config is not None
self.moe_kernel = make_wna16_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
routing_tables=layer._expert_routing_tables(),
)
def get_fused_moe_quant_config(
self, layer: RoutedExperts
) -> FusedMoEQuantConfig | None:
weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp
return make_wna16_moe_quant_config(
assert weight_bits == 4 or weight_bits == 8
config_builder = (
int4_w4a16_moe_quant_config
if weight_bits == 4
else int8_w8a16_moe_quant_config
)
return config_builder(
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
w1_zp=layer.w13_qzeros if has_zp else None,
w2_zp=layer.w2_qzeros if has_zp else None,
group_size=layer.group_size,
num_bits=self.quant_config.weight_bits,
block_shape=[0, layer.group_size],
)
def apply(
@@ -469,44 +365,22 @@ class MoeWNA16Method(FusedMoEMethodBase):
shared_experts: SharedExperts | None,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor:
assert not self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply(
from vllm.model_executor.layers.fused_moe import fused_experts
assert layer.activation == MoEActivation.SILU, (
f"Only SiLU activation is supported, not {layer.activation}."
)
return fused_experts(
x,
layer.w13_qweight,
layer.w2_qweight,
topk_weights,
topk_ids,
activation=layer.activation,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts=shared_experts,
shared_experts_input=shared_experts_input,
)
def apply_monolithic(
self,
layer: RoutedExperts,
x: torch.Tensor,
router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
x,
layer.w13_weight,
layer.w2_weight,
router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
quant_config=self.moe_quant_config,
)
@staticmethod
@@ -3,7 +3,7 @@
from collections.abc import Mapping
from copy import deepcopy
from types import MappingProxyType
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING
import regex as re
import torch
@@ -69,20 +69,6 @@ def get_dynamic_override(
return default_value
def flatten_list(lst: list[Any]) -> list[Any]:
output = []
def _flatten(lst: list[Any]):
for i in lst:
if isinstance(i, list):
_flatten(i)
else:
output.append(i)
_flatten(lst)
return output
def is_layer_gptq_quantized(
prefix: str,
quantized_layers: list[str],
@@ -97,8 +83,6 @@ def is_layer_gptq_quantized(
proj_name = prefix.split(".")[-1]
quantized_layers = flatten_list(quantized_layers)
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that