[https://nvbugs/5753788][chore] Padding empty chunk for configurable moe (#10451)

Signed-off-by: leslie-fang25 <leslief@nvidia.com>
This commit is contained in:
Leslie Fang 2026-01-14 10:42:17 +08:00 committed by GitHub
parent d3f4fbb742
commit 795e690bca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -837,6 +837,26 @@ class ConfigurableMoE(MoE):
all_rank_num_tokens_list, chunk_size_list, use_multi_stream
)
# ========== Padding empty chunk ==========
chunked_used = torch.ones(num_chunks, dtype=torch.bool)
if self.use_dp:
# For empty chunk, will use chunk 0 instead. The current split heuristic
# ensures that if an empty chunk exists, Chunk 0 contains exactly one token.
assert x_list[0].numel() != 0, "chunk 0 shouldn't be empty"
x_list = list(x_list)
router_logits_list = list(router_logits_list)
for idx_chunk in range(num_chunks):
_x = x_list[idx_chunk]
if _x.numel() == 0:
chunked_used[idx_chunk] = False
x_list[idx_chunk] = x_list[0]
router_logits_list[idx_chunk] = router_logits_list[0]
all_rank_num_tokens_list[idx_chunk][self.mapping.tp_rank] = (
all_rank_num_tokens_list[0][self.mapping.tp_rank]
)
x_list = tuple(x_list)
router_logits_list = tuple(router_logits_list)
# ========== Execute chunking with overlap ==========
outputs_list = []
for idx_chunk, (x_chunk, router_logits_chunk) in enumerate(zip(x_list, router_logits_list)):
@ -888,7 +908,8 @@ class ConfigurableMoE(MoE):
workspace=workspace_0,
)
outputs_list.append(outputs)
if chunked_used[idx_chunk]:
outputs_list.append(outputs)
# ========== Wait for auxiliary stream to complete ==========
if use_multi_stream: