[None][perf] Improve the performance of online EPLB on Hopper by better overlapping (#6624)

Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
This commit is contained in:
Jinyang Yuan 2025-08-12 09:25:13 +08:00 committed by GitHub
parent be9dd4713c
commit ead89a0e40
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 302 additions and 229 deletions

View File

@ -453,7 +453,7 @@ class Deepseekv3MoE(nn.Module):
False, # In both lowlatency and attentionDP modes, FusedMoE skips the inop allreduce.
model_config=model_config,
override_quant_config=override_quant_config,
aux_stream=aux_stream_dict[AuxStreamType.MoeChunkingOverlap],
aux_stream_dict=aux_stream_dict,
layer_idx=layer_idx)
self.mapping = model_config.mapping
@ -1049,11 +1049,12 @@ class DeepseekV3Model(DecoderModel):
config = model_config.pretrained_config
self.vocab_size = config.vocab_size
self.num_hidden_layers = config.num_hidden_layers
aux_stream_list = [torch.cuda.Stream() for _ in range(2)]
aux_stream_list = [torch.cuda.Stream() for _ in range(3)]
self.aux_stream_dict = {
AuxStreamType.Attention: aux_stream_list[0],
AuxStreamType.MoeShared: aux_stream_list[0],
AuxStreamType.MoeChunkingOverlap: aux_stream_list[1],
AuxStreamType.MoeBalancer: aux_stream_list[2],
}
self.embed_tokens = Embedding(

View File

@ -15,6 +15,7 @@ from ..modules.embedding import Embedding
from ..modules.fused_moe import RenormalizeMoeRoutingMethod, create_moe
from ..modules.linear import Linear
from ..modules.rms_norm import RMSNorm
from ..utils import AuxStreamType
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
register_auto_model)
@ -49,7 +50,7 @@ class MixtralMoE(nn.Module):
routing_method=RenormalizeMoeRoutingMethod(top_k=self.top_k),
hidden_size=self.hidden_dim,
intermediate_size=self.ffn_dim,
aux_stream=aux_stream,
aux_stream_dict={AuxStreamType.MoeChunkingOverlap: aux_stream},
dtype=config.torch_dtype,
reduce_results=reduce_results,
model_config=model_config,

View File

@ -22,6 +22,7 @@ from ..modules.fused_moe import (BaseMoeRoutingMethod,
from ..modules.linear import TensorParallelMode
from ..modules.rms_norm import RMSNorm
from ..speculative import SpecMetadata
from ..utils import AuxStreamType
from .modeling_qwen3 import Qwen3Attention
from .modeling_speculative import SpecDecOneEngineForCausalLM
from .modeling_utils import DecoderModel, EagerFusionConfig, register_auto_model
@ -107,7 +108,7 @@ class Qwen3MoE(nn.Module):
routing_method=self.gate.routing_method,
hidden_size=self.hidden_dim,
intermediate_size=self.moe_intermediate_size,
aux_stream=aux_stream,
aux_stream_dict={AuxStreamType.MoeChunkingOverlap: aux_stream},
dtype=config.torch_dtype,
reduce_results=False,
model_config=model_config,

View File

@ -17,6 +17,7 @@ from ..modules.fused_moe import DefaultMoeRoutingMethod, create_moe
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import Linear, TensorParallelMode
from ..modules.rms_norm import RMSNorm
from ..utils import AuxStreamType
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
register_auto_model)
@ -53,7 +54,7 @@ class QwenMoE(nn.Module):
routing_method=DefaultMoeRoutingMethod(top_k=self.top_k),
hidden_size=self.hidden_dim,
intermediate_size=self.moe_intermediate_size,
aux_stream=aux_stream,
aux_stream_dict={AuxStreamType.MoeChunkingOverlap: aux_stream},
dtype=config.torch_dtype,
reduce_results=reduce_results,
model_config=model_config,

View File

