[Bugfix] Replace code that disabled shared expert overlap (#39222)

Signed-off-by: Bill Nell <bnell@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
bnellnm
2026-04-20 19:36:16 -04:00
committed by GitHub
parent c075702eae
commit 6867bcd076
3 changed files with 17 additions and 4 deletions
@@ -990,6 +990,7 @@ class FusedMoEParallelConfig:
@property
def use_batched_activation_format(self):
# TODO(bnell): nixl also uses batched format
return self.use_deepep_ll_kernels
@property
@@ -427,9 +427,6 @@ class MoERunnerBase(MoERunner):
via the router, and the actual fused MoE computation. Returns
(shared_expert_output, fused_expert_output).
"""
# Run this before quant_method to avoid inplace issues.
# TODO(bnell): probably not needed anymore since inplace is
# disabled when shared experts are present.
self._maybe_apply_shared_experts(
shared_experts_input, SharedExpertsOrder.NO_OVERLAP
)
@@ -447,7 +444,7 @@ class MoERunnerBase(MoERunner):
)
# Passing shared_experts_input in case SharedExpertsOrder is
# NO_OVERLAP or MK_INTERNAL_OVERLAPPED.
# MK_INTERNAL_OVERLAPPED.
fused_out = self.quant_method.apply(
layer=layer,
x=hidden_states,
@@ -490,6 +487,7 @@ class MoERunnerBase(MoERunner):
# parallel execution of shared experts with the FusedMoE via
# separate cuda stream)
if self._shared_experts is not None:
assert shared_experts_input is not None
self._shared_experts.maybe_sync_shared_experts_stream(shared_experts_input)
def _maybe_add_zero_expert_output(
@@ -80,10 +80,24 @@ class SharedExperts:
"Enabled separate cuda stream for MoE shared_experts", scope="local"
)
@property
def _disable_shared_experts_overlap(self) -> bool:
# Disable shared expert overlap if:
# - we are using eplb with non-default backend, because of correctness issues
# - we are using flashinfer with DP, since there nothing to gain
parallel_config = self._moe_config.moe_parallel_config
return (
parallel_config.enable_eplb
and parallel_config.all2all_backend != "allgather_reducescatter"
) or parallel_config.use_fi_nvl_two_sided_kernels
def _determine_shared_experts_order(
self,
hidden_states: torch.Tensor,
) -> SharedExpertsOrder:
if self._disable_shared_experts_overlap:
return SharedExpertsOrder.NO_OVERLAP
if self._quant_method.mk_owns_shared_expert:
return SharedExpertsOrder.MK_INTERNAL_OVERLAPPED