mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][feat] Add ExpertStatistic and DUMMY_ALLREDUCE for configurable_moe (#10401)
Signed-off-by: Xianjie <5410381+qiaoxj07@users.noreply.github.com>
This commit is contained in:
parent
5e0dbba0c9
commit
3a9a00b544
@ -32,6 +32,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from tensorrt_llm._torch.expert_statistic import ExpertStatistic
|
||||||
from tensorrt_llm._torch.model_config import ModelConfig
|
from tensorrt_llm._torch.model_config import ModelConfig
|
||||||
from tensorrt_llm._torch.modules.fused_moe.interface import MoE
|
from tensorrt_llm._torch.modules.fused_moe.interface import MoE
|
||||||
from tensorrt_llm._torch.modules.fused_moe.routing import BaseMoeRoutingMethod
|
from tensorrt_llm._torch.modules.fused_moe.routing import BaseMoeRoutingMethod
|
||||||
@ -619,6 +620,10 @@ class ConfigurableMoE(MoE):
|
|||||||
else:
|
else:
|
||||||
token_selected_slots = token_selected_experts
|
token_selected_slots = token_selected_experts
|
||||||
|
|
||||||
|
if token_selected_slots is not None:
|
||||||
|
ExpertStatistic.set_layer(self.layer_idx)
|
||||||
|
ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots)
|
||||||
|
|
||||||
# ========== Step 3.5: Communication Prepare Phase (BEFORE quantization) ==========
|
# ========== Step 3.5: Communication Prepare Phase (BEFORE quantization) ==========
|
||||||
# NVLINK two-sided has a prepare phase to gather EPLB statistics
|
# NVLINK two-sided has a prepare phase to gather EPLB statistics
|
||||||
|
|
||||||
@ -647,6 +652,10 @@ class ConfigurableMoE(MoE):
|
|||||||
# supports_post_quant_dispatch checks strategy capability for the current quant mode
|
# supports_post_quant_dispatch checks strategy capability for the current quant mode
|
||||||
supports_post_quant = self.comm.supports_post_quant_dispatch()
|
supports_post_quant = self.comm.supports_post_quant_dispatch()
|
||||||
|
|
||||||
|
# Call dummy_allreduce before allgather for load balancing debug
|
||||||
|
if self.enable_dummy_allreduce:
|
||||||
|
self.dummy_allreduce()
|
||||||
|
|
||||||
if supports_post_quant:
|
if supports_post_quant:
|
||||||
# ===== Post-quant flow: Quantize → Dispatch =====
|
# ===== Post-quant flow: Quantize → Dispatch =====
|
||||||
|
|
||||||
@ -710,6 +719,8 @@ class ConfigurableMoE(MoE):
|
|||||||
|
|
||||||
# ========== Step 9: Communication - Combine ==========
|
# ========== Step 9: Communication - Combine ==========
|
||||||
if self.comm is not None:
|
if self.comm is not None:
|
||||||
|
if self.enable_dummy_allreduce:
|
||||||
|
self.dummy_allreduce()
|
||||||
# Use unified combine interface (reads dispatch state from strategy)
|
# Use unified combine interface (reads dispatch state from strategy)
|
||||||
final_hidden_states = self.comm.combine(final_hidden_states)
|
final_hidden_states = self.comm.combine(final_hidden_states)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -159,10 +159,6 @@ class WideEPMoE(MoE):
|
|||||||
if not model_config.skip_create_weights_in_init:
|
if not model_config.skip_create_weights_in_init:
|
||||||
self.create_weights()
|
self.create_weights()
|
||||||
|
|
||||||
# Debug function for eliminating imbalance during performance analysis.
|
|
||||||
self.enable_dummy_allreduce = os.environ.get(
|
|
||||||
"TRTLLM_ENABLE_DUMMY_ALLREDUCE", "0") == "1"
|
|
||||||
|
|
||||||
# MoE op will be lazily initialized when first accessed (see moe_op_impl property)
|
# MoE op will be lazily initialized when first accessed (see moe_op_impl property)
|
||||||
self._moe_op_impl = None
|
self._moe_op_impl = None
|
||||||
|
|
||||||
@ -342,16 +338,6 @@ class WideEPMoE(MoE):
|
|||||||
self._moe_op_impl = MoEOpSelector.select_op(self)
|
self._moe_op_impl = MoEOpSelector.select_op(self)
|
||||||
return self._moe_op_impl
|
return self._moe_op_impl
|
||||||
|
|
||||||
def dummy_allreduce(self):
|
|
||||||
"""
|
|
||||||
Debug function for eliminating imbalance during performance analysis.
|
|
||||||
Creates a small dummy tensor and performs allreduce to synchronize processes
|
|
||||||
and eliminate timing imbalances for more accurate profiling measurements.
|
|
||||||
"""
|
|
||||||
dummy_tensor = torch.zeros(4, dtype=torch.float32, device='cuda')
|
|
||||||
dummy_tensor = self.all_reduce(dummy_tensor)
|
|
||||||
return dummy_tensor
|
|
||||||
|
|
||||||
def reducescatter_or_allreduce(
|
def reducescatter_or_allreduce(
|
||||||
self,
|
self,
|
||||||
inputs,
|
inputs,
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import weakref
|
import weakref
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from enum import Enum, IntEnum
|
from enum import Enum, IntEnum
|
||||||
@ -200,11 +201,19 @@ class MoE(nn.Module):
|
|||||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||||
|
|
||||||
self.all_reduce = None
|
self.all_reduce = None
|
||||||
|
# Debug function for eliminating imbalance during performance analysis.
|
||||||
|
self.enable_dummy_allreduce = os.environ.get(
|
||||||
|
"TRTLLM_ENABLE_DUMMY_ALLREDUCE", "0") == "1"
|
||||||
if not self.use_dp and self.mapping.tp_size > 1:
|
if not self.use_dp and self.mapping.tp_size > 1:
|
||||||
self.all_reduce = AllReduce(
|
self.all_reduce = AllReduce(
|
||||||
mapping=self.mapping,
|
mapping=self.mapping,
|
||||||
strategy=model_config.allreduce_strategy,
|
strategy=model_config.allreduce_strategy,
|
||||||
dtype=self.dtype)
|
dtype=self.dtype)
|
||||||
|
elif self.enable_dummy_allreduce:
|
||||||
|
from tensorrt_llm.functional import AllReduceStrategy
|
||||||
|
self.all_reduce = AllReduce(mapping=self.mapping,
|
||||||
|
strategy=AllReduceStrategy.NCCL,
|
||||||
|
dtype=self.dtype)
|
||||||
|
|
||||||
# Initialize load balancer related attributes
|
# Initialize load balancer related attributes
|
||||||
if init_load_balancer:
|
if init_load_balancer:
|
||||||
@ -748,3 +757,14 @@ class MoE(nn.Module):
|
|||||||
elif self.reduce_results:
|
elif self.reduce_results:
|
||||||
outputs = self.all_reduce(inputs)
|
outputs = self.all_reduce(inputs)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def dummy_allreduce(self):
|
||||||
|
assert self.enable_dummy_allreduce and self.all_reduce is not None, "Dummy allreduce is not enabled"
|
||||||
|
"""
|
||||||
|
Debug function for eliminating imbalance during performance analysis.
|
||||||
|
Creates a small dummy tensor and performs allreduce to synchronize processes
|
||||||
|
and eliminate timing imbalances for more accurate profiling measurements.
|
||||||
|
"""
|
||||||
|
dummy_tensor = torch.zeros(4, dtype=torch.float32, device="cuda")
|
||||||
|
dummy_tensor = self.all_reduce(dummy_tensor)
|
||||||
|
return dummy_tensor
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user