diff --git a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py index 97e1293cfb..07fad25df4 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py @@ -32,6 +32,7 @@ from typing import Dict, List, Optional, Tuple, Union import torch +from tensorrt_llm._torch.expert_statistic import ExpertStatistic from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.modules.fused_moe.interface import MoE from tensorrt_llm._torch.modules.fused_moe.routing import BaseMoeRoutingMethod @@ -619,6 +620,10 @@ class ConfigurableMoE(MoE): else: token_selected_slots = token_selected_experts + if token_selected_slots is not None: + ExpertStatistic.set_layer(self.layer_idx) + ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots) + # ========== Step 3.5: Communication Prepare Phase (BEFORE quantization) ========== # NVLINK two-sided has a prepare phase to gather EPLB statistics @@ -647,6 +652,10 @@ class ConfigurableMoE(MoE): # supports_post_quant_dispatch checks strategy capability for the current quant mode supports_post_quant = self.comm.supports_post_quant_dispatch() + # Call dummy_allreduce before allgather for load balancing debug + if self.enable_dummy_allreduce: + self.dummy_allreduce() + if supports_post_quant: # ===== Post-quant flow: Quantize → Dispatch ===== @@ -710,6 +719,8 @@ class ConfigurableMoE(MoE): # ========== Step 9: Communication - Combine ========== if self.comm is not None: + if self.enable_dummy_allreduce: + self.dummy_allreduce() # Use unified combine interface (reads dispatch state from strategy) final_hidden_states = self.comm.combine(final_hidden_states) else: diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index d8f32c3ff8..f0a605cc07 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -159,10 +159,6 @@ class WideEPMoE(MoE): if not model_config.skip_create_weights_in_init: self.create_weights() - # Debug function for eliminating imbalance during performance analysis. - self.enable_dummy_allreduce = os.environ.get( - "TRTLLM_ENABLE_DUMMY_ALLREDUCE", "0") == "1" - # MoE op will be lazily initialized when first accessed (see moe_op_impl property) self._moe_op_impl = None @@ -342,16 +338,6 @@ class WideEPMoE(MoE): self._moe_op_impl = MoEOpSelector.select_op(self) return self._moe_op_impl - def dummy_allreduce(self): - """ - Debug function for eliminating imbalance during performance analysis. - Creates a small dummy tensor and performs allreduce to synchronize processes - and eliminate timing imbalances for more accurate profiling measurements. - """ - dummy_tensor = torch.zeros(4, dtype=torch.float32, device='cuda') - dummy_tensor = self.all_reduce(dummy_tensor) - return dummy_tensor - def reducescatter_or_allreduce( self, inputs, diff --git a/tensorrt_llm/_torch/modules/fused_moe/interface.py b/tensorrt_llm/_torch/modules/fused_moe/interface.py index e6d7797b9b..ead2510b0d 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/interface.py +++ b/tensorrt_llm/_torch/modules/fused_moe/interface.py @@ -1,3 +1,4 @@ +import os import weakref from abc import abstractmethod from enum import Enum, IntEnum @@ -200,11 +201,19 @@ class MoE(nn.Module): self.intermediate_size_per_partition = intermediate_size // self.tp_size self.all_reduce = None + # Debug function for eliminating imbalance during performance analysis. + self.enable_dummy_allreduce = os.environ.get( + "TRTLLM_ENABLE_DUMMY_ALLREDUCE", "0") == "1" if not self.use_dp and self.mapping.tp_size > 1: self.all_reduce = AllReduce( mapping=self.mapping, strategy=model_config.allreduce_strategy, dtype=self.dtype) + elif self.enable_dummy_allreduce: + from tensorrt_llm.functional import AllReduceStrategy + self.all_reduce = AllReduce(mapping=self.mapping, + strategy=AllReduceStrategy.NCCL, + dtype=self.dtype) # Initialize load balancer related attributes if init_load_balancer: @@ -748,3 +757,14 @@ class MoE(nn.Module): elif self.reduce_results: outputs = self.all_reduce(inputs) return outputs + + def dummy_allreduce(self): + assert self.enable_dummy_allreduce and self.all_reduce is not None, "Dummy allreduce is not enabled" + """ + Debug function for eliminating imbalance during performance analysis. + Creates a small dummy tensor and performs allreduce to synchronize processes + and eliminate timing imbalances for more accurate profiling measurements. + """ + dummy_tensor = torch.zeros(4, dtype=torch.float32, device="cuda") + dummy_tensor = self.all_reduce(dummy_tensor) + return dummy_tensor