@ -1,4 +1,4 @@
from typing import Optional, Type
from typing import Dict, Optional, Type
import torch
@ -6,6 +6,7 @@ from tensorrt_llm.logger import logger
from tensorrt_llm.models.modeling_utils import QuantConfig
from ...model_config import ModelConfig
from ...utils import AuxStreamType
from .fused_moe_cute_dsl import CuteDslFusedMoE
from .fused_moe_cutlass import CutlassFusedMoE
from .fused_moe_deepgemm import DeepGemmFusedMoE
@ -66,7 +67,7 @@ def create_moe(
reduce_results: bool = False,
model_config: ModelConfig = ModelConfig(),
override_quant_config: Optional[QuantConfig] = None,
aux_stream: Optional[torch.cuda.Stream] = None,
aux_stream_dict: Optional[Dict[AuxStreamType, torch.cuda.Stream]] = None,
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.VANILLA,
bias: bool = False,
apply_router_weight_on_input: bool = False,
@ -123,7 +124,7 @@ def create_moe(
dtype=dtype,
reduce_results=reduce_results,
model_config=model_config,
aux_stream=aux_stream,
aux_stream_dict=aux_stream_dict,
weight_loading_mode=weight_loading_mode,
bias=bias,
apply_router_weight_on_input=apply_router_weight_on_input,
@ -141,7 +142,7 @@ def create_moe(
dtype=dtype,
reduce_results=reduce_results,
model_config=model_config,
aux_stream=aux_stream,
aux_stream_dict=aux_stream_dict,
weight_loading_mode=weight_loading_mode,
apply_router_weight_on_input=apply_router_weight_on_input,
layer_idx=layer_idx,
@ -169,7 +170,7 @@ def create_moe(
dtype=dtype,
reduce_results=reduce_results,
model_config=model_config,
aux_stream=aux_stream,
aux_stream_dict=aux_stream_dict,
weight_loading_mode=weight_loading_mode,
apply_router_weight_on_input=apply_router_weight_on_input,
layer_idx=layer_idx,
@ -183,7 +184,7 @@ def create_moe(
dtype=dtype,
reduce_results=reduce_results,
model_config=model_config,
aux_stream=aux_stream,
aux_stream_dict=aux_stream_dict,
weight_loading_mode=weight_loading_mode,
apply_router_weight_on_input=apply_router_weight_on_input,
layer_idx=layer_idx,
@ -199,7 +200,6 @@ def create_moe(
dtype=dtype,
reduce_results=reduce_results,
model_config=model_config,
aux_stream=aux_stream,
weight_loading_mode=weight_loading_mode,
bias=bias,
layer_idx=layer_idx,

View File

@ -1,5 +1,5 @@
import math
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union
import torch
import torch.nn.functional as F
@ -7,7 +7,7 @@ import torch.nn.functional as F
from tensorrt_llm._utils import get_sm_version
from ...model_config import ModelConfig
from ...utils import Fp4QuantizedTensor
from ...utils import AuxStreamType, Fp4QuantizedTensor
from .fused_moe_cutlass import CutlassFusedMoE
from .quantization import MoEWeightLoadingMode
from .routing import BaseMoeRoutingMethod
@ -97,7 +97,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
top_k (int): Number of top experts to select for each input token.
hidden_size (int): Size of the hidden state.
intermediate_size (int): Size of the intermediate state.
aux_stream (Optional[torch.cuda.Stream]): Auxiliary CUDA stream to overlap chunks.
aux_stream_dict (Optional[Dict[AuxStreamType, torch.cuda.Stream]]): Auxiliary CUDA streams for overlapping.
dtype (Optional[torch.dtype]): Data type for the weights.
reduce_results (bool): Whether to reduce the results across devices.
model_config (ModelConfig): Configuration object for the model.
@ -118,7 +118,8 @@ class CuteDslFusedMoE(CutlassFusedMoE):
dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
model_config: ModelConfig = ModelConfig(),
aux_stream: Optional[torch.cuda.Stream] = None,
aux_stream_dict: Optional[Dict[AuxStreamType,
torch.cuda.Stream]] = None,
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
VANILLA,
apply_router_weight_on_input: bool = False,
@ -133,7 +134,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
dtype=dtype,
reduce_results=reduce_results,
model_config=model_config,
aux_stream=aux_stream,
aux_stream_dict=aux_stream_dict,
weight_loading_mode=weight_loading_mode,
apply_router_weight_on_input=apply_router_weight_on_input,
layer_idx=layer_idx,

View File

@ -9,7 +9,8 @@ from tensorrt_llm.math_utils import pad_up
from ...distributed import allgather
from ...model_config import ModelConfig
from ...utils import EventType, Fp4QuantizedTensor, ceil_div, swizzle_sf
from ...utils import (AuxStreamType, EventType, Fp4QuantizedTensor, ceil_div,
swizzle_sf)
from .interface import MoE
# isort: off
@ -31,7 +32,7 @@ class CutlassFusedMoE(MoE):
top_k (int): Number of top experts to select for each input token.
hidden_size (int): Size of the hidden state.
intermediate_size (int): Size of the intermediate state.
aux_stream (Optional[torch.cuda.Stream]): Auxiliary CUDA stream to overlap chunks.
aux_stream_dict (Optional[Dict[AuxStreamType, torch.cuda.Stream]]): Auxiliary CUDA streams for overlapping.
dtype (Optional[torch.dtype]): Data type for the weights.
reduce_results (bool): Whether to reduce the results across devices.
model_config (ModelConfig): Configuration object for the model.
@ -60,7 +61,8 @@ class CutlassFusedMoE(MoE):
dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
model_config: ModelConfig = ModelConfig(),
aux_stream: Optional[torch.cuda.Stream] = None,
aux_stream_dict: Optional[Dict[AuxStreamType,
torch.cuda.Stream]] = None,
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
VANILLA,
bias: bool = False,
@ -115,8 +117,10 @@ class CutlassFusedMoE(MoE):
self.moe_max_num_tokens = model_config.moe_max_num_tokens or max_num_tokens
# The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied
if self.moe_max_num_tokens < max_num_tokens:
self.aux_stream = aux_stream if aux_stream is not None else torch.cuda.Stream(
)
self.aux_stream = aux_stream_dict[
AuxStreamType.
MoeChunkingOverlap] if aux_stream_dict is not None else torch.cuda.Stream(
)
self.event_dict = {
key: torch.cuda.Event()
for key in [EventType.Main, EventType.MoeChunkingOverlap]

View File

@ -1,4 +1,4 @@
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union
import torch
import torch.nn.functional as F
@ -11,7 +11,7 @@ from tensorrt_llm._utils import nvtx_range
from ...distributed import allgather
from ...model_config import ModelConfig
from ...utils import Fp4QuantizedTensor
from ...utils import AuxStreamType, Fp4QuantizedTensor
from .fused_moe_cutlass import CutlassFusedMoE
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm,
MoEWeightLoadingMode, UnquantizedFusedMoEMethod)
@ -299,7 +299,7 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
top_k (int): Number of top experts to select for each input token.
hidden_size (int): Size of the hidden state.
intermediate_size (int): Size of the intermediate state.
aux_stream (Optional[torch.cuda.Stream]): Auxiliary CUDA stream to overlap chunks.
aux_stream_dict (Optional[Dict[AuxStreamType, torch.cuda.Stream]]): Auxiliary CUDA streams for overlapping.
dtype (Optional[torch.dtype]): Data type for the weights.
reduce_results (bool): Whether to reduce the results across devices.
model_config (ModelConfig): Configuration object for the model.
@ -320,7 +320,8 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
model_config: ModelConfig = ModelConfig(),
aux_stream: Optional[torch.cuda.Stream] = None,
aux_stream_dict: Optional[Dict[AuxStreamType,
torch.cuda.Stream]] = None,
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
VANILLA,
apply_router_weight_on_input: bool = False,
@ -335,7 +336,7 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
dtype=dtype,
reduce_results=reduce_results,
model_config=model_config,
aux_stream=aux_stream,
aux_stream_dict=aux_stream_dict,
weight_loading_mode=weight_loading_mode,
apply_router_weight_on_input=apply_router_weight_on_input,
layer_idx=layer_idx,

View File

@ -1207,7 +1207,6 @@ class TritonFusedMoE(MoE):
dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
model_config: ModelConfig = ModelConfig(),
aux_stream: Optional[torch.cuda.Stream] = None,
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
VANILLA,
bias: bool = False,

View File

@ -23,7 +23,6 @@ class TRTLLMGenFusedMoE(MoE):
top_k (int): Number of top experts to select for each input token.
hidden_size (int): Size of the hidden state.
intermediate_size (int): Size of the intermediate state.
aux_stream (Optional[torch.cuda.Stream]): Auxiliary CUDA stream to overlap chunks.
dtype (Optional[torch.dtype]): Data type for the weights.
reduce_results (bool): Whether to reduce the results across devices.
model_config (ModelConfig): Configuration object for the model.

View File

@ -6,12 +6,13 @@ import torch
from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo
from tensorrt_llm._utils import logger
from tensorrt_llm.functional import AllReduceStrategy
from tensorrt_llm.mapping import Mapping
from ...distributed import allgather, reducescatter
from ...distributed import AllReduce, allgather, reducescatter
from ...expert_statistic import ExpertStatistic
from ...model_config import ModelConfig
from ...utils import EventType, Fp4QuantizedTensor
from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor
from .deep_ep_utils import buffer_pool, deep_ep_installed
from .interface import MoE
from .moe_load_balancer import get_moe_load_balancer
@ -43,7 +44,7 @@ class WideEPMoE(MoE):
top_k (int): Number of top experts to select for each input token.
hidden_size (int): Size of the hidden state.
intermediate_size (int): Size of the intermediate state.
aux_stream (Optional[torch.cuda.Stream]): Auxiliary CUDA stream to overlap chunks.
aux_stream_dict (Optional[Dict[AuxStreamType, torch.cuda.Stream]]): Auxiliary CUDA streams for overlapping.
dtype (Optional[torch.dtype]): Data type for the weights.
reduce_results (bool): Whether to reduce the results across devices.
model_config (ModelConfig): Configuration object for the model.
@ -64,7 +65,8 @@ class WideEPMoE(MoE):
dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
model_config: ModelConfig = ModelConfig(),
aux_stream: Optional[torch.cuda.Stream] = None,
aux_stream_dict: Optional[Dict[AuxStreamType,
torch.cuda.Stream]] = None,
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
VANILLA,
apply_router_weight_on_input: bool = False,
@ -106,7 +108,11 @@ class WideEPMoE(MoE):
top_k = self.routing_method.experts_per_token
self.expert_size_per_partition = moe_load_balancer_config.num_local_slots
self.layer_load_balancer = moe_load_balancer.add_layer(
self.num_experts, top_k, self.expert_size_per_partition)
self.num_experts,
top_k,
self.expert_size_per_partition,
aux_stream=None if aux_stream_dict is None else
aux_stream_dict[AuxStreamType.MoeBalancer])
self.repeat_count = self.layer_load_balancer.get_repeat_count()
loaded_initial_global_assignments = moe_load_balancer_config.get_layer_initial_global_assignments(
self.layer_idx)
@ -130,6 +136,12 @@ class WideEPMoE(MoE):
assert num_experts % self.ep_size == 0
self.expert_size_per_partition = num_experts // self.ep_size
self.num_slots = num_experts
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
):
self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=AllReduceStrategy.NCCL)
else:
self.allreduce = None
self.slot_start = self.ep_rank * self.expert_size_per_partition
self.slot_end = self.slot_start + self.expert_size_per_partition
@ -144,8 +156,10 @@ class WideEPMoE(MoE):
self.moe_max_num_tokens = model_config.moe_max_num_tokens if model_config.moe_max_num_tokens is not None else max_num_tokens
# The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied
if self.moe_max_num_tokens < max_num_tokens:
self.aux_stream = aux_stream if aux_stream is not None else torch.cuda.Stream(
)
self.aux_stream = aux_stream_dict[
AuxStreamType.
MoeChunkingOverlap] if aux_stream_dict is not None else torch.cuda.Stream(
)
self.event_dict = {
key: torch.cuda.Event()
for key in [EventType.Main, EventType.MoeChunkingOverlap]
@ -371,9 +385,8 @@ class WideEPMoE(MoE):
is_first_call, is_last_call = repeating_info
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
) and is_first_call:
self.layer_load_balancer.wait_for_gpu_stage()
if self.layer_load_balancer and is_first_call:
self.layer_load_balancer.start_wait_gpu_stage()
use_deepseek_fp8_block_scale = False
use_w4_group_scaling = False
@ -401,40 +414,32 @@ class WideEPMoE(MoE):
else:
token_final_scales = None
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
) and is_first_call:
self.layer_load_balancer.maybe_cudagraph_done_wait()
use_allgather = not use_all_to_all
loadbalancer_local_statistic_info = None
gathered_loadbalancer_local_statistic_info = None
token_selected_experts_for_statistic = None
if self.layer_load_balancer is None:
token_selected_slots = token_selected_experts
else:
if not self.layer_load_balancer.is_static_routing(
) and use_all_to_all:
self.layer_load_balancer.local_statistic(
if self.layer_load_balancer:
if is_first_call:
self.layer_load_balancer.done_wait_gpu_stage()
if use_all_to_all and self.alltoall_method_type == AlltoallMethodType.MNNVL:
self.layer_load_balancer.update_local_statistic(
token_selected_experts,
is_first_stage=is_first_call,
is_last_stage=is_last_call)
else:
self.layer_load_balancer.update_statistic_with_local_ids(
token_selected_experts,
is_first_stage=is_first_call,
is_last_stage=is_last_call,
allreduce=self.allreduce)
token_selected_slots = self.layer_load_balancer.route(
token_selected_experts, self.use_dp)
if not self.layer_load_balancer.is_static_routing():
# split into two part to get possible overlap with load balancer routing
if use_all_to_all:
if is_last_call:
loadbalancer_local_statistic_info = self.layer_load_balancer.get_local_statistic_tensor(
)
else:
token_selected_experts_for_statistic = token_selected_experts
else:
token_selected_slots = token_selected_experts
# If load balancer is disabled, the statistics are collected from expert IDs.
# If load balancer is enabled, the statistics are collected from expert slot IDs.
ExpertStatistic.set_layer(self.layer_idx)
ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots)
use_allgather = not use_all_to_all
# If alltoall is disabled, we need also disable use_postquant_alltoall
use_postquant_alltoall = self.use_postquant_alltoall and use_all_to_all
@ -456,6 +461,11 @@ class WideEPMoE(MoE):
self.dummy_allreduce()
token_count = x.shape[0]
alltoall_info = None
if is_last_call:
loadbalancer_local_statistic_info = self.layer_load_balancer.get_local_statistic_tensor(
)
else:
loadbalancer_local_statistic_info = None
x, token_selected_slots, token_final_scales, gathered_loadbalancer_local_statistic_info, alltoall_info = \
self.alltoall_prepare_maybe_dispatch(all_rank_max_num_tokens,
x,
@ -463,6 +473,11 @@ class WideEPMoE(MoE):
token_final_scales,
use_postquant_alltoall,
loadbalancer_local_statistic_info)
if gathered_loadbalancer_local_statistic_info is not None:
gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info.view(
(self.mapping.moe_ep_size, self.num_experts))
self.layer_load_balancer.update_statistic_with_gathered_statistic(
gathered_loadbalancer_local_statistic_info)
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
if not use_postquant_alltoall:
x, recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
@ -470,10 +485,6 @@ class WideEPMoE(MoE):
self.expert_size_per_partition * self.mapping.moe_ep_rank)
padded, x, _, token_selected_slots, token_final_scales = self.pad_empty_recv_tensors(
x, None, recv_topk_idx, token_final_scales)
if is_last_call and self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
):
gathered_loadbalancer_local_statistic_info = allgather(
loadbalancer_local_statistic_info, self.mapping, dim=0)
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
if not use_postquant_alltoall:
deep_ep_topk_idx = token_selected_slots
@ -503,10 +514,6 @@ class WideEPMoE(MoE):
x.shape[0], 1)
token_final_scales = torch.ones_like(
token_selected_slots, dtype=token_final_scales.dtype)
if is_last_call and self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
):
gathered_loadbalancer_local_statistic_info = allgather(
loadbalancer_local_statistic_info, self.mapping, dim=0)
x_sf = None
x_row = x.shape[0]
@ -550,33 +557,18 @@ class WideEPMoE(MoE):
# using allgather case.
if self.enable_dummy_allreduce:
self.dummy_allreduce()
x, x_sf, token_selected_slots, token_final_scales, gathered_token_selected_experts_for_statistic = allgather(
x, x_sf, token_selected_slots, token_final_scales = allgather(
[
x,
x_sf,
token_selected_slots,
token_final_scales,
token_selected_experts_for_statistic,
],
self.mapping,
dim=0,
sizes=None if use_dp_padding else all_rank_num_tokens)
x_row = x.shape[0]
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
):
if use_all_to_all:
if is_last_call:
gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info.view(
(self.mapping.moe_ep_size, self.num_experts))
self.layer_load_balancer.update_statistic(
gathered_loadbalancer_local_statistic_info)
else:
self.layer_load_balancer.statistic(
gathered_token_selected_experts_for_statistic,
is_first_stage=is_first_call,
is_last_stage=is_last_call)
ep_size = self.ep_size
ep_rank = self.ep_rank
w3_w1_weight = self.w3_w1_weight
@ -670,9 +662,8 @@ class WideEPMoE(MoE):
tuner_top_k=tuner_top_k,
)
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
) and is_last_call:
self.layer_load_balancer.set_cpu_stage()
if self.layer_load_balancer and is_last_call:
self.layer_load_balancer.start_set_cpu_stage()
# Only in cutlass_min_latency_mode, the output is a list of tensors.
# Otherwise, the output should be unpacked as a single tensor.
@ -702,9 +693,8 @@ class WideEPMoE(MoE):
f"Not available alltoall method type: {self.alltoall_method_type!r}"
)
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
) and is_last_call:
self.layer_load_balancer.maybe_cudagraph_done_set_cpu_stage()
if self.layer_load_balancer and is_last_call:
self.layer_load_balancer.done_set_cpu_stage()
return final_hidden_states

