From 7b98f498cdf0bf9cf0ecc37b5c0c994cb94513c3 Mon Sep 17 00:00:00 2001 From: bnellnm <49004751+bnellnm@users.noreply.github.com> Date: Fri, 29 May 2026 17:26:56 -0400 Subject: [PATCH] [MoE Refactor] Remove supports_expert_map (#43108) Signed-off-by: Bill Nell --- .../moe/modular_kernel_tools/common.py | 15 +++++++++----- .../moe/modular_kernel_tools/mk_objects.py | 15 -------------- .../moe/test_modular_kernel_combinations.py | 2 +- .../fused_moe/experts/aiter_mxfp4_w4a8_moe.py | 3 --- .../experts/batched_deep_gemm_moe.py | 3 --- .../layers/fused_moe/experts/cpu_moe.py | 6 ------ .../layers/fused_moe/experts/cutlass_moe.py | 18 ++--------------- .../layers/fused_moe/experts/deep_gemm_moe.py | 6 ------ .../layers/fused_moe/experts/fallback.py | 10 ---------- .../experts/flashinfer_cutedsl_batched_moe.py | 3 --- .../experts/flashinfer_cutedsl_moe.py | 3 --- .../experts/flashinfer_cutlass_moe.py | 3 --- .../fused_moe/experts/fused_batched_moe.py | 6 ------ .../fused_moe/experts/fused_humming_moe.py | 3 --- .../experts/gpt_oss_triton_kernels_moe.py | 6 ------ .../layers/fused_moe/experts/marlin_moe.py | 6 ------ .../fused_moe/experts/rocm_aiter_moe.py | 3 --- .../layers/fused_moe/experts/triton_moe.py | 3 --- .../fused_moe/experts/trtllm_bf16_moe.py | 3 --- .../fused_moe/experts/trtllm_fp8_moe.py | 3 --- .../fused_moe/experts/trtllm_mxfp4_moe.py | 6 ------ .../fused_moe/experts/trtllm_nvfp4_moe.py | 3 --- .../layers/fused_moe/experts/xpu_moe.py | 3 --- .../fused_moe/fused_moe_modular_method.py | 7 +------ .../layers/fused_moe/modular_kernel.py | 13 ------------ .../fused_moe/prepare_finalize/naive_dp_ep.py | 20 +++++++++++-------- .../compressed_tensors_moe_w8a8_fp8.py | 2 -- 27 files changed, 26 insertions(+), 148 deletions(-) diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index ea52a2d3398..fdd00cfa27a 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -224,10 +224,6 @@ class Config: info = expert_info(self.fused_experts_type) return info.blocked_quantization_support - def supports_expert_map(self): - info = expert_info(self.fused_experts_type) - return info.supports_expert_map - def supports_apply_weight_on_input(self): info = prepare_finalize_info(self.prepare_finalize_type) return info.supports_apply_weight_on_input @@ -326,6 +322,15 @@ class Config: if self.needs_mori() and not has_mori(): # noqa: SIM103 return False, "Needs MoRI, but MoRI not available." + try: + if not self.fused_experts_type._supports_current_device(): + return ( + False, + f"{self.fused_experts_type} not supported on the current device.", + ) + except NotImplementedError: + pass + return True, None @@ -471,7 +476,7 @@ class RankTensors: topk_ids = topk_ids.to(device=device) expert_map = None - if config.world_size > 1 and config.supports_expert_map(): + if config.world_size > 1: expert_map = torch.full( (global_num_experts,), fill_value=-1, dtype=torch.int32 ) diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index 7c3bde2eafa..78ee8084d90 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -67,7 +67,6 @@ class ExpertInfo: activation_format: mk.FusedMoEActivationFormat supported_dtypes: list[torch.dtype | str] blocked_quantization_support: bool - supports_expert_map: bool needs_matching_quant: bool = False needs_deep_gemm: bool = False needs_aiter: bool = False @@ -129,7 +128,6 @@ def register_experts( activation_format: mk.FusedMoEActivationFormat, supported_dtypes: list[torch.dtype | str], blocked_quantization_support: bool, - supports_expert_map: bool, needs_matching_quant: bool = False, needs_deep_gemm: bool = False, needs_aiter: bool = False, @@ -142,7 +140,6 @@ def register_experts( activation_format, supported_dtypes, blocked_quantization_support, - supports_expert_map, needs_matching_quant, needs_deep_gemm, needs_aiter, @@ -176,7 +173,6 @@ register_experts( batched_format, common_float_types, blocked_quantization_support=True, - supports_expert_map=False, needs_matching_quant=True, ) @@ -185,7 +181,6 @@ register_experts( standard_format, common_float_and_int_types, blocked_quantization_support=True, - supports_expert_map=True, needs_matching_quant=True, ) @@ -194,7 +189,6 @@ register_experts( batched_format, common_float_and_int_types, blocked_quantization_support=True, - supports_expert_map=True, ) # Disable on blackwell for now @@ -260,7 +254,6 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability nvfp4_types + fp8_types, blocked_quantization_support=True, # Note: this is a hack to get it to run for now - supports_expert_map=True, ) else: FlashInferCutlassMoEPrepareAndFinalize = None @@ -294,7 +287,6 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability standard_format, nvfp4_types, blocked_quantization_support=False, - supports_expert_map=True, ) if has_aiter(): @@ -307,7 +299,6 @@ if has_aiter(): standard_format, fp8_types, blocked_quantization_support=True, - supports_expert_map=True, needs_aiter=True, ) else: @@ -319,7 +310,6 @@ if has_deep_gemm() and is_deep_gemm_supported(): batched_format, fp8_types, blocked_quantization_support=True, - supports_expert_map=False, needs_matching_quant=False, needs_deep_gemm=True, ) @@ -328,7 +318,6 @@ if has_deep_gemm() and is_deep_gemm_supported(): standard_format, fp8_types, blocked_quantization_support=True, - supports_expert_map=True, needs_matching_quant=False, needs_deep_gemm=True, ) @@ -337,7 +326,6 @@ if has_deep_gemm() and is_deep_gemm_supported(): standard_format, common_float_and_int_types, blocked_quantization_support=True, - supports_expert_map=True, needs_matching_quant=True, needs_deep_gemm=True, ) @@ -353,14 +341,12 @@ if cutlass_fp8_supported(): standard_format, fp8_types, blocked_quantization_support=False, - supports_expert_map=False, ) register_experts( CutlassBatchedExpertsFp8, batched_format, fp8_types, blocked_quantization_support=False, - supports_expert_map=False, ) else: CutlassBatchedExpertsFp8 = None @@ -376,7 +362,6 @@ if cutlass_fp4_supported(): standard_format, nvfp4_types, blocked_quantization_support=True, - supports_expert_map=False, ) else: CutlassExpertsFp4 = None diff --git a/tests/kernels/moe/test_modular_kernel_combinations.py b/tests/kernels/moe/test_modular_kernel_combinations.py index c7295f3ed6e..0c0e1d61f90 100644 --- a/tests/kernels/moe/test_modular_kernel_combinations.py +++ b/tests/kernels/moe/test_modular_kernel_combinations.py @@ -227,7 +227,7 @@ def is_nyi_config(config: Config) -> bool: ) == 1 return unsupported_quant_config - return not info.supports_expert_map + return False def generate_valid_test_cases( diff --git a/vllm/model_executor/layers/fused_moe/experts/aiter_mxfp4_w4a8_moe.py b/vllm/model_executor/layers/fused_moe/experts/aiter_mxfp4_w4a8_moe.py index 3906a7e057c..cc2adc31fcd 100644 --- a/vllm/model_executor/layers/fused_moe/experts/aiter_mxfp4_w4a8_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/aiter_mxfp4_w4a8_moe.py @@ -248,9 +248,6 @@ class AiterW4A8ExpertsMonolithic(mk.FusedMoEExpertsMonolithic): ) -> bool: return True - def supports_expert_map(self) -> bool: - return False # Expert parallelism not yet supported - @property def expects_unquantized_inputs(self) -> bool: return True diff --git a/vllm/model_executor/layers/fused_moe/experts/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/experts/batched_deep_gemm_moe.py index 7bd383b9cda..c8611217a18 100644 --- a/vllm/model_executor/layers/fused_moe/experts/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/batched_deep_gemm_moe.py @@ -316,9 +316,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEExpertsModular): def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: return True - def supports_expert_map(self) -> bool: - return False - def supports_packed_ue8m0_act_scales(self) -> bool: """ DeepGemm supports packed ue8m0 activation scales format in devices == sm100 diff --git a/vllm/model_executor/layers/fused_moe/experts/cpu_moe.py b/vllm/model_executor/layers/fused_moe/experts/cpu_moe.py index 54b264ef772..84740fc0570 100644 --- a/vllm/model_executor/layers/fused_moe/experts/cpu_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/cpu_moe.py @@ -100,9 +100,6 @@ class CPUExpertsFp8(mk.FusedMoEExpertsMonolithic): ) -> bool: return True - def supports_expert_map(self) -> bool: - return False - def apply( self, hidden_states: torch.Tensor, @@ -256,9 +253,6 @@ class CPUExpertsMxfp4(mk.FusedMoEExpertsMonolithic): ) -> bool: return True - def supports_expert_map(self) -> bool: - return False - def apply( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/experts/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/experts/cutlass_moe.py index 28a7d283b4b..feb49d260e1 100644 --- a/vllm/model_executor/layers/fused_moe/experts/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/cutlass_moe.py @@ -378,7 +378,8 @@ class CutlassExpertsFp8Base(mk.FusedMoEExpertsModular): topk_ids, activation, global_num_experts, - expert_map, + # the fp8 cutlass experts use their own expert map. + None, self.w1_scale, self.w2_scale, a1q_scale, @@ -418,9 +419,6 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): or moe_parallel_config.use_fi_nvl_one_sided_kernels ) - def supports_expert_map(self) -> bool: - return False - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # topk weights and reduction are fused in moe_unpermute cuda kernel return TopKWeightAndReduceNoOP() @@ -460,9 +458,6 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): def activation_format() -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.BatchedExperts - def supports_expert_map(self) -> bool: - return False - def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: return self.out_dtype if self.out_dtype is not None else act_dtype @@ -741,9 +736,6 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular): def activation_format() -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard - def supports_expert_map(self) -> bool: - return False - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: return TopKWeightAndReduceNoOP() @@ -1038,9 +1030,6 @@ class CutlassExpertsMxfp4(mk.FusedMoEExpertsModular): def activation_format() -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard - def supports_expert_map(self) -> bool: - return False - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: return TopKWeightAndReduceNoOP() @@ -1340,9 +1329,6 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEExpertsModular): def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: return True - def supports_expert_map(self) -> bool: - return True - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # topk weights and reduction are fused in moe_unpermute cuda kernel return TopKWeightAndReduceNoOP() diff --git a/vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py index e3e15e31618..3b354dd3ef1 100644 --- a/vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py @@ -164,9 +164,6 @@ class DeepGemmExperts(mk.FusedMoEExpertsModular): or moe_parallel_config.use_fi_nvl_one_sided_kernels ) - def supports_expert_map(self) -> bool: - return True - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: return TopKWeightAndReduceNoOP() @@ -388,9 +385,6 @@ class DeepGemmFP4Experts(mk.FusedMoEExpertsModular): or moe_parallel_config.use_fi_nvl_one_sided_kernels ) - def supports_expert_map(self) -> bool: - return True - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: return TopKWeightAndReduceNoOP() diff --git a/vllm/model_executor/layers/fused_moe/experts/fallback.py b/vllm/model_executor/layers/fused_moe/experts/fallback.py index 40741d52af5..639b2bf2668 100644 --- a/vllm/model_executor/layers/fused_moe/experts/fallback.py +++ b/vllm/model_executor/layers/fused_moe/experts/fallback.py @@ -92,16 +92,6 @@ class FallbackExperts(mk.FusedMoEExpertsModular, ABC): moe_parallel_config ) and fallback_cls._supports_parallel_config(moe_parallel_config) - def supports_expert_map(self) -> bool: - assert ( - self.experts.supports_expert_map() - == self.fallback_experts.supports_expert_map() - ) - return ( - self.experts.supports_expert_map() - and self.fallback_experts.supports_expert_map() - ) - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: e_war = self.experts.finalize_weight_and_reduce_impl() fbe_war = self.fallback_experts.finalize_weight_and_reduce_impl() diff --git a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py index 5eaaf46739f..253d1dae711 100644 --- a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py @@ -89,9 +89,6 @@ class FlashInferCuteDSLBatchedExperts(mk.FusedMoEExpertsModular): def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: return True - def supports_expert_map(self) -> bool: - return False - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # Let PrepareAndFinalize::finalize() decide the impl. return TopKWeightAndReduceDelegate() diff --git a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py index 2310982792f..b512d51c135 100644 --- a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py @@ -98,9 +98,6 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular): ) -> bool: return True - def supports_expert_map(self) -> bool: - return False - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: return TopKWeightAndReduceNoOP() diff --git a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutlass_moe.py index b891583e3ef..fd9446c2a22 100644 --- a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutlass_moe.py @@ -207,9 +207,6 @@ class FlashInferExperts(mk.FusedMoEExpertsModular): def activation_format() -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard - def supports_expert_map(self) -> bool: - return False - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: return TopKWeightAndReduceNoOP() diff --git a/vllm/model_executor/layers/fused_moe/experts/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/experts/fused_batched_moe.py index 0e31331e726..1f5724ac39c 100644 --- a/vllm/model_executor/layers/fused_moe/experts/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/fused_batched_moe.py @@ -555,9 +555,6 @@ class NaiveBatchedExperts(mk.FusedMoEExpertsModular): "This method should not be called." ) - def supports_expert_map(self) -> bool: - return False - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # Let PrepareAndFinalize::finalize() decide the impl. return TopKWeightAndReduceDelegate() @@ -799,9 +796,6 @@ class BatchedTritonExperts(mk.FusedMoEExpertsModular): def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: return True - def supports_expert_map(self) -> bool: - return False - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # Let PrepareAndFinalize::finalize() decide the impl. return TopKWeightAndReduceDelegate() diff --git a/vllm/model_executor/layers/fused_moe/experts/fused_humming_moe.py b/vllm/model_executor/layers/fused_moe/experts/fused_humming_moe.py index 8874228a142..53623f13254 100644 --- a/vllm/model_executor/layers/fused_moe/experts/fused_humming_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/fused_humming_moe.py @@ -156,9 +156,6 @@ class HummingExpertsBase(mk.FusedMoEExpertsModular): ) -> bool: return True - def supports_expert_map(self) -> bool: - return True - @staticmethod def _supports_current_device() -> bool: platform = current_platform diff --git a/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py index 98265abf7c8..03bf925fbd9 100644 --- a/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py @@ -608,9 +608,6 @@ class BaseOAITritonExperts(mk.FusedMoEExpertsModular): def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: return True - def supports_expert_map(self) -> bool: - return True - def moe_problem_size( self, a1: torch.Tensor, @@ -1036,9 +1033,6 @@ class OAITritonMxfp4ExpertsMonolithic(mk.FusedMoEExpertsMonolithic): ) -> bool: return True - def supports_expert_map(self) -> bool: - return True - @property def expects_unquantized_inputs(self) -> bool: return True diff --git a/vllm/model_executor/layers/fused_moe/experts/marlin_moe.py b/vllm/model_executor/layers/fused_moe/experts/marlin_moe.py index 8bb9e5bdc06..1d0cf91d427 100644 --- a/vllm/model_executor/layers/fused_moe/experts/marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/marlin_moe.py @@ -686,9 +686,6 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular): class MarlinExperts(LoRAExpertsMixin, MarlinExpertsBase): """Marlin-based fused MoE expert implementation.""" - def supports_expert_map(self) -> bool: - return True - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: return TopKWeightAndReduceNoOP() @@ -920,9 +917,6 @@ class BatchedMarlinExperts(MarlinExpertsBase): is_k_full=is_k_full, ) - def supports_expert_map(self) -> bool: - return True - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: return TopKWeightAndReduceDelegate() diff --git a/vllm/model_executor/layers/fused_moe/experts/rocm_aiter_moe.py b/vllm/model_executor/layers/fused_moe/experts/rocm_aiter_moe.py index b272a458b17..8415ac02784 100644 --- a/vllm/model_executor/layers/fused_moe/experts/rocm_aiter_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/rocm_aiter_moe.py @@ -441,9 +441,6 @@ class AiterExperts(mk.FusedMoEExpertsModular): or moe_parallel_config.use_fi_nvl_one_sided_kernels ) - def supports_expert_map(self): - return True - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: return TopKWeightAndReduceNoOP() diff --git a/vllm/model_executor/layers/fused_moe/experts/triton_moe.py b/vllm/model_executor/layers/fused_moe/experts/triton_moe.py index ddcef519da0..cf2c2cff6a8 100644 --- a/vllm/model_executor/layers/fused_moe/experts/triton_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/triton_moe.py @@ -128,9 +128,6 @@ class TritonExperts(LoRAExpertsMixin, mk.FusedMoEExpertsModular): def _supports_batch_invariance(): return True - def supports_expert_map(self) -> bool: - return True - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: return TopKWeightAndReduceNoOP() diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_bf16_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_bf16_moe.py index 02b7450a5c9..592a1513d75 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_bf16_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_bf16_moe.py @@ -99,9 +99,6 @@ class TrtLlmBf16Experts(mk.FusedMoEExpertsMonolithic): ) -> bool: return True - def supports_expert_map(self) -> bool: - return False - @property def expects_unquantized_inputs(self) -> bool: return True diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py index b98b84cdc62..43126195205 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py @@ -88,9 +88,6 @@ class TrtLlmFp8ExpertsBase: or moe_parallel_config.use_ag_rs_all2all_kernels ) and not moe_parallel_config.enable_eplb - def supports_expert_map(self) -> bool: - return False - class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular): """ diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py index 1e2fff8eb66..43f800343c7 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py @@ -113,9 +113,6 @@ class TrtLlmMxfp4ExpertsBase: def activation_format() -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard - def supports_expert_map(self) -> bool: - return False - @property def expects_unquantized_inputs(self) -> bool: return False @@ -248,9 +245,6 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula # routing is done externally, so accept any routing method. return True - def supports_expert_map(self) -> bool: - return True - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: return TopKWeightAndReduceNoOP() diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py index e4f292b7b1e..5ee023aa27c 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py @@ -179,9 +179,6 @@ class TrtLlmNvFp4ExpertsBase: 300000, _calc_max_supported_tokens(self.topk, self.moe_config.num_experts) ) - def supports_expert_map(self) -> bool: - return False - class TrtLlmNvFp4ExpertsModular(TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsModular): """ diff --git a/vllm/model_executor/layers/fused_moe/experts/xpu_moe.py b/vllm/model_executor/layers/fused_moe/experts/xpu_moe.py index 8cbf0a6ce02..82969dd8e25 100644 --- a/vllm/model_executor/layers/fused_moe/experts/xpu_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/xpu_moe.py @@ -107,9 +107,6 @@ class XPUExperts(mk.FusedMoEExpertsModular): ] return (weight_key, activation_key) in SUPPORTED_W_A - def supports_expert_map(self) -> bool: - return True - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: return TopKWeightAndReduceNoOP() diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index 3ebb63b0057..dd21ff58fc3 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -34,11 +34,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): super().__init__(moe_kernel.moe_config) self.moe_quant_config = old_quant_method.moe_quant_config self.moe_kernel = moe_kernel - self.disable_expert_map = getattr( - old_quant_method, - "disable_expert_map", - not self.moe_kernel.supports_expert_map(), - ) self.old_quant_method = old_quant_method logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__) @@ -103,7 +98,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): activation=layer.activation, global_num_experts=layer.global_num_experts, apply_router_weight_on_input=layer.apply_router_weight_on_input, - expert_map=None if self.disable_expert_map else layer.expert_map, + expert_map=layer.expert_map, shared_experts=shared_experts, shared_experts_input=shared_experts_input, ) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 6fbc1bffaac..9c3ecee9f9b 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -751,13 +751,6 @@ class FusedMoEExperts(ABC): """ return False - @abstractmethod - def supports_expert_map(self) -> bool: - """ - A flag indicating whether or not this class supports expert maps - """ - raise NotImplementedError - def supports_packed_ue8m0_act_scales(self) -> bool: """ A flag indicating whether or not this class can process packed ue8m0 @@ -1567,12 +1560,6 @@ class FusedMoEKernel: == self.fused_experts.activation_format() ) - def supports_expert_map(self) -> bool: - """ - A flag indicating whether or not this class supports expert maps. - """ - return self.fused_experts.supports_expert_map() - def output_is_reduced(self) -> bool: """ Indicates whether or not the output of fused MoE kernel diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py index ffbb4c4a7d3..89f3843cc50 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py @@ -17,7 +17,7 @@ def _quantize_and_setup_dispatch( a1: torch.Tensor, quant_config: FusedMoEQuantConfig, defer_input_quant: bool = False, -) -> tuple[torch.Tensor, list[torch.Tensor] | None]: +) -> tuple[torch.Tensor, list[torch.Tensor] | None, torch.Tensor | None]: # Defer input quantization to the MoE kernel. if defer_input_quant: a1q = a1 @@ -33,7 +33,7 @@ def _quantize_and_setup_dispatch( # which makes the scales tensor different shape than # the hidden states, breaking the A2A kernel. So, we # delay the swizzling until after the A2A. - a1q, a1q_scale = a1q, a1q_scale = moe_kernel_quantize_input( + a1q, a1q_scale = moe_kernel_quantize_input( a1, input_sf, quant_dtype=quant_config.quant_dtype, @@ -49,7 +49,7 @@ def _quantize_and_setup_dispatch( skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0 scales = None if skip_gather_scales else [a1q_scale] - return a1q, scales + return a1q, scales, a1q_scale def _unwrap_scale_and_prepare_for_moe( @@ -129,7 +129,9 @@ class MoEPrepareAndFinalizeNaiveDPEPModular(mk.FusedMoEPrepareAndFinalizeModular ) a1 = a1 * topk_weights.to(a1.dtype) - a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant) + a1q, scales, a1q_scale_orig = _quantize_and_setup_dispatch( + a1, quant_config, defer_input_quant + ) # When LoRA is active, dispatch the per-token LoRA id along with # hidden_states so every rank receives the correct mapping for the @@ -164,7 +166,7 @@ class MoEPrepareAndFinalizeNaiveDPEPModular(mk.FusedMoEPrepareAndFinalizeModular if extra_tensors is None: assert len(res) == 3 a1q, topk_weights, topk_ids = res - a1q_scale = None + a1q_scale = a1q_scale_orig else: assert len(res) == 4 a1q, topk_weights, topk_ids, gathered_extras = res @@ -178,7 +180,7 @@ class MoEPrepareAndFinalizeNaiveDPEPModular(mk.FusedMoEPrepareAndFinalizeModular gathered_extras, quant_config ) else: - a1q_scale = None + a1q_scale = a1q_scale_orig return a1q, a1q_scale, None, topk_ids, topk_weights @@ -249,7 +251,9 @@ class MoEPrepareAndFinalizeNaiveDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMono ) -> mk.PrepareMonolithicResultType: """Quantize and Dispatch Router Logits.""" - a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant) + a1q, scales, a1q_scale_orig = _quantize_and_setup_dispatch( + a1, quant_config, defer_input_quant + ) res = get_ep_group().dispatch_router_logits( a1q, @@ -261,7 +265,7 @@ class MoEPrepareAndFinalizeNaiveDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMono if scales is None: assert len(res) == 2 a1q, router_logits = res - a1q_scale = None + a1q_scale = a1q_scale_orig else: assert len(res) == 3 a1q, router_logits, scales = res diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py index da5d85e4abc..14ef8bf614c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py @@ -405,8 +405,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): topk_ids, activation=layer.activation, global_num_experts=layer.global_num_experts, - # TODO(rob): investigate the disable_expert_map introduced by: - # https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501 expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, shared_experts=shared_experts,