[MoE Refactor] Remove supports_expert_map (#43108)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2026-05-29 17:26:56 -04:00
committed by GitHub
parent 106aa92f04
commit 7b98f498cd
27 changed files with 26 additions and 148 deletions
@@ -224,10 +224,6 @@ class Config:
info = expert_info(self.fused_experts_type)
return info.blocked_quantization_support
def supports_expert_map(self):
info = expert_info(self.fused_experts_type)
return info.supports_expert_map
def supports_apply_weight_on_input(self):
info = prepare_finalize_info(self.prepare_finalize_type)
return info.supports_apply_weight_on_input
@@ -326,6 +322,15 @@ class Config:
if self.needs_mori() and not has_mori(): # noqa: SIM103
return False, "Needs MoRI, but MoRI not available."
try:
if not self.fused_experts_type._supports_current_device():
return (
False,
f"{self.fused_experts_type} not supported on the current device.",
)
except NotImplementedError:
pass
return True, None
@@ -471,7 +476,7 @@ class RankTensors:
topk_ids = topk_ids.to(device=device)
expert_map = None
if config.world_size > 1 and config.supports_expert_map():
if config.world_size > 1:
expert_map = torch.full(
(global_num_experts,), fill_value=-1, dtype=torch.int32
)
@@ -67,7 +67,6 @@ class ExpertInfo:
activation_format: mk.FusedMoEActivationFormat
supported_dtypes: list[torch.dtype | str]
blocked_quantization_support: bool
supports_expert_map: bool
needs_matching_quant: bool = False
needs_deep_gemm: bool = False
needs_aiter: bool = False
@@ -129,7 +128,6 @@ def register_experts(
activation_format: mk.FusedMoEActivationFormat,
supported_dtypes: list[torch.dtype | str],
blocked_quantization_support: bool,
supports_expert_map: bool,
needs_matching_quant: bool = False,
needs_deep_gemm: bool = False,
needs_aiter: bool = False,
@@ -142,7 +140,6 @@ def register_experts(
activation_format,
supported_dtypes,
blocked_quantization_support,
supports_expert_map,
needs_matching_quant,
needs_deep_gemm,
needs_aiter,
@@ -176,7 +173,6 @@ register_experts(
batched_format,
common_float_types,
blocked_quantization_support=True,
supports_expert_map=False,
needs_matching_quant=True,
)
@@ -185,7 +181,6 @@ register_experts(
standard_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_expert_map=True,
needs_matching_quant=True,
)
@@ -194,7 +189,6 @@ register_experts(
batched_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_expert_map=True,
)
# Disable on blackwell for now
@@ -260,7 +254,6 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
nvfp4_types + fp8_types,
blocked_quantization_support=True,
# Note: this is a hack to get it to run for now
supports_expert_map=True,
)
else:
FlashInferCutlassMoEPrepareAndFinalize = None
@@ -294,7 +287,6 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
standard_format,
nvfp4_types,
blocked_quantization_support=False,
supports_expert_map=True,
)
if has_aiter():
@@ -307,7 +299,6 @@ if has_aiter():
standard_format,
fp8_types,
blocked_quantization_support=True,
supports_expert_map=True,
needs_aiter=True,
)
else:
@@ -319,7 +310,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
batched_format,
fp8_types,
blocked_quantization_support=True,
supports_expert_map=False,
needs_matching_quant=False,
needs_deep_gemm=True,
)
@@ -328,7 +318,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
standard_format,
fp8_types,
blocked_quantization_support=True,
supports_expert_map=True,
needs_matching_quant=False,
needs_deep_gemm=True,
)
@@ -337,7 +326,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
standard_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_expert_map=True,
needs_matching_quant=True,
needs_deep_gemm=True,
)
@@ -353,14 +341,12 @@ if cutlass_fp8_supported():
standard_format,
fp8_types,
blocked_quantization_support=False,
supports_expert_map=False,
)
register_experts(
CutlassBatchedExpertsFp8,
batched_format,
fp8_types,
blocked_quantization_support=False,
supports_expert_map=False,
)
else:
CutlassBatchedExpertsFp8 = None
@@ -376,7 +362,6 @@ if cutlass_fp4_supported():
standard_format,
nvfp4_types,
blocked_quantization_support=True,
supports_expert_map=False,
)
else:
CutlassExpertsFp4 = None
@@ -227,7 +227,7 @@ def is_nyi_config(config: Config) -> bool:
) == 1
return unsupported_quant_config
return not info.supports_expert_map
return False
def generate_valid_test_cases(
@@ -248,9 +248,6 @@ class AiterW4A8ExpertsMonolithic(mk.FusedMoEExpertsMonolithic):
) -> bool:
return True
def supports_expert_map(self) -> bool:
return False # Expert parallelism not yet supported
@property
def expects_unquantized_inputs(self) -> bool:
return True
@@ -316,9 +316,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEExpertsModular):
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True
def supports_expert_map(self) -> bool:
return False
def supports_packed_ue8m0_act_scales(self) -> bool:
"""
DeepGemm supports packed ue8m0 activation scales format in devices == sm100
@@ -100,9 +100,6 @@ class CPUExpertsFp8(mk.FusedMoEExpertsMonolithic):
) -> bool:
return True
def supports_expert_map(self) -> bool:
return False
def apply(
self,
hidden_states: torch.Tensor,
@@ -256,9 +253,6 @@ class CPUExpertsMxfp4(mk.FusedMoEExpertsMonolithic):
) -> bool:
return True
def supports_expert_map(self) -> bool:
return False
def apply(
self,
hidden_states: torch.Tensor,
@@ -378,7 +378,8 @@ class CutlassExpertsFp8Base(mk.FusedMoEExpertsModular):
topk_ids,
activation,
global_num_experts,
expert_map,
# the fp8 cutlass experts use their own expert map.
None,
self.w1_scale,
self.w2_scale,
a1q_scale,
@@ -418,9 +419,6 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
or moe_parallel_config.use_fi_nvl_one_sided_kernels
)
def supports_expert_map(self) -> bool:
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# topk weights and reduction are fused in moe_unpermute cuda kernel
return TopKWeightAndReduceNoOP()
@@ -460,9 +458,6 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts
def supports_expert_map(self) -> bool:
return False
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
return self.out_dtype if self.out_dtype is not None else act_dtype
@@ -741,9 +736,6 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def supports_expert_map(self) -> bool:
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
@@ -1038,9 +1030,6 @@ class CutlassExpertsMxfp4(mk.FusedMoEExpertsModular):
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def supports_expert_map(self) -> bool:
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
@@ -1340,9 +1329,6 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEExpertsModular):
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# topk weights and reduction are fused in moe_unpermute cuda kernel
return TopKWeightAndReduceNoOP()
@@ -164,9 +164,6 @@ class DeepGemmExperts(mk.FusedMoEExpertsModular):
or moe_parallel_config.use_fi_nvl_one_sided_kernels
)
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
@@ -388,9 +385,6 @@ class DeepGemmFP4Experts(mk.FusedMoEExpertsModular):
or moe_parallel_config.use_fi_nvl_one_sided_kernels
)
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
@@ -92,16 +92,6 @@ class FallbackExperts(mk.FusedMoEExpertsModular, ABC):
moe_parallel_config
) and fallback_cls._supports_parallel_config(moe_parallel_config)
def supports_expert_map(self) -> bool:
assert (
self.experts.supports_expert_map()
== self.fallback_experts.supports_expert_map()
)
return (
self.experts.supports_expert_map()
and self.fallback_experts.supports_expert_map()
)
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
e_war = self.experts.finalize_weight_and_reduce_impl()
fbe_war = self.fallback_experts.finalize_weight_and_reduce_impl()
@@ -89,9 +89,6 @@ class FlashInferCuteDSLBatchedExperts(mk.FusedMoEExpertsModular):
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True
def supports_expert_map(self) -> bool:
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
@@ -98,9 +98,6 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
) -> bool:
return True
def supports_expert_map(self) -> bool:
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
@@ -207,9 +207,6 @@ class FlashInferExperts(mk.FusedMoEExpertsModular):
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def supports_expert_map(self) -> bool:
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
@@ -555,9 +555,6 @@ class NaiveBatchedExperts(mk.FusedMoEExpertsModular):
"This method should not be called."
)
def supports_expert_map(self) -> bool:
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
@@ -799,9 +796,6 @@ class BatchedTritonExperts(mk.FusedMoEExpertsModular):
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True
def supports_expert_map(self) -> bool:
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
@@ -156,9 +156,6 @@ class HummingExpertsBase(mk.FusedMoEExpertsModular):
) -> bool:
return True
def supports_expert_map(self) -> bool:
return True
@staticmethod
def _supports_current_device() -> bool:
platform = current_platform
@@ -608,9 +608,6 @@ class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True
def supports_expert_map(self) -> bool:
return True
def moe_problem_size(
self,
a1: torch.Tensor,
@@ -1036,9 +1033,6 @@ class OAITritonMxfp4ExpertsMonolithic(mk.FusedMoEExpertsMonolithic):
) -> bool:
return True
def supports_expert_map(self) -> bool:
return True
@property
def expects_unquantized_inputs(self) -> bool:
return True
@@ -686,9 +686,6 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular):
class MarlinExperts(LoRAExpertsMixin, MarlinExpertsBase):
"""Marlin-based fused MoE expert implementation."""
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
@@ -920,9 +917,6 @@ class BatchedMarlinExperts(MarlinExpertsBase):
is_k_full=is_k_full,
)
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceDelegate()
@@ -441,9 +441,6 @@ class AiterExperts(mk.FusedMoEExpertsModular):
or moe_parallel_config.use_fi_nvl_one_sided_kernels
)
def supports_expert_map(self):
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
@@ -128,9 +128,6 @@ class TritonExperts(LoRAExpertsMixin, mk.FusedMoEExpertsModular):
def _supports_batch_invariance():
return True
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
@@ -99,9 +99,6 @@ class TrtLlmBf16Experts(mk.FusedMoEExpertsMonolithic):
) -> bool:
return True
def supports_expert_map(self) -> bool:
return False
@property
def expects_unquantized_inputs(self) -> bool:
return True
@@ -88,9 +88,6 @@ class TrtLlmFp8ExpertsBase:
or moe_parallel_config.use_ag_rs_all2all_kernels
) and not moe_parallel_config.enable_eplb
def supports_expert_map(self) -> bool:
return False
class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
"""
@@ -113,9 +113,6 @@ class TrtLlmMxfp4ExpertsBase:
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def supports_expert_map(self) -> bool:
return False
@property
def expects_unquantized_inputs(self) -> bool:
return False
@@ -248,9 +245,6 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula
# routing is done externally, so accept any routing method.
return True
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
@@ -179,9 +179,6 @@ class TrtLlmNvFp4ExpertsBase:
300000, _calc_max_supported_tokens(self.topk, self.moe_config.num_experts)
)
def supports_expert_map(self) -> bool:
return False
class TrtLlmNvFp4ExpertsModular(TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsModular):
"""
@@ -107,9 +107,6 @@ class XPUExperts(mk.FusedMoEExpertsModular):
]
return (weight_key, activation_key) in SUPPORTED_W_A
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
@@ -34,11 +34,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
super().__init__(moe_kernel.moe_config)
self.moe_quant_config = old_quant_method.moe_quant_config
self.moe_kernel = moe_kernel
self.disable_expert_map = getattr(
old_quant_method,
"disable_expert_map",
not self.moe_kernel.supports_expert_map(),
)
self.old_quant_method = old_quant_method
logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__)
@@ -103,7 +98,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
activation=layer.activation,
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=None if self.disable_expert_map else layer.expert_map,
expert_map=layer.expert_map,
shared_experts=shared_experts,
shared_experts_input=shared_experts_input,
)
@@ -751,13 +751,6 @@ class FusedMoEExperts(ABC):
"""
return False
@abstractmethod
def supports_expert_map(self) -> bool:
"""
A flag indicating whether or not this class supports expert maps
"""
raise NotImplementedError
def supports_packed_ue8m0_act_scales(self) -> bool:
"""
A flag indicating whether or not this class can process packed ue8m0
@@ -1567,12 +1560,6 @@ class FusedMoEKernel:
== self.fused_experts.activation_format()
)
def supports_expert_map(self) -> bool:
"""
A flag indicating whether or not this class supports expert maps.
"""
return self.fused_experts.supports_expert_map()
def output_is_reduced(self) -> bool:
"""
Indicates whether or not the output of fused MoE kernel
@@ -17,7 +17,7 @@ def _quantize_and_setup_dispatch(
a1: torch.Tensor,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> tuple[torch.Tensor, list[torch.Tensor] | None]:
) -> tuple[torch.Tensor, list[torch.Tensor] | None, torch.Tensor | None]:
# Defer input quantization to the MoE kernel.
if defer_input_quant:
a1q = a1
@@ -33,7 +33,7 @@ def _quantize_and_setup_dispatch(
# which makes the scales tensor different shape than
# the hidden states, breaking the A2A kernel. So, we
# delay the swizzling until after the A2A.
a1q, a1q_scale = a1q, a1q_scale = moe_kernel_quantize_input(
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
input_sf,
quant_dtype=quant_config.quant_dtype,
@@ -49,7 +49,7 @@ def _quantize_and_setup_dispatch(
skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0
scales = None if skip_gather_scales else [a1q_scale]
return a1q, scales
return a1q, scales, a1q_scale
def _unwrap_scale_and_prepare_for_moe(
@@ -129,7 +129,9 @@ class MoEPrepareAndFinalizeNaiveDPEPModular(mk.FusedMoEPrepareAndFinalizeModular
)
a1 = a1 * topk_weights.to(a1.dtype)
a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant)
a1q, scales, a1q_scale_orig = _quantize_and_setup_dispatch(
a1, quant_config, defer_input_quant
)
# When LoRA is active, dispatch the per-token LoRA id along with
# hidden_states so every rank receives the correct mapping for the
@@ -164,7 +166,7 @@ class MoEPrepareAndFinalizeNaiveDPEPModular(mk.FusedMoEPrepareAndFinalizeModular
if extra_tensors is None:
assert len(res) == 3
a1q, topk_weights, topk_ids = res
a1q_scale = None
a1q_scale = a1q_scale_orig
else:
assert len(res) == 4
a1q, topk_weights, topk_ids, gathered_extras = res
@@ -178,7 +180,7 @@ class MoEPrepareAndFinalizeNaiveDPEPModular(mk.FusedMoEPrepareAndFinalizeModular
gathered_extras, quant_config
)
else:
a1q_scale = None
a1q_scale = a1q_scale_orig
return a1q, a1q_scale, None, topk_ids, topk_weights
@@ -249,7 +251,9 @@ class MoEPrepareAndFinalizeNaiveDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMono
) -> mk.PrepareMonolithicResultType:
"""Quantize and Dispatch Router Logits."""
a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant)
a1q, scales, a1q_scale_orig = _quantize_and_setup_dispatch(
a1, quant_config, defer_input_quant
)
res = get_ep_group().dispatch_router_logits(
a1q,
@@ -261,7 +265,7 @@ class MoEPrepareAndFinalizeNaiveDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMono
if scales is None:
assert len(res) == 2
a1q, router_logits = res
a1q_scale = None
a1q_scale = a1q_scale_orig
else:
assert len(res) == 3
a1q, router_logits, scales = res
@@ -405,8 +405,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
# TODO(rob): investigate the disable_expert_map introduced by:
# https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts=shared_experts,