View File

@ -24,7 +24,6 @@ class MoE(nn.Module):
top_k (int): Number of top experts to select for each input token.
hidden_size (int): Size of the hidden state.
intermediate_size (int): Size of the intermediate state.
aux_stream (Optional[torch.cuda.Stream]): Auxiliary CUDA stream to overlap chunks.
dtype (Optional[torch.dtype]): Data type for the weights.
reduce_results (bool): Whether to reduce the results across devices.
model_config (ModelConfig): Configuration object for the model.

View File

@ -13,6 +13,9 @@ from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import is_graph_capturing
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from ...distributed import AllReduce
from ...utils import EventType
def _tensor_to_weight(t: torch.Tensor) -> _tbr.MoeWeight:
"""
@ -271,7 +274,8 @@ class SingleLayerMoeLoadBalancer:
shared_mpi_comm: MPI.Comm,
expert_count: int,
updates_enabled: bool = True,
repeated_count=1):
repeated_count=1,
aux_stream: Optional[torch.cuda.Stream] = None):
"""
Initialize a SingleLayerMoeLoadBalancer instance.
@ -287,6 +291,7 @@ class SingleLayerMoeLoadBalancer:
)
self.expert_count = expert_count
self.updates_enabled = updates_enabled
self.repeated_count = repeated_count
layer_id = self.single_layer_load_balancer_impl.get_layer_id()
self.host_tensor_sharer = HostMoeTensorSharer(
layer_id, expert_count,
@ -303,15 +308,34 @@ class SingleLayerMoeLoadBalancer:
self.expert_count)
self.load_expert_ids = list(range(load_expert_start, load_expert_end))
if self.updates_enabled:
self.aux_stream = aux_stream if aux_stream is not None else torch.cuda.Stream(
)
self.event_dict = {
key: torch.cuda.Event()
for key in [EventType.Main, EventType.MoeBalancer]
}
else:
self.aux_stream = None
self.event_dict = None
self.statistic_flag_tensor = None
self.local_statistic_tensor = None
self.cudagraph_stream = None
self.cudagraph_event = None
self.repeated_count = repeated_count
self.statistic_stream = None
self.statistic_event = None
self.func_called_count = {
name: 0
for name in [
"start_wait_gpu_stage",
"done_wait_gpu_stage",
"start_set_cpu_stage",
"done_set_cpu_stage",
"update_local_statistic",
"get_local_statistic_tensor",
"update_statistic_with_gathered_statistic",
"update_statistic_with_local_ids",
"update_statistic_with_global_ids",
"route",
]
}
def get_layer_idx(self):
return self.single_layer_load_balancer_impl.get_layer_id()
@ -441,139 +465,94 @@ class SingleLayerMoeLoadBalancer:
self.host_tensor_sharer.finalize_host_tensor_sharing(
self._add_host_weight_from_tensor)
def wait_for_gpu_stage(self) -> Optional[torch.Tensor]:
def start_wait_gpu_stage(self):
"""
Wait for the GPU stage to complete.
Returns:
A tensor indicating whether the stage is enabled
Start to wait for the GPU stage to complete.
"""
assert self.func_called_count["start_wait_gpu_stage"] == 0
self.func_called_count["start_wait_gpu_stage"] += 1
if self.updates_enabled:
assert self.statistic_flag_tensor is None, \
"Already has statistic_flag_tensor, should not wait."
if is_graph_capturing():
self.cudagraph_event = torch.cuda.Event()
self.cudagraph_stream = torch.cuda.Stream()
current_stream_event = torch.cuda.Event()
current_stream_event.record(torch.cuda.current_stream())
with torch.cuda.stream(self.cudagraph_stream):
current_stream_event.wait()
self.event_dict[EventType.Main].record()
with torch.cuda.stream(self.aux_stream):
self.event_dict[EventType.Main].wait()
self.statistic_flag_tensor = torch.ops.trtllm.moe_load_balance_wait_gpu_stage(
self.single_layer_load_balancer_ptr)
self.cudagraph_event.record(self.cudagraph_stream)
self.event_dict[EventType.MoeBalancer].record()
else:
self.statistic_flag_tensor = torch.ops.trtllm.moe_load_balance_wait_gpu_stage(
self.single_layer_load_balancer_ptr)
return self.statistic_flag_tensor
else:
return
def maybe_cudagraph_done_wait(self):
def done_wait_gpu_stage(self):
"""
Done waiting for the GPU stage to complete.
"""
assert self.func_called_count["start_wait_gpu_stage"] == 1
assert self.func_called_count["done_wait_gpu_stage"] == 0
self.func_called_count["done_wait_gpu_stage"] += 1
if self.updates_enabled:
if is_graph_capturing():
assert self.cudagraph_event is not None, "should have cudagraph_event when capturing"
assert self.cudagraph_stream is not None, "should have cudagraph_stream when capturing"
self.cudagraph_event.wait()
self.event_dict[EventType.MoeBalancer].wait()
def set_cpu_stage(self):
def start_set_cpu_stage(self):
"""
Set the CPU stage.
Start to set the CPU stage.
"""
assert self.func_called_count["done_wait_gpu_stage"] == 1
assert self.func_called_count["start_set_cpu_stage"] == 0
self.func_called_count["start_set_cpu_stage"] += 1
if self.updates_enabled:
assert self.statistic_flag_tensor is not None, \
"Doesn't have statistic_flag_tensor, should not set_cpu_stage."
self.statistic_flag_tensor = None
if is_graph_capturing():
assert self.cudagraph_stream is not None, "Doesn't have cudagraph_stream, should not set_cpu_stage."
assert self.statistic_event is not None
assert self.statistic_stream is not None
# wait statistic update done
current_stream_event = torch.cuda.Event()
current_stream_event.record(torch.cuda.current_stream())
with torch.cuda.stream(self.cudagraph_stream):
self.statistic_event.wait()
current_stream_event.wait()
self.event_dict[EventType.Main].record()
with torch.cuda.stream(self.aux_stream):
self.event_dict[EventType.Main].wait()
torch.ops.trtllm.moe_load_balance_set_cpu_stage(
self.single_layer_load_balancer_ptr)
self.cudagraph_event.record(self.cudagraph_stream)
self.statistic_event = None
self.statistic_stream = None
self.event_dict[EventType.MoeBalancer].record()
else:
torch.ops.trtllm.moe_load_balance_set_cpu_stage(
self.single_layer_load_balancer_ptr)
def maybe_cudagraph_done_set_cpu_stage(self):
def done_set_cpu_stage(self):
"""
Done setting the CPU stage.
"""
assert self.func_called_count["start_set_cpu_stage"] == 1
for name in self.func_called_count:
self.func_called_count[name] = 0
self.statistic_flag_tensor = None
if self.updates_enabled:
if is_graph_capturing():
assert self.cudagraph_event is not None, "should have cudagraph_event when capturing"
assert self.cudagraph_stream is not None, "should have cudagraph_stream when capturing"
self.cudagraph_event.wait()
self.cudagraph_stream = None
self.cudagraph_event = None
self.event_dict[EventType.MoeBalancer].wait()
def statistic(self, gathered_raw_expert_ids: torch.Tensor,
is_first_stage: bool, is_last_stage: bool):
def update_local_statistic(self, local_raw_expert_ids: torch.Tensor,
is_first_stage: bool, is_last_stage: bool):
"""
Perform statistics on the expert IDs.
Update local statistics of the expert IDs.
Args:
gathered_raw_expert_ids: The gathered raw expert IDs from all ranks
local_raw_expert_ids: The local raw expert IDs
is_first_stage: Whether this is the first stage
is_last_stage: Whether this is the last stage
"""
assert self.func_called_count["done_wait_gpu_stage"] == 1
assert self.func_called_count["update_statistic_with_global_ids"] == 0
self.func_called_count["update_local_statistic"] += 1
if self.updates_enabled:
assert isinstance(self.statistic_flag_tensor, torch.Tensor)
if is_graph_capturing():
if is_first_stage:
self.statistic_event = torch.cuda.Event()
self.statistic_stream = torch.cuda.Stream()
current_stream_event = torch.cuda.Event()
current_stream_event.record(torch.cuda.current_stream())
with torch.cuda.stream(self.statistic_stream):
current_stream_event.wait()
torch.ops.trtllm.moe_load_balance_statistic(
gathered_raw_expert_ids, self.statistic_flag_tensor,
self.single_layer_load_balancer_ptr, is_first_stage,
is_last_stage)
self.statistic_event.record()
else:
torch.ops.trtllm.moe_load_balance_statistic(
gathered_raw_expert_ids, self.statistic_flag_tensor,
self.single_layer_load_balancer_ptr, is_first_stage,
is_last_stage)
def local_statistic(self, local_raw_expert_ids: torch.Tensor,
is_first_stage: bool, is_last_stage: bool):
"""
Perform local statistics on the expert IDs.
Args:
local_raw_expert_ids: The gathered raw expert IDs from all ranks
is_first_stage: Whether this is the first stage
is_last_stage: Whether this is the last stage
"""
if self.updates_enabled:
assert isinstance(self.statistic_flag_tensor, torch.Tensor)
if is_first_stage:
assert self.local_statistic_tensor is None
if self.local_statistic_tensor is None:
self.local_statistic_tensor = torch.empty(
(self.expert_count, ),
dtype=torch.int32,
device=torch.device('cuda'))
if is_graph_capturing():
if is_first_stage:
self.statistic_event = torch.cuda.Event()
self.statistic_stream = torch.cuda.Stream()
current_stream_event = torch.cuda.Event()
current_stream_event.record(torch.cuda.current_stream())
with torch.cuda.stream(self.statistic_stream):
current_stream_event.wait()
self.event_dict[EventType.Main].record()
with torch.cuda.stream(self.aux_stream):
self.event_dict[EventType.Main].wait()
torch.ops.trtllm.moe_hierarchical_statistic_local_device(
local_raw_expert_ids, self.local_statistic_tensor,
self.statistic_flag_tensor,
self.single_layer_load_balancer_ptr, is_first_stage,
is_last_stage)
self.statistic_event.record(self.statistic_stream)
else:
torch.ops.trtllm.moe_hierarchical_statistic_local_device(
local_raw_expert_ids, self.local_statistic_tensor,
@ -581,48 +560,119 @@ class SingleLayerMoeLoadBalancer:
self.single_layer_load_balancer_ptr, is_first_stage,
is_last_stage)
def get_local_statistic_tensor(self):
def get_local_statistic_tensor(self) -> Optional[torch.Tensor]:
"""
Get the local statistic tensor. Should perform allreduce on it and then call update_statistic
Get the local statistic tensor.
Returns:
The local statistic tensor if using statistic else None
"""
assert self.func_called_count["update_local_statistic"] > 0
self.func_called_count["get_local_statistic_tensor"] += 1
if self.updates_enabled:
assert self.local_statistic_tensor is not None
if is_graph_capturing():
assert self.statistic_event is not None
assert self.statistic_stream is not None
self.statistic_event.wait()
with torch.cuda.stream(self.aux_stream):
self.event_dict[EventType.MoeBalancer].record()
self.event_dict[EventType.MoeBalancer].wait()
return self.local_statistic_tensor
return None
def update_statistic(self, gathered_local_statistic_tensor: torch.Tensor):
def update_statistic_with_gathered_statistic(
self, gathered_local_statistic_tensor: torch.Tensor):
"""
Perform update with global statistics.
Update statistics of the expert IDs, using gathered local statistic tensors.
Args:
gathered_local_statistic_tensor: gathered local statistics info, should have shape (world_size, self.expert_count)
"""
if self.updates_enabled:
assert isinstance(self.statistic_flag_tensor, torch.Tensor)
assert self.func_called_count["get_local_statistic_tensor"] > 0
assert self.func_called_count["update_statistic_with_local_ids"] == 0
assert self.func_called_count["update_statistic_with_global_ids"] == 0
self.func_called_count["update_statistic_with_gathered_statistic"] += 1
def _update_statistic():
global_statistic_info = torch.sum(
gathered_local_statistic_tensor, dim=0, dtype=torch.int32)
def _update_statistic():
global_statistic_info = torch.sum(gathered_local_statistic_tensor,
dim=0,
dtype=torch.int32)
torch.ops.trtllm.moe_hierarchical_statistic_update(
global_statistic_info, self.statistic_flag_tensor,
self.single_layer_load_balancer_ptr)
if self.updates_enabled:
if is_graph_capturing():
self.event_dict[EventType.Main].record()
with torch.cuda.stream(self.aux_stream):
self.event_dict[EventType.Main].wait()
_update_statistic()
else:
_update_statistic()
def update_statistic_with_local_ids(self,
local_raw_expert_ids: torch.Tensor,
is_first_stage: bool,
is_last_stage: bool,
allreduce: Optional[AllReduce] = None):
"""
Update statistics of the expert IDs, using local raw expert IDs.
Args:
local_raw_expert_ids: The local raw expert IDs
is_first_stage: Whether this is the first stage
is_last_stage: Whether this is the last stage
allreduce: The allreduce object
"""
assert self.func_called_count["done_wait_gpu_stage"] == 1
assert self.func_called_count[
"update_statistic_with_gathered_statistic"] == 0
assert self.func_called_count["update_statistic_with_global_ids"] == 0
self.func_called_count["update_statistic_with_local_ids"] += 1
def _update_statistic():
if is_last_stage:
global_statistic_info = allreduce(self.local_statistic_tensor)
torch.ops.trtllm.moe_hierarchical_statistic_update(
global_statistic_info, self.statistic_flag_tensor,
self.single_layer_load_balancer_ptr)
if self.updates_enabled:
self.update_local_statistic(local_raw_expert_ids, is_first_stage,
is_last_stage)
if is_graph_capturing():
current_stream_event = torch.cuda.Event()
current_stream_event.record(torch.cuda.current_stream())
with torch.cuda.stream(self.statistic_stream):
current_stream_event.wait()
with torch.cuda.stream(self.aux_stream):
_update_statistic()
self.statistic_event.record(self.statistic_stream)
else:
_update_statistic()
self.local_statistic_tensor = None
def update_statistic_with_global_ids(self,
gathered_raw_expert_ids: torch.Tensor,
is_first_stage: bool,
is_last_stage: bool):
"""
Update statistics of the expert IDs, using gathered raw expert IDs from all ranks.
Args:
gathered_raw_expert_ids: The gathered raw expert IDs from all ranks
is_first_stage: Whether this is the first stage
is_last_stage: Whether this is the last stage
"""
assert self.func_called_count["done_wait_gpu_stage"] == 1
assert self.func_called_count[
"update_statistic_with_gathered_statistic"] == 0
assert self.func_called_count["update_statistic_with_local_ids"] == 0
self.func_called_count["update_statistic_with_global_ids"] += 1
if self.updates_enabled:
if is_graph_capturing():
self.event_dict[EventType.Main].record()
with torch.cuda.stream(self.aux_stream):
self.event_dict[EventType.Main].wait()
torch.ops.trtllm.moe_load_balance_statistic(
gathered_raw_expert_ids, self.statistic_flag_tensor,
self.single_layer_load_balancer_ptr, is_first_stage,
is_last_stage)
else:
torch.ops.trtllm.moe_load_balance_statistic(
gathered_raw_expert_ids, self.statistic_flag_tensor,
self.single_layer_load_balancer_ptr, is_first_stage,
is_last_stage)
def route(self,
token_selected_experts: torch.Tensor,
@ -637,6 +687,8 @@ class SingleLayerMoeLoadBalancer:
Returns:
A tensor of routed slot IDs
"""
assert self.func_called_count["done_wait_gpu_stage"] == 1
self.func_called_count["route"] += 1
return torch.ops.trtllm.moe_load_balance_routing(
token_selected_experts, offset_by_ep_rank,
self.single_layer_load_balancer_ptr)
@ -731,8 +783,13 @@ class MoeLoadBalancer:
assert repeated_count > 0, "repeat count must be greater than 0"
self.next_layer_repeated_count = repeated_count
def add_layer(self, expert_count: int, top_k: int,
slot_count_per_rank: int) -> SingleLayerMoeLoadBalancer:
def add_layer(
self,
expert_count: int,
top_k: int,
slot_count_per_rank: int,
aux_stream: Optional[torch.cuda.Stream] = None
) -> SingleLayerMoeLoadBalancer:
"""
Add a new layer to the load balancer.
@ -740,6 +797,7 @@ class MoeLoadBalancer:
expert_count: The number of experts in the layer
top_k: The number of experts each token selects
slot_count_per_rank: The number of slots per rank
aux_stream: The auxiliary stream for overlapping
Returns:
A SingleLayerMoeLoadBalancer instance for the new layer
@ -756,7 +814,8 @@ class MoeLoadBalancer:
self.shared_mpi_comm,
expert_count,
updates_enabled=updates_enabled,
repeated_count=repeat_count)
repeated_count=repeat_count,
aux_stream=aux_stream)
single_layer_load_balancer.set_shared_memory_base_name(
self.shared_memory_base_name)
self.single_layer_load_balancers.append(single_layer_load_balancer)
@ -988,8 +1047,11 @@ def moe_load_balancer_set_repeated_for_next_layer(repeat_count: int):
def moe_load_balancer_add_single_layer(
expert_count: int, top_k: int,
slot_count_per_rank: int) -> Optional[SingleLayerMoeLoadBalancer]:
expert_count: int,
top_k: int,
slot_count_per_rank: int,
aux_stream: Optional[torch.cuda.Stream] = None
) -> Optional[SingleLayerMoeLoadBalancer]:
"""
Add a new layer to the current active MoeLoadBalancer.
@ -997,11 +1059,13 @@ def moe_load_balancer_add_single_layer(
expert_count: The number of experts in the layer
top_k: The number of experts each token selects
slot_count_per_rank: The number of slots per rank
aux_stream: The auxiliary stream for overlapping
Returns:
A SingleLayerMoeLoadBalancer instance for the new layer, or None if not in a MoeLoadBalancer context
"""
load_balancer = get_moe_load_balancer()
if load_balancer is not None:
return load_balancer.add_layer(expert_count, top_k, slot_count_per_rank)
return load_balancer.add_layer(expert_count, top_k, slot_count_per_rank,
aux_stream)
return None

