From 7147efb2e80b83b012264693c5bca7da164cda0e Mon Sep 17 00:00:00 2001 From: dongxuy04 <78518666+dongxuy04@users.noreply.github.com> Date: Fri, 9 May 2025 09:01:35 +0800 Subject: [PATCH] fix: alltoall padding for chunked MoE (#4157) fix alltoall padding for chunked MoE Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> --- tensorrt_llm/_torch/models/modeling_deepseekv3.py | 4 ++-- tensorrt_llm/_torch/modules/fused_moe.py | 15 +++++++-------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index e64f0990d4..39699e0059 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -457,14 +457,14 @@ class Deepseekv3MoE(nn.Module): def compute_routed_output(self, hidden_states, hidden_states_fp4, all_rank_num_tokens, cutlass_min_latency_mode): # max-throughput - if self.use_dp and self.mapping.tp_size > 1 and not self.enable_alltoall: + if self.use_dp and self.mapping.tp_size > 1: max_num_token = max(all_rank_num_tokens) hidden_states = torch.nn.functional.pad( hidden_states, (0, 0, 0, max_num_token - hidden_states.shape[0])) # FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization # to reduce allreduce BW - if disable_fp4_allgather(): + if disable_fp4_allgather() and not self.enable_alltoall: hidden_states = allgather(hidden_states, self.mapping, gather_dim=0) diff --git a/tensorrt_llm/_torch/modules/fused_moe.py b/tensorrt_llm/_torch/modules/fused_moe.py index 0a49df75a7..7007f471ac 100755 --- a/tensorrt_llm/_torch/modules/fused_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe.py @@ -903,14 +903,13 @@ class FusedMoE(nn.Module): if self.use_dp and self.enable_alltoall: all_rank_chunk_size_list = [] for single_rank_num_tokens in all_rank_num_tokens: - single_rank_num_chunks = (single_rank_num_tokens + - max_chunk_size - - 1) // max_chunk_size - assert single_rank_num_chunks == num_chunks,\ - "num_chunks should be the same for attention dp and ep" - all_rank_chunk_size_list.append( - split_chunk(single_rank_num_tokens, - single_rank_num_chunks)) + single_rank_num_chunks = num_chunks + single_rank_chunk_size_list = split_chunk( + single_rank_num_tokens, single_rank_num_chunks) + single_rank_chunk_size_list = [ + 1 if x == 0 else x for x in single_rank_chunk_size_list + ] + all_rank_chunk_size_list.append(single_rank_chunk_size_list) for chunk_id in range(num_chunks): chunk_all_rank_num_tokens = [