[fix] Fix wide EP when using DeepEP with online EPLB (#6429)

Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
This commit is contained in:
Jinyang Yuan 2025-07-30 12:13:18 +08:00 committed by GitHub
parent c9ed1ab436
commit a427f5bece
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -470,6 +470,10 @@ class WideEPMoE(MoE):
self.expert_size_per_partition * self.mapping.moe_ep_rank)
padded, x, _, token_selected_slots, token_final_scales = self.pad_empty_recv_tensors(
x, None, recv_topk_idx, token_final_scales)
if is_last_call and self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
):
gathered_loadbalancer_local_statistic_info = allgather(
loadbalancer_local_statistic_info, self.mapping, dim=0)
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
if not use_postquant_alltoall:
deep_ep_topk_idx = token_selected_slots
@ -499,6 +503,10 @@ class WideEPMoE(MoE):
x.shape[0], 1)
token_final_scales = torch.ones_like(
token_selected_slots, dtype=token_final_scales.dtype)
if is_last_call and self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
):
gathered_loadbalancer_local_statistic_info = allgather(
loadbalancer_local_statistic_info, self.mapping, dim=0)
x_sf = None
x_row = x.shape[0]