mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Add ExpertStatistic and DUMMY_ALLREDUCE for configurable_moe (#10401)
Signed-off-by: Xianjie <5410381+qiaoxj07@users.noreply.github.com>
This commit is contained in:
parent
5e0dbba0c9
commit
3a9a00b544
@ -32,6 +32,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.expert_statistic import ExpertStatistic
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.modules.fused_moe.interface import MoE
|
||||
from tensorrt_llm._torch.modules.fused_moe.routing import BaseMoeRoutingMethod
|
||||
@ -619,6 +620,10 @@ class ConfigurableMoE(MoE):
|
||||
else:
|
||||
token_selected_slots = token_selected_experts
|
||||
|
||||
if token_selected_slots is not None:
|
||||
ExpertStatistic.set_layer(self.layer_idx)
|
||||
ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots)
|
||||
|
||||
# ========== Step 3.5: Communication Prepare Phase (BEFORE quantization) ==========
|
||||
# NVLINK two-sided has a prepare phase to gather EPLB statistics
|
||||
|
||||
@ -647,6 +652,10 @@ class ConfigurableMoE(MoE):
|
||||
# supports_post_quant_dispatch checks strategy capability for the current quant mode
|
||||
supports_post_quant = self.comm.supports_post_quant_dispatch()
|
||||
|
||||
# Call dummy_allreduce before allgather for load balancing debug
|
||||
if self.enable_dummy_allreduce:
|
||||
self.dummy_allreduce()
|
||||
|
||||
if supports_post_quant:
|
||||
# ===== Post-quant flow: Quantize → Dispatch =====
|
||||
|
||||
@ -710,6 +719,8 @@ class ConfigurableMoE(MoE):
|
||||
|
||||
# ========== Step 9: Communication - Combine ==========
|
||||
if self.comm is not None:
|
||||
if self.enable_dummy_allreduce:
|
||||
self.dummy_allreduce()
|
||||
# Use unified combine interface (reads dispatch state from strategy)
|
||||
final_hidden_states = self.comm.combine(final_hidden_states)
|
||||
else:
|
||||
|
||||
@ -159,10 +159,6 @@ class WideEPMoE(MoE):
|
||||
if not model_config.skip_create_weights_in_init:
|
||||
self.create_weights()
|
||||
|
||||
# Debug function for eliminating imbalance during performance analysis.
|
||||
self.enable_dummy_allreduce = os.environ.get(
|
||||
"TRTLLM_ENABLE_DUMMY_ALLREDUCE", "0") == "1"
|
||||
|
||||
# MoE op will be lazily initialized when first accessed (see moe_op_impl property)
|
||||
self._moe_op_impl = None
|
||||
|
||||
@ -342,16 +338,6 @@ class WideEPMoE(MoE):
|
||||
self._moe_op_impl = MoEOpSelector.select_op(self)
|
||||
return self._moe_op_impl
|
||||
|
||||
def dummy_allreduce(self):
|
||||
"""
|
||||
Debug function for eliminating imbalance during performance analysis.
|
||||
Creates a small dummy tensor and performs allreduce to synchronize processes
|
||||
and eliminate timing imbalances for more accurate profiling measurements.
|
||||
"""
|
||||
dummy_tensor = torch.zeros(4, dtype=torch.float32, device='cuda')
|
||||
dummy_tensor = self.all_reduce(dummy_tensor)
|
||||
return dummy_tensor
|
||||
|
||||
def reducescatter_or_allreduce(
|
||||
self,
|
||||
inputs,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import weakref
|
||||
from abc import abstractmethod
|
||||
from enum import Enum, IntEnum
|
||||
@ -200,11 +201,19 @@ class MoE(nn.Module):
|
||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||
|
||||
self.all_reduce = None
|
||||
# Debug function for eliminating imbalance during performance analysis.
|
||||
self.enable_dummy_allreduce = os.environ.get(
|
||||
"TRTLLM_ENABLE_DUMMY_ALLREDUCE", "0") == "1"
|
||||
if not self.use_dp and self.mapping.tp_size > 1:
|
||||
self.all_reduce = AllReduce(
|
||||
mapping=self.mapping,
|
||||
strategy=model_config.allreduce_strategy,
|
||||
dtype=self.dtype)
|
||||
elif self.enable_dummy_allreduce:
|
||||
from tensorrt_llm.functional import AllReduceStrategy
|
||||
self.all_reduce = AllReduce(mapping=self.mapping,
|
||||
strategy=AllReduceStrategy.NCCL,
|
||||
dtype=self.dtype)
|
||||
|
||||
# Initialize load balancer related attributes
|
||||
if init_load_balancer:
|
||||
@ -748,3 +757,14 @@ class MoE(nn.Module):
|
||||
elif self.reduce_results:
|
||||
outputs = self.all_reduce(inputs)
|
||||
return outputs
|
||||
|
||||
def dummy_allreduce(self):
|
||||
assert self.enable_dummy_allreduce and self.all_reduce is not None, "Dummy allreduce is not enabled"
|
||||
"""
|
||||
Debug function for eliminating imbalance during performance analysis.
|
||||
Creates a small dummy tensor and performs allreduce to synchronize processes
|
||||
and eliminate timing imbalances for more accurate profiling measurements.
|
||||
"""
|
||||
dummy_tensor = torch.zeros(4, dtype=torch.float32, device="cuda")
|
||||
dummy_tensor = self.all_reduce(dummy_tensor)
|
||||
return dummy_tensor
|
||||
|
||||
Loading…
Reference in New Issue
Block a user