[None][fix] Fix enable_alltoall passed to CutlassFusedMoE (#11016)

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
Enwei Zhu 2026-01-29 12:11:07 +08:00 committed by GitHub
parent 24ac86c485
commit 34a730aaf7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 16 additions and 30 deletions

View File

@ -206,7 +206,11 @@ if args.run_type == "GEN":
max(1, 20480 // ctx_seq_len_q),
)
ctx_attn_workspace = torch.empty((0,), device="cuda", dtype=torch.int8)
with mock.patch.dict(os.environ, {"TRTLLM_FORCE_ALLTOALL_METHOD": "NotEnabled"}, clear=False):
with mock.patch.dict(
os.environ,
{"TRTLLM_FORCE_ALLTOALL_METHOD": "NotEnabled", "TRTLLM_FORCE_COMM_METHOD": "ALLGATHER"},
clear=False,
):
ctx_runner = Runner(
args.model,
mapping,

View File

@ -313,7 +313,7 @@ class ConfigurableMoE(MoE):
2. Validates if current AllToAll strategy can be used for given workload
3. Falls back to AllGather if current strategy cannot be used (logs info message)
After calling this method, use _is_using_alltoall() to check which method is active.
After calling this method, use enable_alltoall to check which method is active.
Args:
all_rank_num_tokens: Token counts per rank
@ -348,23 +348,6 @@ class ConfigurableMoE(MoE):
# Switch to AllGather (always works)
self.comm = AllGatherReduceScatter(mapping=self.mapping)
def _is_using_alltoall(self) -> bool:
"""
Check if current communication strategy uses alltoall
Returns:
True: Strategy uses alltoall (NVLINK, DeepEP, etc.)
False: Strategy uses allgather (AllGatherReduceScatter or None)
Note: Can be called anytime. If comm is None, returns False (no alltoall).
Typically called after determine_communication_method() to get accurate result.
"""
if self.comm is None:
return False # No strategy means no alltoall
# AllGather uses allgather, all others use alltoall
return not isinstance(self.comm, AllGatherReduceScatter)
def _create_comm_strategy_auto(self) -> Communication:
"""
Auto-create the best communication strategy based on hardware and configuration
@ -810,11 +793,7 @@ class ConfigurableMoE(MoE):
Same as original implementation - chunking logic is backend-agnostic
Note: use_all_to_all is determined internally via _is_using_alltoall()
"""
# Determine if using alltoall
use_all_to_all = self._is_using_alltoall()
# ========== Chunk preparation ==========
if self.use_dp:
# When using DP: need all ranks' token counts for reducescatter
@ -828,7 +807,7 @@ class ConfigurableMoE(MoE):
chunk_size_list = all_rank_chunk_size_list[self.rank]
# For alltoall, replace 0 with 1 (avoid empty tensor)
if use_all_to_all:
if self.enable_alltoall:
all_rank_num_tokens_list = [
[1 if val == 0 else val for val in val_list]
for val_list in all_rank_num_tokens_list
@ -842,7 +821,7 @@ class ConfigurableMoE(MoE):
router_logits_list = router_logits.split(chunk_size_list)
# Determine if we need multiple streams for overlapped execution
use_multi_stream = not use_all_to_all and self.aux_stream is not None
use_multi_stream = not self.enable_alltoall and self.aux_stream is not None
# ========== Setup auxiliary stream ==========
if use_multi_stream:
@ -1086,7 +1065,8 @@ class ConfigurableMoE(MoE):
# Only the non-alltoall case is considered for profiling in the warmup phase.
# Therefore, to get the correct tactics during the actual inference, the inputs to the tuner
# should be the same as when not using alltoall.
if self._is_using_alltoall():
kwargs["enable_alltoall"] = self.enable_alltoall
if self.enable_alltoall:
if all_rank_num_tokens is not None:
kwargs["tuner_num_tokens"] = sum(all_rank_num_tokens)
else:
@ -1094,9 +1074,6 @@ class ConfigurableMoE(MoE):
x.shape[0] * self.mapping.tp_size if x is not None else None
)
kwargs["tuner_top_k"] = self.routing_method.top_k
else:
kwargs["tuner_num_tokens"] = None
kwargs["tuner_top_k"] = None
# Get moe_output for NVLinkOneSided backend
kwargs["moe_output"] = self._get_nvlink_onesided_moe_output(

View File

@ -411,6 +411,7 @@ class CutlassFusedMoE(MoE):
tuner_num_tokens: Optional[int] = None,
tuner_top_k: Optional[int] = None,
moe_output: Optional[torch.Tensor] = None,
enable_alltoall: Optional[bool] = None,
) -> torch.Tensor:
"""
Run MoE computation with Cutlass backend.
@ -429,6 +430,7 @@ class CutlassFusedMoE(MoE):
tuner_num_tokens: Number of tokens for profiling tuner (optional)
tuner_top_k: Top-k value for profiling tuner (optional)
moe_output: Pre-allocated output buffer (optional)
enable_alltoall: Whether alltoall communication is enabled (optional). If None, defaults to self.enable_alltoall.
Returns:
final_hidden_states: Output tensor from MoE computation
@ -441,6 +443,9 @@ class CutlassFusedMoE(MoE):
elif self.has_w4a16_mxfp4:
weight_dtype = torch.uint8
if enable_alltoall is None:
enable_alltoall = self.enable_alltoall
result = torch.ops.trtllm.fused_moe(
x,
token_selected_experts,
@ -462,7 +467,7 @@ class CutlassFusedMoE(MoE):
ep_rank=self.ep_rank,
cluster_size=self.cluster_size,
cluster_rank=self.cluster_rank,
enable_alltoall=self.enable_alltoall,
enable_alltoall=enable_alltoall,
use_deepseek_fp8_block_scale=self.has_deepseek_fp8_block_scales,
use_w4_group_scaling=self.has_w4afp8 or self.has_w4a16_mxfp4,
use_int8_woq_per_channel=self.has_int8_woq_per_channel,