mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[MoE Refactor] Remove supports_expert_map (#43108)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -224,10 +224,6 @@ class Config:
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return info.blocked_quantization_support
|
||||
|
||||
def supports_expert_map(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return info.supports_expert_map
|
||||
|
||||
def supports_apply_weight_on_input(self):
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return info.supports_apply_weight_on_input
|
||||
@@ -326,6 +322,15 @@ class Config:
|
||||
if self.needs_mori() and not has_mori(): # noqa: SIM103
|
||||
return False, "Needs MoRI, but MoRI not available."
|
||||
|
||||
try:
|
||||
if not self.fused_experts_type._supports_current_device():
|
||||
return (
|
||||
False,
|
||||
f"{self.fused_experts_type} not supported on the current device.",
|
||||
)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
@@ -471,7 +476,7 @@ class RankTensors:
|
||||
topk_ids = topk_ids.to(device=device)
|
||||
|
||||
expert_map = None
|
||||
if config.world_size > 1 and config.supports_expert_map():
|
||||
if config.world_size > 1:
|
||||
expert_map = torch.full(
|
||||
(global_num_experts,), fill_value=-1, dtype=torch.int32
|
||||
)
|
||||
|
||||
@@ -67,7 +67,6 @@ class ExpertInfo:
|
||||
activation_format: mk.FusedMoEActivationFormat
|
||||
supported_dtypes: list[torch.dtype | str]
|
||||
blocked_quantization_support: bool
|
||||
supports_expert_map: bool
|
||||
needs_matching_quant: bool = False
|
||||
needs_deep_gemm: bool = False
|
||||
needs_aiter: bool = False
|
||||
@@ -129,7 +128,6 @@ def register_experts(
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
supported_dtypes: list[torch.dtype | str],
|
||||
blocked_quantization_support: bool,
|
||||
supports_expert_map: bool,
|
||||
needs_matching_quant: bool = False,
|
||||
needs_deep_gemm: bool = False,
|
||||
needs_aiter: bool = False,
|
||||
@@ -142,7 +140,6 @@ def register_experts(
|
||||
activation_format,
|
||||
supported_dtypes,
|
||||
blocked_quantization_support,
|
||||
supports_expert_map,
|
||||
needs_matching_quant,
|
||||
needs_deep_gemm,
|
||||
needs_aiter,
|
||||
@@ -176,7 +173,6 @@ register_experts(
|
||||
batched_format,
|
||||
common_float_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_expert_map=False,
|
||||
needs_matching_quant=True,
|
||||
)
|
||||
|
||||
@@ -185,7 +181,6 @@ register_experts(
|
||||
standard_format,
|
||||
common_float_and_int_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_expert_map=True,
|
||||
needs_matching_quant=True,
|
||||
)
|
||||
|
||||
@@ -194,7 +189,6 @@ register_experts(
|
||||
batched_format,
|
||||
common_float_and_int_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_expert_map=True,
|
||||
)
|
||||
|
||||
# Disable on blackwell for now
|
||||
@@ -260,7 +254,6 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
|
||||
nvfp4_types + fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
# Note: this is a hack to get it to run for now
|
||||
supports_expert_map=True,
|
||||
)
|
||||
else:
|
||||
FlashInferCutlassMoEPrepareAndFinalize = None
|
||||
@@ -294,7 +287,6 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
|
||||
standard_format,
|
||||
nvfp4_types,
|
||||
blocked_quantization_support=False,
|
||||
supports_expert_map=True,
|
||||
)
|
||||
|
||||
if has_aiter():
|
||||
@@ -307,7 +299,6 @@ if has_aiter():
|
||||
standard_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_expert_map=True,
|
||||
needs_aiter=True,
|
||||
)
|
||||
else:
|
||||
@@ -319,7 +310,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
|
||||
batched_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_expert_map=False,
|
||||
needs_matching_quant=False,
|
||||
needs_deep_gemm=True,
|
||||
)
|
||||
@@ -328,7 +318,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
|
||||
standard_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_expert_map=True,
|
||||
needs_matching_quant=False,
|
||||
needs_deep_gemm=True,
|
||||
)
|
||||
@@ -337,7 +326,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
|
||||
standard_format,
|
||||
common_float_and_int_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_expert_map=True,
|
||||
needs_matching_quant=True,
|
||||
needs_deep_gemm=True,
|
||||
)
|
||||
@@ -353,14 +341,12 @@ if cutlass_fp8_supported():
|
||||
standard_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=False,
|
||||
supports_expert_map=False,
|
||||
)
|
||||
register_experts(
|
||||
CutlassBatchedExpertsFp8,
|
||||
batched_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=False,
|
||||
supports_expert_map=False,
|
||||
)
|
||||
else:
|
||||
CutlassBatchedExpertsFp8 = None
|
||||
@@ -376,7 +362,6 @@ if cutlass_fp4_supported():
|
||||
standard_format,
|
||||
nvfp4_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_expert_map=False,
|
||||
)
|
||||
else:
|
||||
CutlassExpertsFp4 = None
|
||||
|
||||
@@ -227,7 +227,7 @@ def is_nyi_config(config: Config) -> bool:
|
||||
) == 1
|
||||
return unsupported_quant_config
|
||||
|
||||
return not info.supports_expert_map
|
||||
return False
|
||||
|
||||
|
||||
def generate_valid_test_cases(
|
||||
|
||||
@@ -248,9 +248,6 @@ class AiterW4A8ExpertsMonolithic(mk.FusedMoEExpertsMonolithic):
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False # Expert parallelism not yet supported
|
||||
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -316,9 +316,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEExpertsModular):
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_packed_ue8m0_act_scales(self) -> bool:
|
||||
"""
|
||||
DeepGemm supports packed ue8m0 activation scales format in devices == sm100
|
||||
|
||||
@@ -100,9 +100,6 @@ class CPUExpertsFp8(mk.FusedMoEExpertsMonolithic):
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -256,9 +253,6 @@ class CPUExpertsMxfp4(mk.FusedMoEExpertsMonolithic):
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@@ -378,7 +378,8 @@ class CutlassExpertsFp8Base(mk.FusedMoEExpertsModular):
|
||||
topk_ids,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
# the fp8 cutlass experts use their own expert map.
|
||||
None,
|
||||
self.w1_scale,
|
||||
self.w2_scale,
|
||||
a1q_scale,
|
||||
@@ -418,9 +419,6 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
or moe_parallel_config.use_fi_nvl_one_sided_kernels
|
||||
)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# topk weights and reduction are fused in moe_unpermute cuda kernel
|
||||
return TopKWeightAndReduceNoOP()
|
||||
@@ -460,9 +458,6 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
|
||||
return self.out_dtype if self.out_dtype is not None else act_dtype
|
||||
|
||||
@@ -741,9 +736,6 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
@@ -1038,9 +1030,6 @@ class CutlassExpertsMxfp4(mk.FusedMoEExpertsModular):
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
@@ -1340,9 +1329,6 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEExpertsModular):
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# topk weights and reduction are fused in moe_unpermute cuda kernel
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
@@ -164,9 +164,6 @@ class DeepGemmExperts(mk.FusedMoEExpertsModular):
|
||||
or moe_parallel_config.use_fi_nvl_one_sided_kernels
|
||||
)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
@@ -388,9 +385,6 @@ class DeepGemmFP4Experts(mk.FusedMoEExpertsModular):
|
||||
or moe_parallel_config.use_fi_nvl_one_sided_kernels
|
||||
)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
|
||||
@@ -92,16 +92,6 @@ class FallbackExperts(mk.FusedMoEExpertsModular, ABC):
|
||||
moe_parallel_config
|
||||
) and fallback_cls._supports_parallel_config(moe_parallel_config)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
assert (
|
||||
self.experts.supports_expert_map()
|
||||
== self.fallback_experts.supports_expert_map()
|
||||
)
|
||||
return (
|
||||
self.experts.supports_expert_map()
|
||||
and self.fallback_experts.supports_expert_map()
|
||||
)
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
e_war = self.experts.finalize_weight_and_reduce_impl()
|
||||
fbe_war = self.fallback_experts.finalize_weight_and_reduce_impl()
|
||||
|
||||
@@ -89,9 +89,6 @@ class FlashInferCuteDSLBatchedExperts(mk.FusedMoEExpertsModular):
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
|
||||
@@ -98,9 +98,6 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
|
||||
@@ -207,9 +207,6 @@ class FlashInferExperts(mk.FusedMoEExpertsModular):
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
|
||||
@@ -555,9 +555,6 @@ class NaiveBatchedExperts(mk.FusedMoEExpertsModular):
|
||||
"This method should not be called."
|
||||
)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
@@ -799,9 +796,6 @@ class BatchedTritonExperts(mk.FusedMoEExpertsModular):
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
|
||||
@@ -156,9 +156,6 @@ class HummingExpertsBase(mk.FusedMoEExpertsModular):
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
platform = current_platform
|
||||
|
||||
@@ -608,9 +608,6 @@ class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def moe_problem_size(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
@@ -1036,9 +1033,6 @@ class OAITritonMxfp4ExpertsMonolithic(mk.FusedMoEExpertsMonolithic):
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -686,9 +686,6 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular):
|
||||
class MarlinExperts(LoRAExpertsMixin, MarlinExpertsBase):
|
||||
"""Marlin-based fused MoE expert implementation."""
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
@@ -920,9 +917,6 @@ class BatchedMarlinExperts(MarlinExpertsBase):
|
||||
is_k_full=is_k_full,
|
||||
)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceDelegate()
|
||||
|
||||
|
||||
@@ -441,9 +441,6 @@ class AiterExperts(mk.FusedMoEExpertsModular):
|
||||
or moe_parallel_config.use_fi_nvl_one_sided_kernels
|
||||
)
|
||||
|
||||
def supports_expert_map(self):
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
|
||||
@@ -128,9 +128,6 @@ class TritonExperts(LoRAExpertsMixin, mk.FusedMoEExpertsModular):
|
||||
def _supports_batch_invariance():
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
|
||||
@@ -99,9 +99,6 @@ class TrtLlmBf16Experts(mk.FusedMoEExpertsMonolithic):
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -88,9 +88,6 @@ class TrtLlmFp8ExpertsBase:
|
||||
or moe_parallel_config.use_ag_rs_all2all_kernels
|
||||
) and not moe_parallel_config.enable_eplb
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
|
||||
"""
|
||||
|
||||
@@ -113,9 +113,6 @@ class TrtLlmMxfp4ExpertsBase:
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
return False
|
||||
@@ -248,9 +245,6 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula
|
||||
# routing is done externally, so accept any routing method.
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
|
||||
@@ -179,9 +179,6 @@ class TrtLlmNvFp4ExpertsBase:
|
||||
300000, _calc_max_supported_tokens(self.topk, self.moe_config.num_experts)
|
||||
)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class TrtLlmNvFp4ExpertsModular(TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsModular):
|
||||
"""
|
||||
|
||||
@@ -107,9 +107,6 @@ class XPUExperts(mk.FusedMoEExpertsModular):
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
|
||||
@@ -34,11 +34,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
super().__init__(moe_kernel.moe_config)
|
||||
self.moe_quant_config = old_quant_method.moe_quant_config
|
||||
self.moe_kernel = moe_kernel
|
||||
self.disable_expert_map = getattr(
|
||||
old_quant_method,
|
||||
"disable_expert_map",
|
||||
not self.moe_kernel.supports_expert_map(),
|
||||
)
|
||||
self.old_quant_method = old_quant_method
|
||||
logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__)
|
||||
|
||||
@@ -103,7 +98,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
expert_map=None if self.disable_expert_map else layer.expert_map,
|
||||
expert_map=layer.expert_map,
|
||||
shared_experts=shared_experts,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
@@ -751,13 +751,6 @@ class FusedMoEExperts(ABC):
|
||||
"""
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def supports_expert_map(self) -> bool:
|
||||
"""
|
||||
A flag indicating whether or not this class supports expert maps
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def supports_packed_ue8m0_act_scales(self) -> bool:
|
||||
"""
|
||||
A flag indicating whether or not this class can process packed ue8m0
|
||||
@@ -1567,12 +1560,6 @@ class FusedMoEKernel:
|
||||
== self.fused_experts.activation_format()
|
||||
)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
"""
|
||||
A flag indicating whether or not this class supports expert maps.
|
||||
"""
|
||||
return self.fused_experts.supports_expert_map()
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not the output of fused MoE kernel
|
||||
|
||||
@@ -17,7 +17,7 @@ def _quantize_and_setup_dispatch(
|
||||
a1: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, list[torch.Tensor] | None]:
|
||||
) -> tuple[torch.Tensor, list[torch.Tensor] | None, torch.Tensor | None]:
|
||||
# Defer input quantization to the MoE kernel.
|
||||
if defer_input_quant:
|
||||
a1q = a1
|
||||
@@ -33,7 +33,7 @@ def _quantize_and_setup_dispatch(
|
||||
# which makes the scales tensor different shape than
|
||||
# the hidden states, breaking the A2A kernel. So, we
|
||||
# delay the swizzling until after the A2A.
|
||||
a1q, a1q_scale = a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
input_sf,
|
||||
quant_dtype=quant_config.quant_dtype,
|
||||
@@ -49,7 +49,7 @@ def _quantize_and_setup_dispatch(
|
||||
skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0
|
||||
scales = None if skip_gather_scales else [a1q_scale]
|
||||
|
||||
return a1q, scales
|
||||
return a1q, scales, a1q_scale
|
||||
|
||||
|
||||
def _unwrap_scale_and_prepare_for_moe(
|
||||
@@ -129,7 +129,9 @@ class MoEPrepareAndFinalizeNaiveDPEPModular(mk.FusedMoEPrepareAndFinalizeModular
|
||||
)
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant)
|
||||
a1q, scales, a1q_scale_orig = _quantize_and_setup_dispatch(
|
||||
a1, quant_config, defer_input_quant
|
||||
)
|
||||
|
||||
# When LoRA is active, dispatch the per-token LoRA id along with
|
||||
# hidden_states so every rank receives the correct mapping for the
|
||||
@@ -164,7 +166,7 @@ class MoEPrepareAndFinalizeNaiveDPEPModular(mk.FusedMoEPrepareAndFinalizeModular
|
||||
if extra_tensors is None:
|
||||
assert len(res) == 3
|
||||
a1q, topk_weights, topk_ids = res
|
||||
a1q_scale = None
|
||||
a1q_scale = a1q_scale_orig
|
||||
else:
|
||||
assert len(res) == 4
|
||||
a1q, topk_weights, topk_ids, gathered_extras = res
|
||||
@@ -178,7 +180,7 @@ class MoEPrepareAndFinalizeNaiveDPEPModular(mk.FusedMoEPrepareAndFinalizeModular
|
||||
gathered_extras, quant_config
|
||||
)
|
||||
else:
|
||||
a1q_scale = None
|
||||
a1q_scale = a1q_scale_orig
|
||||
|
||||
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||
|
||||
@@ -249,7 +251,9 @@ class MoEPrepareAndFinalizeNaiveDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMono
|
||||
) -> mk.PrepareMonolithicResultType:
|
||||
"""Quantize and Dispatch Router Logits."""
|
||||
|
||||
a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant)
|
||||
a1q, scales, a1q_scale_orig = _quantize_and_setup_dispatch(
|
||||
a1, quant_config, defer_input_quant
|
||||
)
|
||||
|
||||
res = get_ep_group().dispatch_router_logits(
|
||||
a1q,
|
||||
@@ -261,7 +265,7 @@ class MoEPrepareAndFinalizeNaiveDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMono
|
||||
if scales is None:
|
||||
assert len(res) == 2
|
||||
a1q, router_logits = res
|
||||
a1q_scale = None
|
||||
a1q_scale = a1q_scale_orig
|
||||
else:
|
||||
assert len(res) == 3
|
||||
a1q, router_logits, scales = res
|
||||
|
||||
-2
@@ -405,8 +405,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
topk_ids,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
# TODO(rob): investigate the disable_expert_map introduced by:
|
||||
# https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
shared_experts=shared_experts,
|
||||
|
||||
Reference in New Issue
Block a user