mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[fix] Fix wide EP when using DeepEP with online EPLB (#6429)
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
This commit is contained in:
parent
c9ed1ab436
commit
a427f5bece
@ -470,6 +470,10 @@ class WideEPMoE(MoE):
|
||||
self.expert_size_per_partition * self.mapping.moe_ep_rank)
|
||||
padded, x, _, token_selected_slots, token_final_scales = self.pad_empty_recv_tensors(
|
||||
x, None, recv_topk_idx, token_final_scales)
|
||||
if is_last_call and self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
|
||||
):
|
||||
gathered_loadbalancer_local_statistic_info = allgather(
|
||||
loadbalancer_local_statistic_info, self.mapping, dim=0)
|
||||
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
|
||||
if not use_postquant_alltoall:
|
||||
deep_ep_topk_idx = token_selected_slots
|
||||
@ -499,6 +503,10 @@ class WideEPMoE(MoE):
|
||||
x.shape[0], 1)
|
||||
token_final_scales = torch.ones_like(
|
||||
token_selected_slots, dtype=token_final_scales.dtype)
|
||||
if is_last_call and self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
|
||||
):
|
||||
gathered_loadbalancer_local_statistic_info = allgather(
|
||||
loadbalancer_local_statistic_info, self.mapping, dim=0)
|
||||
|
||||
x_sf = None
|
||||
x_row = x.shape[0]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user