View File

@ -12,7 +12,12 @@ from tensorrt_llm.quantization.utils import fp4_utils
is_torch_compiling_flag = False
aux_stream_name_list = ['Attention', 'MoeShared', 'MoeChunkingOverlap']
aux_stream_name_list = [
'Attention',
'MoeShared',
'MoeChunkingOverlap',
'MoeBalancer',
]
AuxStreamType = Enum(
'AuxStreamType',
aux_stream_name_list,

View File

@ -222,22 +222,18 @@ class TestMoeLoadBalancer(unittest.TestCase):
# wait_for_gpu_stage
mock_wait.return_value = torch.tensor([1])
layer.wait_for_gpu_stage()
layer.start_wait_gpu_stage()
layer.done_wait_gpu_stage()
result = layer.statistic_flag_tensor
mock_wait.assert_called_once_with(
mock_single_layer_impl.get_pointer())
self.assertEqual(result, mock_wait.return_value)
# set_cpu_stage
layer.set_cpu_stage()
mock_set_cpu.assert_called_once_with(
mock_single_layer_impl.get_pointer())
# statistic
mock_expert_ids = torch.tensor([[0, 1], [2, 3]])
mock_enabled = torch.tensor([1])
layer.statistic_flag_tensor = mock_enabled
layer.statistic(mock_expert_ids, True, False)
layer.update_statistic_with_global_ids(mock_expert_ids, True, False)
mock_statistic.assert_called_once_with(
mock_expert_ids, mock_enabled,
mock_single_layer_impl.get_pointer(), True, False)
@ -248,6 +244,12 @@ class TestMoeLoadBalancer(unittest.TestCase):
result = layer.route(mock_selected_experts)
assert torch.equal(result, mock_route.return_value)
# set_cpu_stage
layer.start_set_cpu_stage()
layer.done_set_cpu_stage()
mock_set_cpu.assert_called_once_with(
mock_single_layer_impl.get_pointer())
@patch('tensorrt_llm.bindings.internal.runtime.MoeLoadBalancer')
def test_moe_load_balancer_lifecycle_methods(self, mock_load_balancer_impl):
"""Test lifecycle methods of MoeLoadBalancer."""
@ -323,13 +325,16 @@ class TestMoeLoadBalancer(unittest.TestCase):
try:
with MoeLoadBalancerIterContext(balancer):
# Wait for GPU stage and get enabled flag
layer.wait_for_gpu_stage()
layer.start_wait_gpu_stage()
layer.done_wait_gpu_stage()
# Run statistic - just test it runs without error
layer.statistic(gathered_raw_expert_ids, True, True)
layer.update_statistic_with_global_ids(gathered_raw_expert_ids,
True, True)
# Set CPU stage to signal completion
layer.set_cpu_stage()
layer.start_set_cpu_stage()
layer.done_set_cpu_stage()
# Test passed if we got here without exceptions
self.assertTrue(True, "Statistic kernel ran successfully")
@ -384,13 +389,15 @@ class TestMoeLoadBalancer(unittest.TestCase):
try:
with MoeLoadBalancerIterContext(balancer):
# Wait for GPU stage
layer.wait_for_gpu_stage()
layer.start_wait_gpu_stage()
layer.done_wait_gpu_stage()
# Run routing
routed_slots = layer.route(token_selected_experts)
# Set CPU stage
layer.set_cpu_stage()
layer.start_set_cpu_stage()
layer.done_set_cpu_stage()
# Verify results - with our initial assignment, expert i should map to slot i
expected_slots = torch.tensor(