[fix] Fix can_use_alltoall in fused_moe_wide_ep.py (#6173)

Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
This commit is contained in:
Jinyang Yuan 2025-07-21 10:53:07 +08:00 committed by GitHub
parent b4c7e8c9a5
commit 88076eecd0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -283,16 +283,14 @@ class WideEPMoE(MoE):
return (num_rows + self.moe_max_num_tokens -
1) // self.moe_max_num_tokens
def can_use_alltoall(self, input, all_rank_num_tokens):
def can_use_alltoall(self, all_rank_num_tokens, all_rank_max_num_tokens):
# Disable alltoall when chunking is used
if self.calculate_num_chunks(all_rank_num_tokens) > 1:
return False
num_tokens = input.shape[0]
# For DeepEPLowLatency, check if tokens exceed the threshold
if (self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency
and num_tokens > self.deep_ep_max_num_tokens):
and all_rank_max_num_tokens > self.deep_ep_max_num_tokens):
return False
return self.enable_alltoall
@ -726,7 +724,8 @@ class WideEPMoE(MoE):
# in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
num_chunks = self.calculate_num_chunks(all_rank_num_tokens)
use_all_to_all = self.can_use_alltoall(x, all_rank_num_tokens)
use_all_to_all = self.can_use_alltoall(all_rank_num_tokens,
all_rank_max_num_tokens)
if use_dp_padding:
all_rank_num_tokens_padded = [all_rank_max_num_tokens