[None][feat] Add ExpertStatistic and DUMMY_ALLREDUCE for configurable_moe (#10401)

Signed-off-by: Xianjie <5410381+qiaoxj07@users.noreply.github.com>
This commit is contained in:
Xianjie Qiao 2026-01-12 14:10:31 +08:00 committed by GitHub
parent 5e0dbba0c9
commit 3a9a00b544
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 31 additions and 14 deletions

View File

@ -32,6 +32,7 @@ from typing import Dict, List, Optional, Tuple, Union
import torch
from tensorrt_llm._torch.expert_statistic import ExpertStatistic
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.modules.fused_moe.interface import MoE
from tensorrt_llm._torch.modules.fused_moe.routing import BaseMoeRoutingMethod
@ -619,6 +620,10 @@ class ConfigurableMoE(MoE):
else:
token_selected_slots = token_selected_experts
if token_selected_slots is not None:
ExpertStatistic.set_layer(self.layer_idx)
ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots)
# ========== Step 3.5: Communication Prepare Phase (BEFORE quantization) ==========
# NVLINK two-sided has a prepare phase to gather EPLB statistics
@ -647,6 +652,10 @@ class ConfigurableMoE(MoE):
# supports_post_quant_dispatch checks strategy capability for the current quant mode
supports_post_quant = self.comm.supports_post_quant_dispatch()
# Call dummy_allreduce before allgather for load balancing debug
if self.enable_dummy_allreduce:
self.dummy_allreduce()
if supports_post_quant:
# ===== Post-quant flow: Quantize → Dispatch =====
@ -710,6 +719,8 @@ class ConfigurableMoE(MoE):
# ========== Step 9: Communication - Combine ==========
if self.comm is not None:
if self.enable_dummy_allreduce:
self.dummy_allreduce()
# Use unified combine interface (reads dispatch state from strategy)
final_hidden_states = self.comm.combine(final_hidden_states)
else:

View File

@ -159,10 +159,6 @@ class WideEPMoE(MoE):
if not model_config.skip_create_weights_in_init:
self.create_weights()
# Debug function for eliminating imbalance during performance analysis.
self.enable_dummy_allreduce = os.environ.get(
"TRTLLM_ENABLE_DUMMY_ALLREDUCE", "0") == "1"
# MoE op will be lazily initialized when first accessed (see moe_op_impl property)
self._moe_op_impl = None
@ -342,16 +338,6 @@ class WideEPMoE(MoE):
self._moe_op_impl = MoEOpSelector.select_op(self)
return self._moe_op_impl
def dummy_allreduce(self):
"""
Debug function for eliminating imbalance during performance analysis.
Creates a small dummy tensor and performs allreduce to synchronize processes
and eliminate timing imbalances for more accurate profiling measurements.
"""
dummy_tensor = torch.zeros(4, dtype=torch.float32, device='cuda')
dummy_tensor = self.all_reduce(dummy_tensor)
return dummy_tensor
def reducescatter_or_allreduce(
self,
inputs,

View File

@ -1,3 +1,4 @@
import os
import weakref
from abc import abstractmethod
from enum import Enum, IntEnum
@ -200,11 +201,19 @@ class MoE(nn.Module):
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.all_reduce = None
# Debug function for eliminating imbalance during performance analysis.
self.enable_dummy_allreduce = os.environ.get(
"TRTLLM_ENABLE_DUMMY_ALLREDUCE", "0") == "1"
if not self.use_dp and self.mapping.tp_size > 1:
self.all_reduce = AllReduce(
mapping=self.mapping,
strategy=model_config.allreduce_strategy,
dtype=self.dtype)
elif self.enable_dummy_allreduce:
from tensorrt_llm.functional import AllReduceStrategy
self.all_reduce = AllReduce(mapping=self.mapping,
strategy=AllReduceStrategy.NCCL,
dtype=self.dtype)
# Initialize load balancer related attributes
if init_load_balancer:
@ -748,3 +757,14 @@ class MoE(nn.Module):
elif self.reduce_results:
outputs = self.all_reduce(inputs)
return outputs
def dummy_allreduce(self):
assert self.enable_dummy_allreduce and self.all_reduce is not None, "Dummy allreduce is not enabled"
"""
Debug function for eliminating imbalance during performance analysis.
Creates a small dummy tensor and performs allreduce to synchronize processes
and eliminate timing imbalances for more accurate profiling measurements.
"""
dummy_tensor = torch.zeros(4, dtype=torch.float32, device="cuda")
dummy_tensor = self.all_reduce(dummy_tensor)
return dummy_tensor