mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[fix] Fix can_use_alltoall in fused_moe_wide_ep.py (#6173)
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
This commit is contained in:
parent
b4c7e8c9a5
commit
88076eecd0
@ -283,16 +283,14 @@ class WideEPMoE(MoE):
|
||||
return (num_rows + self.moe_max_num_tokens -
|
||||
1) // self.moe_max_num_tokens
|
||||
|
||||
def can_use_alltoall(self, input, all_rank_num_tokens):
|
||||
def can_use_alltoall(self, all_rank_num_tokens, all_rank_max_num_tokens):
|
||||
# Disable alltoall when chunking is used
|
||||
if self.calculate_num_chunks(all_rank_num_tokens) > 1:
|
||||
return False
|
||||
|
||||
num_tokens = input.shape[0]
|
||||
|
||||
# For DeepEPLowLatency, check if tokens exceed the threshold
|
||||
if (self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency
|
||||
and num_tokens > self.deep_ep_max_num_tokens):
|
||||
and all_rank_max_num_tokens > self.deep_ep_max_num_tokens):
|
||||
return False
|
||||
|
||||
return self.enable_alltoall
|
||||
@ -726,7 +724,8 @@ class WideEPMoE(MoE):
|
||||
|
||||
# in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
|
||||
num_chunks = self.calculate_num_chunks(all_rank_num_tokens)
|
||||
use_all_to_all = self.can_use_alltoall(x, all_rank_num_tokens)
|
||||
use_all_to_all = self.can_use_alltoall(all_rank_num_tokens,
|
||||
all_rank_max_num_tokens)
|
||||
|
||||
if use_dp_padding:
|
||||
all_rank_num_tokens_padded = [all_rank_max_num_tokens
|
||||
|
||||
Loading…
Reference in New Issue
Block a user