From 88076eecd05184f122f134680bea3e1794c1a115 Mon Sep 17 00:00:00 2001 From: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com> Date: Mon, 21 Jul 2025 10:53:07 +0800 Subject: [PATCH] [fix] Fix can_use_alltoall in fused_moe_wide_ep.py (#6173) Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com> --- .../_torch/modules/fused_moe/fused_moe_wide_ep.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index 36de5ddc1b..81778c2854 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -283,16 +283,14 @@ class WideEPMoE(MoE): return (num_rows + self.moe_max_num_tokens - 1) // self.moe_max_num_tokens - def can_use_alltoall(self, input, all_rank_num_tokens): + def can_use_alltoall(self, all_rank_num_tokens, all_rank_max_num_tokens): # Disable alltoall when chunking is used if self.calculate_num_chunks(all_rank_num_tokens) > 1: return False - num_tokens = input.shape[0] - # For DeepEPLowLatency, check if tokens exceed the threshold if (self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency - and num_tokens > self.deep_ep_max_num_tokens): + and all_rank_max_num_tokens > self.deep_ep_max_num_tokens): return False return self.enable_alltoall @@ -726,7 +724,8 @@ class WideEPMoE(MoE): # in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks num_chunks = self.calculate_num_chunks(all_rank_num_tokens) - use_all_to_all = self.can_use_alltoall(x, all_rank_num_tokens) + use_all_to_all = self.can_use_alltoall(all_rank_num_tokens, + all_rank_max_num_tokens) if use_dp_padding: all_rank_num_tokens_padded = [all_rank_max_num_tokens