From 34a730aaf7ddb1f053c0a7d14ade1ae4ef517291 Mon Sep 17 00:00:00 2001 From: Enwei Zhu <21126786+syuoni@users.noreply.github.com> Date: Thu, 29 Jan 2026 12:11:07 +0800 Subject: [PATCH] [None][fix] Fix enable_alltoall passed to CutlassFusedMoE (#11016) Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> --- examples/layer_wise_benchmarks/run.py | 6 +++- .../modules/fused_moe/configurable_moe.py | 33 +++---------------- .../modules/fused_moe/fused_moe_cutlass.py | 7 +++- 3 files changed, 16 insertions(+), 30 deletions(-) diff --git a/examples/layer_wise_benchmarks/run.py b/examples/layer_wise_benchmarks/run.py index 1c33881ca7..a560cff958 100644 --- a/examples/layer_wise_benchmarks/run.py +++ b/examples/layer_wise_benchmarks/run.py @@ -206,7 +206,11 @@ if args.run_type == "GEN": max(1, 20480 // ctx_seq_len_q), ) ctx_attn_workspace = torch.empty((0,), device="cuda", dtype=torch.int8) - with mock.patch.dict(os.environ, {"TRTLLM_FORCE_ALLTOALL_METHOD": "NotEnabled"}, clear=False): + with mock.patch.dict( + os.environ, + {"TRTLLM_FORCE_ALLTOALL_METHOD": "NotEnabled", "TRTLLM_FORCE_COMM_METHOD": "ALLGATHER"}, + clear=False, + ): ctx_runner = Runner( args.model, mapping, diff --git a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py index ce251234e5..f1de1b752b 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py @@ -313,7 +313,7 @@ class ConfigurableMoE(MoE): 2. Validates if current AllToAll strategy can be used for given workload 3. Falls back to AllGather if current strategy cannot be used (logs info message) - After calling this method, use _is_using_alltoall() to check which method is active. + After calling this method, use enable_alltoall to check which method is active. Args: all_rank_num_tokens: Token counts per rank @@ -348,23 +348,6 @@ class ConfigurableMoE(MoE): # Switch to AllGather (always works) self.comm = AllGatherReduceScatter(mapping=self.mapping) - def _is_using_alltoall(self) -> bool: - """ - Check if current communication strategy uses alltoall - - Returns: - True: Strategy uses alltoall (NVLINK, DeepEP, etc.) - False: Strategy uses allgather (AllGatherReduceScatter or None) - - Note: Can be called anytime. If comm is None, returns False (no alltoall). - Typically called after determine_communication_method() to get accurate result. - """ - if self.comm is None: - return False # No strategy means no alltoall - - # AllGather uses allgather, all others use alltoall - return not isinstance(self.comm, AllGatherReduceScatter) - def _create_comm_strategy_auto(self) -> Communication: """ Auto-create the best communication strategy based on hardware and configuration @@ -810,11 +793,7 @@ class ConfigurableMoE(MoE): Same as original implementation - chunking logic is backend-agnostic - Note: use_all_to_all is determined internally via _is_using_alltoall() - """ - # Determine if using alltoall - use_all_to_all = self._is_using_alltoall() # ========== Chunk preparation ========== if self.use_dp: # When using DP: need all ranks' token counts for reducescatter @@ -828,7 +807,7 @@ class ConfigurableMoE(MoE): chunk_size_list = all_rank_chunk_size_list[self.rank] # For alltoall, replace 0 with 1 (avoid empty tensor) - if use_all_to_all: + if self.enable_alltoall: all_rank_num_tokens_list = [ [1 if val == 0 else val for val in val_list] for val_list in all_rank_num_tokens_list @@ -842,7 +821,7 @@ class ConfigurableMoE(MoE): router_logits_list = router_logits.split(chunk_size_list) # Determine if we need multiple streams for overlapped execution - use_multi_stream = not use_all_to_all and self.aux_stream is not None + use_multi_stream = not self.enable_alltoall and self.aux_stream is not None # ========== Setup auxiliary stream ========== if use_multi_stream: @@ -1086,7 +1065,8 @@ class ConfigurableMoE(MoE): # Only the non-alltoall case is considered for profiling in the warmup phase. # Therefore, to get the correct tactics during the actual inference, the inputs to the tuner # should be the same as when not using alltoall. - if self._is_using_alltoall(): + kwargs["enable_alltoall"] = self.enable_alltoall + if self.enable_alltoall: if all_rank_num_tokens is not None: kwargs["tuner_num_tokens"] = sum(all_rank_num_tokens) else: @@ -1094,9 +1074,6 @@ class ConfigurableMoE(MoE): x.shape[0] * self.mapping.tp_size if x is not None else None ) kwargs["tuner_top_k"] = self.routing_method.top_k - else: - kwargs["tuner_num_tokens"] = None - kwargs["tuner_top_k"] = None # Get moe_output for NVLinkOneSided backend kwargs["moe_output"] = self._get_nvlink_onesided_moe_output( diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index e392bfadee..446cbec3a0 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -411,6 +411,7 @@ class CutlassFusedMoE(MoE): tuner_num_tokens: Optional[int] = None, tuner_top_k: Optional[int] = None, moe_output: Optional[torch.Tensor] = None, + enable_alltoall: Optional[bool] = None, ) -> torch.Tensor: """ Run MoE computation with Cutlass backend. @@ -429,6 +430,7 @@ class CutlassFusedMoE(MoE): tuner_num_tokens: Number of tokens for profiling tuner (optional) tuner_top_k: Top-k value for profiling tuner (optional) moe_output: Pre-allocated output buffer (optional) + enable_alltoall: Whether alltoall communication is enabled (optional). If None, defaults to self.enable_alltoall. Returns: final_hidden_states: Output tensor from MoE computation @@ -441,6 +443,9 @@ class CutlassFusedMoE(MoE): elif self.has_w4a16_mxfp4: weight_dtype = torch.uint8 + if enable_alltoall is None: + enable_alltoall = self.enable_alltoall + result = torch.ops.trtllm.fused_moe( x, token_selected_experts, @@ -462,7 +467,7 @@ class CutlassFusedMoE(MoE): ep_rank=self.ep_rank, cluster_size=self.cluster_size, cluster_rank=self.cluster_rank, - enable_alltoall=self.enable_alltoall, + enable_alltoall=enable_alltoall, use_deepseek_fp8_block_scale=self.has_deepseek_fp8_block_scales, use_w4_group_scaling=self.has_w4afp8 or self.has_w4a16_mxfp4, use_int8_woq_per_channel=self.has_int8_woq_per_channel,