From 795e690bca24108d3a660dc36efc2f059853538b Mon Sep 17 00:00:00 2001 From: Leslie Fang Date: Wed, 14 Jan 2026 10:42:17 +0800 Subject: [PATCH] [https://nvbugs/5753788][chore] Padding empty chunk for configurable moe (#10451) Signed-off-by: leslie-fang25 --- .../modules/fused_moe/configurable_moe.py | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py index 07fad25df4..d0fbd64856 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py @@ -837,6 +837,26 @@ class ConfigurableMoE(MoE): all_rank_num_tokens_list, chunk_size_list, use_multi_stream ) + # ========== Padding empty chunk ========== + chunked_used = torch.ones(num_chunks, dtype=torch.bool) + if self.use_dp: + # For empty chunk, will use chunk 0 instead. The current split heuristic + # ensures that if an empty chunk exists, Chunk 0 contains exactly one token. + assert x_list[0].numel() != 0, "chunk 0 shouldn't be empty" + x_list = list(x_list) + router_logits_list = list(router_logits_list) + for idx_chunk in range(num_chunks): + _x = x_list[idx_chunk] + if _x.numel() == 0: + chunked_used[idx_chunk] = False + x_list[idx_chunk] = x_list[0] + router_logits_list[idx_chunk] = router_logits_list[0] + all_rank_num_tokens_list[idx_chunk][self.mapping.tp_rank] = ( + all_rank_num_tokens_list[0][self.mapping.tp_rank] + ) + x_list = tuple(x_list) + router_logits_list = tuple(router_logits_list) + # ========== Execute chunking with overlap ========== outputs_list = [] for idx_chunk, (x_chunk, router_logits_chunk) in enumerate(zip(x_list, router_logits_list)): @@ -888,7 +908,8 @@ class ConfigurableMoE(MoE): workspace=workspace_0, ) - outputs_list.append(outputs) + if chunked_used[idx_chunk]: + outputs_list.append(outputs) # ========== Wait for auxiliary stream to complete ========== if use_multi_stream: