mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][perf] Improve the performance of online EPLB on Hopper by better overlapping (#6624)
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
This commit is contained in:
parent
be9dd4713c
commit
ead89a0e40
@ -453,7 +453,7 @@ class Deepseekv3MoE(nn.Module):
|
||||
False, # In both low‑latency and attention‑DP modes, FusedMoE skips the in‑op all‑reduce.
|
||||
model_config=model_config,
|
||||
override_quant_config=override_quant_config,
|
||||
aux_stream=aux_stream_dict[AuxStreamType.MoeChunkingOverlap],
|
||||
aux_stream_dict=aux_stream_dict,
|
||||
layer_idx=layer_idx)
|
||||
|
||||
self.mapping = model_config.mapping
|
||||
@ -1049,11 +1049,12 @@ class DeepseekV3Model(DecoderModel):
|
||||
config = model_config.pretrained_config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
aux_stream_list = [torch.cuda.Stream() for _ in range(2)]
|
||||
aux_stream_list = [torch.cuda.Stream() for _ in range(3)]
|
||||
self.aux_stream_dict = {
|
||||
AuxStreamType.Attention: aux_stream_list[0],
|
||||
AuxStreamType.MoeShared: aux_stream_list[0],
|
||||
AuxStreamType.MoeChunkingOverlap: aux_stream_list[1],
|
||||
AuxStreamType.MoeBalancer: aux_stream_list[2],
|
||||
}
|
||||
|
||||
self.embed_tokens = Embedding(
|
||||
|
||||
@ -15,6 +15,7 @@ from ..modules.embedding import Embedding
|
||||
from ..modules.fused_moe import RenormalizeMoeRoutingMethod, create_moe
|
||||
from ..modules.linear import Linear
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from ..utils import AuxStreamType
|
||||
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
|
||||
register_auto_model)
|
||||
|
||||
@ -49,7 +50,7 @@ class MixtralMoE(nn.Module):
|
||||
routing_method=RenormalizeMoeRoutingMethod(top_k=self.top_k),
|
||||
hidden_size=self.hidden_dim,
|
||||
intermediate_size=self.ffn_dim,
|
||||
aux_stream=aux_stream,
|
||||
aux_stream_dict={AuxStreamType.MoeChunkingOverlap: aux_stream},
|
||||
dtype=config.torch_dtype,
|
||||
reduce_results=reduce_results,
|
||||
model_config=model_config,
|
||||
|
||||
@ -22,6 +22,7 @@ from ..modules.fused_moe import (BaseMoeRoutingMethod,
|
||||
from ..modules.linear import TensorParallelMode
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from ..speculative import SpecMetadata
|
||||
from ..utils import AuxStreamType
|
||||
from .modeling_qwen3 import Qwen3Attention
|
||||
from .modeling_speculative import SpecDecOneEngineForCausalLM
|
||||
from .modeling_utils import DecoderModel, EagerFusionConfig, register_auto_model
|
||||
@ -107,7 +108,7 @@ class Qwen3MoE(nn.Module):
|
||||
routing_method=self.gate.routing_method,
|
||||
hidden_size=self.hidden_dim,
|
||||
intermediate_size=self.moe_intermediate_size,
|
||||
aux_stream=aux_stream,
|
||||
aux_stream_dict={AuxStreamType.MoeChunkingOverlap: aux_stream},
|
||||
dtype=config.torch_dtype,
|
||||
reduce_results=False,
|
||||
model_config=model_config,
|
||||
|
||||
@ -17,6 +17,7 @@ from ..modules.fused_moe import DefaultMoeRoutingMethod, create_moe
|
||||
from ..modules.gated_mlp import GatedMLP
|
||||
from ..modules.linear import Linear, TensorParallelMode
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from ..utils import AuxStreamType
|
||||
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
|
||||
register_auto_model)
|
||||
|
||||
@ -53,7 +54,7 @@ class QwenMoE(nn.Module):
|
||||
routing_method=DefaultMoeRoutingMethod(top_k=self.top_k),
|
||||
hidden_size=self.hidden_dim,
|
||||
intermediate_size=self.moe_intermediate_size,
|
||||
aux_stream=aux_stream,
|
||||
aux_stream_dict={AuxStreamType.MoeChunkingOverlap: aux_stream},
|
||||
dtype=config.torch_dtype,
|
||||
reduce_results=reduce_results,
|
||||
model_config=model_config,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional, Type
|
||||
from typing import Dict, Optional, Type
|
||||
|
||||
import torch
|
||||
|
||||
@ -6,6 +6,7 @@ from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
|
||||
from ...model_config import ModelConfig
|
||||
from ...utils import AuxStreamType
|
||||
from .fused_moe_cute_dsl import CuteDslFusedMoE
|
||||
from .fused_moe_cutlass import CutlassFusedMoE
|
||||
from .fused_moe_deepgemm import DeepGemmFusedMoE
|
||||
@ -66,7 +67,7 @@ def create_moe(
|
||||
reduce_results: bool = False,
|
||||
model_config: ModelConfig = ModelConfig(),
|
||||
override_quant_config: Optional[QuantConfig] = None,
|
||||
aux_stream: Optional[torch.cuda.Stream] = None,
|
||||
aux_stream_dict: Optional[Dict[AuxStreamType, torch.cuda.Stream]] = None,
|
||||
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.VANILLA,
|
||||
bias: bool = False,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
@ -123,7 +124,7 @@ def create_moe(
|
||||
dtype=dtype,
|
||||
reduce_results=reduce_results,
|
||||
model_config=model_config,
|
||||
aux_stream=aux_stream,
|
||||
aux_stream_dict=aux_stream_dict,
|
||||
weight_loading_mode=weight_loading_mode,
|
||||
bias=bias,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
@ -141,7 +142,7 @@ def create_moe(
|
||||
dtype=dtype,
|
||||
reduce_results=reduce_results,
|
||||
model_config=model_config,
|
||||
aux_stream=aux_stream,
|
||||
aux_stream_dict=aux_stream_dict,
|
||||
weight_loading_mode=weight_loading_mode,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
layer_idx=layer_idx,
|
||||
@ -169,7 +170,7 @@ def create_moe(
|
||||
dtype=dtype,
|
||||
reduce_results=reduce_results,
|
||||
model_config=model_config,
|
||||
aux_stream=aux_stream,
|
||||
aux_stream_dict=aux_stream_dict,
|
||||
weight_loading_mode=weight_loading_mode,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
layer_idx=layer_idx,
|
||||
@ -183,7 +184,7 @@ def create_moe(
|
||||
dtype=dtype,
|
||||
reduce_results=reduce_results,
|
||||
model_config=model_config,
|
||||
aux_stream=aux_stream,
|
||||
aux_stream_dict=aux_stream_dict,
|
||||
weight_loading_mode=weight_loading_mode,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
layer_idx=layer_idx,
|
||||
@ -199,7 +200,6 @@ def create_moe(
|
||||
dtype=dtype,
|
||||
reduce_results=reduce_results,
|
||||
model_config=model_config,
|
||||
aux_stream=aux_stream,
|
||||
weight_loading_mode=weight_loading_mode,
|
||||
bias=bias,
|
||||
layer_idx=layer_idx,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import math
|
||||
from typing import List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
|
||||
from ...model_config import ModelConfig
|
||||
from ...utils import Fp4QuantizedTensor
|
||||
from ...utils import AuxStreamType, Fp4QuantizedTensor
|
||||
from .fused_moe_cutlass import CutlassFusedMoE
|
||||
from .quantization import MoEWeightLoadingMode
|
||||
from .routing import BaseMoeRoutingMethod
|
||||
@ -97,7 +97,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
top_k (int): Number of top experts to select for each input token.
|
||||
hidden_size (int): Size of the hidden state.
|
||||
intermediate_size (int): Size of the intermediate state.
|
||||
aux_stream (Optional[torch.cuda.Stream]): Auxiliary CUDA stream to overlap chunks.
|
||||
aux_stream_dict (Optional[Dict[AuxStreamType, torch.cuda.Stream]]): Auxiliary CUDA streams for overlapping.
|
||||
dtype (Optional[torch.dtype]): Data type for the weights.
|
||||
reduce_results (bool): Whether to reduce the results across devices.
|
||||
model_config (ModelConfig): Configuration object for the model.
|
||||
@ -118,7 +118,8 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = False,
|
||||
model_config: ModelConfig = ModelConfig(),
|
||||
aux_stream: Optional[torch.cuda.Stream] = None,
|
||||
aux_stream_dict: Optional[Dict[AuxStreamType,
|
||||
torch.cuda.Stream]] = None,
|
||||
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
|
||||
VANILLA,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
@ -133,7 +134,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
dtype=dtype,
|
||||
reduce_results=reduce_results,
|
||||
model_config=model_config,
|
||||
aux_stream=aux_stream,
|
||||
aux_stream_dict=aux_stream_dict,
|
||||
weight_loading_mode=weight_loading_mode,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
layer_idx=layer_idx,
|
||||
|
||||
@ -9,7 +9,8 @@ from tensorrt_llm.math_utils import pad_up
|
||||
|
||||
from ...distributed import allgather
|
||||
from ...model_config import ModelConfig
|
||||
from ...utils import EventType, Fp4QuantizedTensor, ceil_div, swizzle_sf
|
||||
from ...utils import (AuxStreamType, EventType, Fp4QuantizedTensor, ceil_div,
|
||||
swizzle_sf)
|
||||
from .interface import MoE
|
||||
|
||||
# isort: off
|
||||
@ -31,7 +32,7 @@ class CutlassFusedMoE(MoE):
|
||||
top_k (int): Number of top experts to select for each input token.
|
||||
hidden_size (int): Size of the hidden state.
|
||||
intermediate_size (int): Size of the intermediate state.
|
||||
aux_stream (Optional[torch.cuda.Stream]): Auxiliary CUDA stream to overlap chunks.
|
||||
aux_stream_dict (Optional[Dict[AuxStreamType, torch.cuda.Stream]]): Auxiliary CUDA streams for overlapping.
|
||||
dtype (Optional[torch.dtype]): Data type for the weights.
|
||||
reduce_results (bool): Whether to reduce the results across devices.
|
||||
model_config (ModelConfig): Configuration object for the model.
|
||||
@ -60,7 +61,8 @@ class CutlassFusedMoE(MoE):
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = False,
|
||||
model_config: ModelConfig = ModelConfig(),
|
||||
aux_stream: Optional[torch.cuda.Stream] = None,
|
||||
aux_stream_dict: Optional[Dict[AuxStreamType,
|
||||
torch.cuda.Stream]] = None,
|
||||
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
|
||||
VANILLA,
|
||||
bias: bool = False,
|
||||
@ -115,8 +117,10 @@ class CutlassFusedMoE(MoE):
|
||||
self.moe_max_num_tokens = model_config.moe_max_num_tokens or max_num_tokens
|
||||
# The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied
|
||||
if self.moe_max_num_tokens < max_num_tokens:
|
||||
self.aux_stream = aux_stream if aux_stream is not None else torch.cuda.Stream(
|
||||
)
|
||||
self.aux_stream = aux_stream_dict[
|
||||
AuxStreamType.
|
||||
MoeChunkingOverlap] if aux_stream_dict is not None else torch.cuda.Stream(
|
||||
)
|
||||
self.event_dict = {
|
||||
key: torch.cuda.Event()
|
||||
for key in [EventType.Main, EventType.MoeChunkingOverlap]
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -11,7 +11,7 @@ from tensorrt_llm._utils import nvtx_range
|
||||
|
||||
from ...distributed import allgather
|
||||
from ...model_config import ModelConfig
|
||||
from ...utils import Fp4QuantizedTensor
|
||||
from ...utils import AuxStreamType, Fp4QuantizedTensor
|
||||
from .fused_moe_cutlass import CutlassFusedMoE
|
||||
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm,
|
||||
MoEWeightLoadingMode, UnquantizedFusedMoEMethod)
|
||||
@ -299,7 +299,7 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
|
||||
top_k (int): Number of top experts to select for each input token.
|
||||
hidden_size (int): Size of the hidden state.
|
||||
intermediate_size (int): Size of the intermediate state.
|
||||
aux_stream (Optional[torch.cuda.Stream]): Auxiliary CUDA stream to overlap chunks.
|
||||
aux_stream_dict (Optional[Dict[AuxStreamType, torch.cuda.Stream]]): Auxiliary CUDA streams for overlapping.
|
||||
dtype (Optional[torch.dtype]): Data type for the weights.
|
||||
reduce_results (bool): Whether to reduce the results across devices.
|
||||
model_config (ModelConfig): Configuration object for the model.
|
||||
@ -320,7 +320,8 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = False,
|
||||
model_config: ModelConfig = ModelConfig(),
|
||||
aux_stream: Optional[torch.cuda.Stream] = None,
|
||||
aux_stream_dict: Optional[Dict[AuxStreamType,
|
||||
torch.cuda.Stream]] = None,
|
||||
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
|
||||
VANILLA,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
@ -335,7 +336,7 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
|
||||
dtype=dtype,
|
||||
reduce_results=reduce_results,
|
||||
model_config=model_config,
|
||||
aux_stream=aux_stream,
|
||||
aux_stream_dict=aux_stream_dict,
|
||||
weight_loading_mode=weight_loading_mode,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
layer_idx=layer_idx,
|
||||
|
||||
@ -1207,7 +1207,6 @@ class TritonFusedMoE(MoE):
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = False,
|
||||
model_config: ModelConfig = ModelConfig(),
|
||||
aux_stream: Optional[torch.cuda.Stream] = None,
|
||||
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
|
||||
VANILLA,
|
||||
bias: bool = False,
|
||||
|
||||
@ -23,7 +23,6 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
top_k (int): Number of top experts to select for each input token.
|
||||
hidden_size (int): Size of the hidden state.
|
||||
intermediate_size (int): Size of the intermediate state.
|
||||
aux_stream (Optional[torch.cuda.Stream]): Auxiliary CUDA stream to overlap chunks.
|
||||
dtype (Optional[torch.dtype]): Data type for the weights.
|
||||
reduce_results (bool): Whether to reduce the results across devices.
|
||||
model_config (ModelConfig): Configuration object for the model.
|
||||
|
||||
@ -6,12 +6,13 @@ import torch
|
||||
|
||||
from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo
|
||||
from tensorrt_llm._utils import logger
|
||||
from tensorrt_llm.functional import AllReduceStrategy
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ...distributed import allgather, reducescatter
|
||||
from ...distributed import AllReduce, allgather, reducescatter
|
||||
from ...expert_statistic import ExpertStatistic
|
||||
from ...model_config import ModelConfig
|
||||
from ...utils import EventType, Fp4QuantizedTensor
|
||||
from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor
|
||||
from .deep_ep_utils import buffer_pool, deep_ep_installed
|
||||
from .interface import MoE
|
||||
from .moe_load_balancer import get_moe_load_balancer
|
||||
@ -43,7 +44,7 @@ class WideEPMoE(MoE):
|
||||
top_k (int): Number of top experts to select for each input token.
|
||||
hidden_size (int): Size of the hidden state.
|
||||
intermediate_size (int): Size of the intermediate state.
|
||||
aux_stream (Optional[torch.cuda.Stream]): Auxiliary CUDA stream to overlap chunks.
|
||||
aux_stream_dict (Optional[Dict[AuxStreamType, torch.cuda.Stream]]): Auxiliary CUDA streams for overlapping.
|
||||
dtype (Optional[torch.dtype]): Data type for the weights.
|
||||
reduce_results (bool): Whether to reduce the results across devices.
|
||||
model_config (ModelConfig): Configuration object for the model.
|
||||
@ -64,7 +65,8 @@ class WideEPMoE(MoE):
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = False,
|
||||
model_config: ModelConfig = ModelConfig(),
|
||||
aux_stream: Optional[torch.cuda.Stream] = None,
|
||||
aux_stream_dict: Optional[Dict[AuxStreamType,
|
||||
torch.cuda.Stream]] = None,
|
||||
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
|
||||
VANILLA,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
@ -106,7 +108,11 @@ class WideEPMoE(MoE):
|
||||
top_k = self.routing_method.experts_per_token
|
||||
self.expert_size_per_partition = moe_load_balancer_config.num_local_slots
|
||||
self.layer_load_balancer = moe_load_balancer.add_layer(
|
||||
self.num_experts, top_k, self.expert_size_per_partition)
|
||||
self.num_experts,
|
||||
top_k,
|
||||
self.expert_size_per_partition,
|
||||
aux_stream=None if aux_stream_dict is None else
|
||||
aux_stream_dict[AuxStreamType.MoeBalancer])
|
||||
self.repeat_count = self.layer_load_balancer.get_repeat_count()
|
||||
loaded_initial_global_assignments = moe_load_balancer_config.get_layer_initial_global_assignments(
|
||||
self.layer_idx)
|
||||
@ -130,6 +136,12 @@ class WideEPMoE(MoE):
|
||||
assert num_experts % self.ep_size == 0
|
||||
self.expert_size_per_partition = num_experts // self.ep_size
|
||||
self.num_slots = num_experts
|
||||
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
|
||||
):
|
||||
self.allreduce = AllReduce(mapping=model_config.mapping,
|
||||
strategy=AllReduceStrategy.NCCL)
|
||||
else:
|
||||
self.allreduce = None
|
||||
|
||||
self.slot_start = self.ep_rank * self.expert_size_per_partition
|
||||
self.slot_end = self.slot_start + self.expert_size_per_partition
|
||||
@ -144,8 +156,10 @@ class WideEPMoE(MoE):
|
||||
self.moe_max_num_tokens = model_config.moe_max_num_tokens if model_config.moe_max_num_tokens is not None else max_num_tokens
|
||||
# The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied
|
||||
if self.moe_max_num_tokens < max_num_tokens:
|
||||
self.aux_stream = aux_stream if aux_stream is not None else torch.cuda.Stream(
|
||||
)
|
||||
self.aux_stream = aux_stream_dict[
|
||||
AuxStreamType.
|
||||
MoeChunkingOverlap] if aux_stream_dict is not None else torch.cuda.Stream(
|
||||
)
|
||||
self.event_dict = {
|
||||
key: torch.cuda.Event()
|
||||
for key in [EventType.Main, EventType.MoeChunkingOverlap]
|
||||
@ -371,9 +385,8 @@ class WideEPMoE(MoE):
|
||||
|
||||
is_first_call, is_last_call = repeating_info
|
||||
|
||||
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
|
||||
) and is_first_call:
|
||||
self.layer_load_balancer.wait_for_gpu_stage()
|
||||
if self.layer_load_balancer and is_first_call:
|
||||
self.layer_load_balancer.start_wait_gpu_stage()
|
||||
|
||||
use_deepseek_fp8_block_scale = False
|
||||
use_w4_group_scaling = False
|
||||
@ -401,40 +414,32 @@ class WideEPMoE(MoE):
|
||||
else:
|
||||
token_final_scales = None
|
||||
|
||||
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
|
||||
) and is_first_call:
|
||||
self.layer_load_balancer.maybe_cudagraph_done_wait()
|
||||
|
||||
use_allgather = not use_all_to_all
|
||||
|
||||
loadbalancer_local_statistic_info = None
|
||||
gathered_loadbalancer_local_statistic_info = None
|
||||
token_selected_experts_for_statistic = None
|
||||
if self.layer_load_balancer is None:
|
||||
token_selected_slots = token_selected_experts
|
||||
else:
|
||||
if not self.layer_load_balancer.is_static_routing(
|
||||
) and use_all_to_all:
|
||||
self.layer_load_balancer.local_statistic(
|
||||
if self.layer_load_balancer:
|
||||
if is_first_call:
|
||||
self.layer_load_balancer.done_wait_gpu_stage()
|
||||
if use_all_to_all and self.alltoall_method_type == AlltoallMethodType.MNNVL:
|
||||
self.layer_load_balancer.update_local_statistic(
|
||||
token_selected_experts,
|
||||
is_first_stage=is_first_call,
|
||||
is_last_stage=is_last_call)
|
||||
else:
|
||||
self.layer_load_balancer.update_statistic_with_local_ids(
|
||||
token_selected_experts,
|
||||
is_first_stage=is_first_call,
|
||||
is_last_stage=is_last_call,
|
||||
allreduce=self.allreduce)
|
||||
token_selected_slots = self.layer_load_balancer.route(
|
||||
token_selected_experts, self.use_dp)
|
||||
if not self.layer_load_balancer.is_static_routing():
|
||||
# split into two part to get possible overlap with load balancer routing
|
||||
if use_all_to_all:
|
||||
if is_last_call:
|
||||
loadbalancer_local_statistic_info = self.layer_load_balancer.get_local_statistic_tensor(
|
||||
)
|
||||
else:
|
||||
token_selected_experts_for_statistic = token_selected_experts
|
||||
else:
|
||||
token_selected_slots = token_selected_experts
|
||||
|
||||
# If load balancer is disabled, the statistics are collected from expert IDs.
|
||||
# If load balancer is enabled, the statistics are collected from expert slot IDs.
|
||||
ExpertStatistic.set_layer(self.layer_idx)
|
||||
ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots)
|
||||
|
||||
use_allgather = not use_all_to_all
|
||||
|
||||
# If alltoall is disabled, we need also disable use_postquant_alltoall
|
||||
use_postquant_alltoall = self.use_postquant_alltoall and use_all_to_all
|
||||
|
||||
@ -456,6 +461,11 @@ class WideEPMoE(MoE):
|
||||
self.dummy_allreduce()
|
||||
token_count = x.shape[0]
|
||||
alltoall_info = None
|
||||
if is_last_call:
|
||||
loadbalancer_local_statistic_info = self.layer_load_balancer.get_local_statistic_tensor(
|
||||
)
|
||||
else:
|
||||
loadbalancer_local_statistic_info = None
|
||||
x, token_selected_slots, token_final_scales, gathered_loadbalancer_local_statistic_info, alltoall_info = \
|
||||
self.alltoall_prepare_maybe_dispatch(all_rank_max_num_tokens,
|
||||
x,
|
||||
@ -463,6 +473,11 @@ class WideEPMoE(MoE):
|
||||
token_final_scales,
|
||||
use_postquant_alltoall,
|
||||
loadbalancer_local_statistic_info)
|
||||
if gathered_loadbalancer_local_statistic_info is not None:
|
||||
gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info.view(
|
||||
(self.mapping.moe_ep_size, self.num_experts))
|
||||
self.layer_load_balancer.update_statistic_with_gathered_statistic(
|
||||
gathered_loadbalancer_local_statistic_info)
|
||||
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
|
||||
if not use_postquant_alltoall:
|
||||
x, recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
|
||||
@ -470,10 +485,6 @@ class WideEPMoE(MoE):
|
||||
self.expert_size_per_partition * self.mapping.moe_ep_rank)
|
||||
padded, x, _, token_selected_slots, token_final_scales = self.pad_empty_recv_tensors(
|
||||
x, None, recv_topk_idx, token_final_scales)
|
||||
if is_last_call and self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
|
||||
):
|
||||
gathered_loadbalancer_local_statistic_info = allgather(
|
||||
loadbalancer_local_statistic_info, self.mapping, dim=0)
|
||||
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
|
||||
if not use_postquant_alltoall:
|
||||
deep_ep_topk_idx = token_selected_slots
|
||||
@ -503,10 +514,6 @@ class WideEPMoE(MoE):
|
||||
x.shape[0], 1)
|
||||
token_final_scales = torch.ones_like(
|
||||
token_selected_slots, dtype=token_final_scales.dtype)
|
||||
if is_last_call and self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
|
||||
):
|
||||
gathered_loadbalancer_local_statistic_info = allgather(
|
||||
loadbalancer_local_statistic_info, self.mapping, dim=0)
|
||||
|
||||
x_sf = None
|
||||
x_row = x.shape[0]
|
||||
@ -550,33 +557,18 @@ class WideEPMoE(MoE):
|
||||
# using allgather case.
|
||||
if self.enable_dummy_allreduce:
|
||||
self.dummy_allreduce()
|
||||
x, x_sf, token_selected_slots, token_final_scales, gathered_token_selected_experts_for_statistic = allgather(
|
||||
x, x_sf, token_selected_slots, token_final_scales = allgather(
|
||||
[
|
||||
x,
|
||||
x_sf,
|
||||
token_selected_slots,
|
||||
token_final_scales,
|
||||
token_selected_experts_for_statistic,
|
||||
],
|
||||
self.mapping,
|
||||
dim=0,
|
||||
sizes=None if use_dp_padding else all_rank_num_tokens)
|
||||
x_row = x.shape[0]
|
||||
|
||||
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
|
||||
):
|
||||
if use_all_to_all:
|
||||
if is_last_call:
|
||||
gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info.view(
|
||||
(self.mapping.moe_ep_size, self.num_experts))
|
||||
self.layer_load_balancer.update_statistic(
|
||||
gathered_loadbalancer_local_statistic_info)
|
||||
else:
|
||||
self.layer_load_balancer.statistic(
|
||||
gathered_token_selected_experts_for_statistic,
|
||||
is_first_stage=is_first_call,
|
||||
is_last_stage=is_last_call)
|
||||
|
||||
ep_size = self.ep_size
|
||||
ep_rank = self.ep_rank
|
||||
w3_w1_weight = self.w3_w1_weight
|
||||
@ -670,9 +662,8 @@ class WideEPMoE(MoE):
|
||||
tuner_top_k=tuner_top_k,
|
||||
)
|
||||
|
||||
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
|
||||
) and is_last_call:
|
||||
self.layer_load_balancer.set_cpu_stage()
|
||||
if self.layer_load_balancer and is_last_call:
|
||||
self.layer_load_balancer.start_set_cpu_stage()
|
||||
|
||||
# Only in cutlass_min_latency_mode, the output is a list of tensors.
|
||||
# Otherwise, the output should be unpacked as a single tensor.
|
||||
@ -702,9 +693,8 @@ class WideEPMoE(MoE):
|
||||
f"Not available alltoall method type: {self.alltoall_method_type!r}"
|
||||
)
|
||||
|
||||
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
|
||||
) and is_last_call:
|
||||
self.layer_load_balancer.maybe_cudagraph_done_set_cpu_stage()
|
||||
if self.layer_load_balancer and is_last_call:
|
||||
self.layer_load_balancer.done_set_cpu_stage()
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
@ -24,7 +24,6 @@ class MoE(nn.Module):
|
||||
top_k (int): Number of top experts to select for each input token.
|
||||
hidden_size (int): Size of the hidden state.
|
||||
intermediate_size (int): Size of the intermediate state.
|
||||
aux_stream (Optional[torch.cuda.Stream]): Auxiliary CUDA stream to overlap chunks.
|
||||
dtype (Optional[torch.dtype]): Data type for the weights.
|
||||
reduce_results (bool): Whether to reduce the results across devices.
|
||||
model_config (ModelConfig): Configuration object for the model.
|
||||
|
||||
@ -13,6 +13,9 @@ from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import is_graph_capturing
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ...distributed import AllReduce
|
||||
from ...utils import EventType
|
||||
|
||||
|
||||
def _tensor_to_weight(t: torch.Tensor) -> _tbr.MoeWeight:
|
||||
"""
|
||||
@ -271,7 +274,8 @@ class SingleLayerMoeLoadBalancer:
|
||||
shared_mpi_comm: MPI.Comm,
|
||||
expert_count: int,
|
||||
updates_enabled: bool = True,
|
||||
repeated_count=1):
|
||||
repeated_count=1,
|
||||
aux_stream: Optional[torch.cuda.Stream] = None):
|
||||
"""
|
||||
Initialize a SingleLayerMoeLoadBalancer instance.
|
||||
|
||||
@ -287,6 +291,7 @@ class SingleLayerMoeLoadBalancer:
|
||||
)
|
||||
self.expert_count = expert_count
|
||||
self.updates_enabled = updates_enabled
|
||||
self.repeated_count = repeated_count
|
||||
layer_id = self.single_layer_load_balancer_impl.get_layer_id()
|
||||
self.host_tensor_sharer = HostMoeTensorSharer(
|
||||
layer_id, expert_count,
|
||||
@ -303,15 +308,34 @@ class SingleLayerMoeLoadBalancer:
|
||||
self.expert_count)
|
||||
self.load_expert_ids = list(range(load_expert_start, load_expert_end))
|
||||
|
||||
if self.updates_enabled:
|
||||
self.aux_stream = aux_stream if aux_stream is not None else torch.cuda.Stream(
|
||||
)
|
||||
self.event_dict = {
|
||||
key: torch.cuda.Event()
|
||||
for key in [EventType.Main, EventType.MoeBalancer]
|
||||
}
|
||||
else:
|
||||
self.aux_stream = None
|
||||
self.event_dict = None
|
||||
|
||||
self.statistic_flag_tensor = None
|
||||
self.local_statistic_tensor = None
|
||||
|
||||
self.cudagraph_stream = None
|
||||
self.cudagraph_event = None
|
||||
self.repeated_count = repeated_count
|
||||
|
||||
self.statistic_stream = None
|
||||
self.statistic_event = None
|
||||
self.func_called_count = {
|
||||
name: 0
|
||||
for name in [
|
||||
"start_wait_gpu_stage",
|
||||
"done_wait_gpu_stage",
|
||||
"start_set_cpu_stage",
|
||||
"done_set_cpu_stage",
|
||||
"update_local_statistic",
|
||||
"get_local_statistic_tensor",
|
||||
"update_statistic_with_gathered_statistic",
|
||||
"update_statistic_with_local_ids",
|
||||
"update_statistic_with_global_ids",
|
||||
"route",
|
||||
]
|
||||
}
|
||||
|
||||
def get_layer_idx(self):
|
||||
return self.single_layer_load_balancer_impl.get_layer_id()
|
||||
@ -441,139 +465,94 @@ class SingleLayerMoeLoadBalancer:
|
||||
self.host_tensor_sharer.finalize_host_tensor_sharing(
|
||||
self._add_host_weight_from_tensor)
|
||||
|
||||
def wait_for_gpu_stage(self) -> Optional[torch.Tensor]:
|
||||
def start_wait_gpu_stage(self):
|
||||
"""
|
||||
Wait for the GPU stage to complete.
|
||||
|
||||
Returns:
|
||||
A tensor indicating whether the stage is enabled
|
||||
Start to wait for the GPU stage to complete.
|
||||
"""
|
||||
assert self.func_called_count["start_wait_gpu_stage"] == 0
|
||||
self.func_called_count["start_wait_gpu_stage"] += 1
|
||||
if self.updates_enabled:
|
||||
assert self.statistic_flag_tensor is None, \
|
||||
"Already has statistic_flag_tensor, should not wait."
|
||||
if is_graph_capturing():
|
||||
self.cudagraph_event = torch.cuda.Event()
|
||||
self.cudagraph_stream = torch.cuda.Stream()
|
||||
current_stream_event = torch.cuda.Event()
|
||||
current_stream_event.record(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.cudagraph_stream):
|
||||
current_stream_event.wait()
|
||||
self.event_dict[EventType.Main].record()
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
self.event_dict[EventType.Main].wait()
|
||||
self.statistic_flag_tensor = torch.ops.trtllm.moe_load_balance_wait_gpu_stage(
|
||||
self.single_layer_load_balancer_ptr)
|
||||
self.cudagraph_event.record(self.cudagraph_stream)
|
||||
self.event_dict[EventType.MoeBalancer].record()
|
||||
else:
|
||||
self.statistic_flag_tensor = torch.ops.trtllm.moe_load_balance_wait_gpu_stage(
|
||||
self.single_layer_load_balancer_ptr)
|
||||
return self.statistic_flag_tensor
|
||||
else:
|
||||
return
|
||||
|
||||
def maybe_cudagraph_done_wait(self):
|
||||
def done_wait_gpu_stage(self):
|
||||
"""
|
||||
Done waiting for the GPU stage to complete.
|
||||
"""
|
||||
assert self.func_called_count["start_wait_gpu_stage"] == 1
|
||||
assert self.func_called_count["done_wait_gpu_stage"] == 0
|
||||
self.func_called_count["done_wait_gpu_stage"] += 1
|
||||
if self.updates_enabled:
|
||||
if is_graph_capturing():
|
||||
assert self.cudagraph_event is not None, "should have cudagraph_event when capturing"
|
||||
assert self.cudagraph_stream is not None, "should have cudagraph_stream when capturing"
|
||||
self.cudagraph_event.wait()
|
||||
self.event_dict[EventType.MoeBalancer].wait()
|
||||
|
||||
def set_cpu_stage(self):
|
||||
def start_set_cpu_stage(self):
|
||||
"""
|
||||
Set the CPU stage.
|
||||
Start to set the CPU stage.
|
||||
"""
|
||||
assert self.func_called_count["done_wait_gpu_stage"] == 1
|
||||
assert self.func_called_count["start_set_cpu_stage"] == 0
|
||||
self.func_called_count["start_set_cpu_stage"] += 1
|
||||
if self.updates_enabled:
|
||||
assert self.statistic_flag_tensor is not None, \
|
||||
"Doesn't have statistic_flag_tensor, should not set_cpu_stage."
|
||||
self.statistic_flag_tensor = None
|
||||
if is_graph_capturing():
|
||||
assert self.cudagraph_stream is not None, "Doesn't have cudagraph_stream, should not set_cpu_stage."
|
||||
assert self.statistic_event is not None
|
||||
assert self.statistic_stream is not None
|
||||
# wait statistic update done
|
||||
current_stream_event = torch.cuda.Event()
|
||||
current_stream_event.record(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.cudagraph_stream):
|
||||
self.statistic_event.wait()
|
||||
current_stream_event.wait()
|
||||
self.event_dict[EventType.Main].record()
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
self.event_dict[EventType.Main].wait()
|
||||
torch.ops.trtllm.moe_load_balance_set_cpu_stage(
|
||||
self.single_layer_load_balancer_ptr)
|
||||
self.cudagraph_event.record(self.cudagraph_stream)
|
||||
self.statistic_event = None
|
||||
self.statistic_stream = None
|
||||
self.event_dict[EventType.MoeBalancer].record()
|
||||
else:
|
||||
torch.ops.trtllm.moe_load_balance_set_cpu_stage(
|
||||
self.single_layer_load_balancer_ptr)
|
||||
|
||||
def maybe_cudagraph_done_set_cpu_stage(self):
|
||||
def done_set_cpu_stage(self):
|
||||
"""
|
||||
Done setting the CPU stage.
|
||||
"""
|
||||
assert self.func_called_count["start_set_cpu_stage"] == 1
|
||||
for name in self.func_called_count:
|
||||
self.func_called_count[name] = 0
|
||||
self.statistic_flag_tensor = None
|
||||
if self.updates_enabled:
|
||||
if is_graph_capturing():
|
||||
assert self.cudagraph_event is not None, "should have cudagraph_event when capturing"
|
||||
assert self.cudagraph_stream is not None, "should have cudagraph_stream when capturing"
|
||||
self.cudagraph_event.wait()
|
||||
self.cudagraph_stream = None
|
||||
self.cudagraph_event = None
|
||||
self.event_dict[EventType.MoeBalancer].wait()
|
||||
|
||||
def statistic(self, gathered_raw_expert_ids: torch.Tensor,
|
||||
is_first_stage: bool, is_last_stage: bool):
|
||||
def update_local_statistic(self, local_raw_expert_ids: torch.Tensor,
|
||||
is_first_stage: bool, is_last_stage: bool):
|
||||
"""
|
||||
Perform statistics on the expert IDs.
|
||||
Update local statistics of the expert IDs.
|
||||
|
||||
Args:
|
||||
gathered_raw_expert_ids: The gathered raw expert IDs from all ranks
|
||||
local_raw_expert_ids: The local raw expert IDs
|
||||
is_first_stage: Whether this is the first stage
|
||||
is_last_stage: Whether this is the last stage
|
||||
"""
|
||||
assert self.func_called_count["done_wait_gpu_stage"] == 1
|
||||
assert self.func_called_count["update_statistic_with_global_ids"] == 0
|
||||
self.func_called_count["update_local_statistic"] += 1
|
||||
if self.updates_enabled:
|
||||
assert isinstance(self.statistic_flag_tensor, torch.Tensor)
|
||||
if is_graph_capturing():
|
||||
if is_first_stage:
|
||||
self.statistic_event = torch.cuda.Event()
|
||||
self.statistic_stream = torch.cuda.Stream()
|
||||
current_stream_event = torch.cuda.Event()
|
||||
current_stream_event.record(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.statistic_stream):
|
||||
current_stream_event.wait()
|
||||
torch.ops.trtllm.moe_load_balance_statistic(
|
||||
gathered_raw_expert_ids, self.statistic_flag_tensor,
|
||||
self.single_layer_load_balancer_ptr, is_first_stage,
|
||||
is_last_stage)
|
||||
self.statistic_event.record()
|
||||
else:
|
||||
torch.ops.trtllm.moe_load_balance_statistic(
|
||||
gathered_raw_expert_ids, self.statistic_flag_tensor,
|
||||
self.single_layer_load_balancer_ptr, is_first_stage,
|
||||
is_last_stage)
|
||||
|
||||
def local_statistic(self, local_raw_expert_ids: torch.Tensor,
|
||||
is_first_stage: bool, is_last_stage: bool):
|
||||
"""
|
||||
Perform local statistics on the expert IDs.
|
||||
|
||||
Args:
|
||||
local_raw_expert_ids: The gathered raw expert IDs from all ranks
|
||||
is_first_stage: Whether this is the first stage
|
||||
is_last_stage: Whether this is the last stage
|
||||
"""
|
||||
if self.updates_enabled:
|
||||
assert isinstance(self.statistic_flag_tensor, torch.Tensor)
|
||||
if is_first_stage:
|
||||
assert self.local_statistic_tensor is None
|
||||
if self.local_statistic_tensor is None:
|
||||
self.local_statistic_tensor = torch.empty(
|
||||
(self.expert_count, ),
|
||||
dtype=torch.int32,
|
||||
device=torch.device('cuda'))
|
||||
if is_graph_capturing():
|
||||
if is_first_stage:
|
||||
self.statistic_event = torch.cuda.Event()
|
||||
self.statistic_stream = torch.cuda.Stream()
|
||||
current_stream_event = torch.cuda.Event()
|
||||
current_stream_event.record(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.statistic_stream):
|
||||
current_stream_event.wait()
|
||||
self.event_dict[EventType.Main].record()
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
self.event_dict[EventType.Main].wait()
|
||||
torch.ops.trtllm.moe_hierarchical_statistic_local_device(
|
||||
local_raw_expert_ids, self.local_statistic_tensor,
|
||||
self.statistic_flag_tensor,
|
||||
self.single_layer_load_balancer_ptr, is_first_stage,
|
||||
is_last_stage)
|
||||
self.statistic_event.record(self.statistic_stream)
|
||||
else:
|
||||
torch.ops.trtllm.moe_hierarchical_statistic_local_device(
|
||||
local_raw_expert_ids, self.local_statistic_tensor,
|
||||
@ -581,48 +560,119 @@ class SingleLayerMoeLoadBalancer:
|
||||
self.single_layer_load_balancer_ptr, is_first_stage,
|
||||
is_last_stage)
|
||||
|
||||
def get_local_statistic_tensor(self):
|
||||
def get_local_statistic_tensor(self) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Get the local statistic tensor. Should perform allreduce on it and then call update_statistic
|
||||
Get the local statistic tensor.
|
||||
Returns:
|
||||
The local statistic tensor if using statistic else None
|
||||
"""
|
||||
assert self.func_called_count["update_local_statistic"] > 0
|
||||
self.func_called_count["get_local_statistic_tensor"] += 1
|
||||
if self.updates_enabled:
|
||||
assert self.local_statistic_tensor is not None
|
||||
if is_graph_capturing():
|
||||
assert self.statistic_event is not None
|
||||
assert self.statistic_stream is not None
|
||||
self.statistic_event.wait()
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
self.event_dict[EventType.MoeBalancer].record()
|
||||
self.event_dict[EventType.MoeBalancer].wait()
|
||||
return self.local_statistic_tensor
|
||||
return None
|
||||
|
||||
def update_statistic(self, gathered_local_statistic_tensor: torch.Tensor):
|
||||
def update_statistic_with_gathered_statistic(
|
||||
self, gathered_local_statistic_tensor: torch.Tensor):
|
||||
"""
|
||||
Perform update with global statistics.
|
||||
Update statistics of the expert IDs, using gathered local statistic tensors.
|
||||
|
||||
Args:
|
||||
gathered_local_statistic_tensor: gathered local statistics info, should have shape (world_size, self.expert_count)
|
||||
"""
|
||||
if self.updates_enabled:
|
||||
assert isinstance(self.statistic_flag_tensor, torch.Tensor)
|
||||
assert self.func_called_count["get_local_statistic_tensor"] > 0
|
||||
assert self.func_called_count["update_statistic_with_local_ids"] == 0
|
||||
assert self.func_called_count["update_statistic_with_global_ids"] == 0
|
||||
self.func_called_count["update_statistic_with_gathered_statistic"] += 1
|
||||
|
||||
def _update_statistic():
|
||||
global_statistic_info = torch.sum(
|
||||
gathered_local_statistic_tensor, dim=0, dtype=torch.int32)
|
||||
def _update_statistic():
|
||||
global_statistic_info = torch.sum(gathered_local_statistic_tensor,
|
||||
dim=0,
|
||||
dtype=torch.int32)
|
||||
torch.ops.trtllm.moe_hierarchical_statistic_update(
|
||||
global_statistic_info, self.statistic_flag_tensor,
|
||||
self.single_layer_load_balancer_ptr)
|
||||
|
||||
if self.updates_enabled:
|
||||
if is_graph_capturing():
|
||||
self.event_dict[EventType.Main].record()
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
self.event_dict[EventType.Main].wait()
|
||||
_update_statistic()
|
||||
else:
|
||||
_update_statistic()
|
||||
|
||||
def update_statistic_with_local_ids(self,
|
||||
local_raw_expert_ids: torch.Tensor,
|
||||
is_first_stage: bool,
|
||||
is_last_stage: bool,
|
||||
allreduce: Optional[AllReduce] = None):
|
||||
"""
|
||||
Update statistics of the expert IDs, using local raw expert IDs.
|
||||
|
||||
Args:
|
||||
local_raw_expert_ids: The local raw expert IDs
|
||||
is_first_stage: Whether this is the first stage
|
||||
is_last_stage: Whether this is the last stage
|
||||
allreduce: The allreduce object
|
||||
"""
|
||||
assert self.func_called_count["done_wait_gpu_stage"] == 1
|
||||
assert self.func_called_count[
|
||||
"update_statistic_with_gathered_statistic"] == 0
|
||||
assert self.func_called_count["update_statistic_with_global_ids"] == 0
|
||||
self.func_called_count["update_statistic_with_local_ids"] += 1
|
||||
|
||||
def _update_statistic():
|
||||
if is_last_stage:
|
||||
global_statistic_info = allreduce(self.local_statistic_tensor)
|
||||
torch.ops.trtllm.moe_hierarchical_statistic_update(
|
||||
global_statistic_info, self.statistic_flag_tensor,
|
||||
self.single_layer_load_balancer_ptr)
|
||||
|
||||
if self.updates_enabled:
|
||||
self.update_local_statistic(local_raw_expert_ids, is_first_stage,
|
||||
is_last_stage)
|
||||
if is_graph_capturing():
|
||||
current_stream_event = torch.cuda.Event()
|
||||
current_stream_event.record(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.statistic_stream):
|
||||
current_stream_event.wait()
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
_update_statistic()
|
||||
self.statistic_event.record(self.statistic_stream)
|
||||
else:
|
||||
_update_statistic()
|
||||
self.local_statistic_tensor = None
|
||||
|
||||
def update_statistic_with_global_ids(self,
|
||||
gathered_raw_expert_ids: torch.Tensor,
|
||||
is_first_stage: bool,
|
||||
is_last_stage: bool):
|
||||
"""
|
||||
Update statistics of the expert IDs, using gathered raw expert IDs from all ranks.
|
||||
|
||||
Args:
|
||||
gathered_raw_expert_ids: The gathered raw expert IDs from all ranks
|
||||
is_first_stage: Whether this is the first stage
|
||||
is_last_stage: Whether this is the last stage
|
||||
"""
|
||||
assert self.func_called_count["done_wait_gpu_stage"] == 1
|
||||
assert self.func_called_count[
|
||||
"update_statistic_with_gathered_statistic"] == 0
|
||||
assert self.func_called_count["update_statistic_with_local_ids"] == 0
|
||||
self.func_called_count["update_statistic_with_global_ids"] += 1
|
||||
if self.updates_enabled:
|
||||
if is_graph_capturing():
|
||||
self.event_dict[EventType.Main].record()
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
self.event_dict[EventType.Main].wait()
|
||||
torch.ops.trtllm.moe_load_balance_statistic(
|
||||
gathered_raw_expert_ids, self.statistic_flag_tensor,
|
||||
self.single_layer_load_balancer_ptr, is_first_stage,
|
||||
is_last_stage)
|
||||
else:
|
||||
torch.ops.trtllm.moe_load_balance_statistic(
|
||||
gathered_raw_expert_ids, self.statistic_flag_tensor,
|
||||
self.single_layer_load_balancer_ptr, is_first_stage,
|
||||
is_last_stage)
|
||||
|
||||
def route(self,
|
||||
token_selected_experts: torch.Tensor,
|
||||
@ -637,6 +687,8 @@ class SingleLayerMoeLoadBalancer:
|
||||
Returns:
|
||||
A tensor of routed slot IDs
|
||||
"""
|
||||
assert self.func_called_count["done_wait_gpu_stage"] == 1
|
||||
self.func_called_count["route"] += 1
|
||||
return torch.ops.trtllm.moe_load_balance_routing(
|
||||
token_selected_experts, offset_by_ep_rank,
|
||||
self.single_layer_load_balancer_ptr)
|
||||
@ -731,8 +783,13 @@ class MoeLoadBalancer:
|
||||
assert repeated_count > 0, "repeat count must be greater than 0"
|
||||
self.next_layer_repeated_count = repeated_count
|
||||
|
||||
def add_layer(self, expert_count: int, top_k: int,
|
||||
slot_count_per_rank: int) -> SingleLayerMoeLoadBalancer:
|
||||
def add_layer(
|
||||
self,
|
||||
expert_count: int,
|
||||
top_k: int,
|
||||
slot_count_per_rank: int,
|
||||
aux_stream: Optional[torch.cuda.Stream] = None
|
||||
) -> SingleLayerMoeLoadBalancer:
|
||||
"""
|
||||
Add a new layer to the load balancer.
|
||||
|
||||
@ -740,6 +797,7 @@ class MoeLoadBalancer:
|
||||
expert_count: The number of experts in the layer
|
||||
top_k: The number of experts each token selects
|
||||
slot_count_per_rank: The number of slots per rank
|
||||
aux_stream: The auxiliary stream for overlapping
|
||||
|
||||
Returns:
|
||||
A SingleLayerMoeLoadBalancer instance for the new layer
|
||||
@ -756,7 +814,8 @@ class MoeLoadBalancer:
|
||||
self.shared_mpi_comm,
|
||||
expert_count,
|
||||
updates_enabled=updates_enabled,
|
||||
repeated_count=repeat_count)
|
||||
repeated_count=repeat_count,
|
||||
aux_stream=aux_stream)
|
||||
single_layer_load_balancer.set_shared_memory_base_name(
|
||||
self.shared_memory_base_name)
|
||||
self.single_layer_load_balancers.append(single_layer_load_balancer)
|
||||
@ -988,8 +1047,11 @@ def moe_load_balancer_set_repeated_for_next_layer(repeat_count: int):
|
||||
|
||||
|
||||
def moe_load_balancer_add_single_layer(
|
||||
expert_count: int, top_k: int,
|
||||
slot_count_per_rank: int) -> Optional[SingleLayerMoeLoadBalancer]:
|
||||
expert_count: int,
|
||||
top_k: int,
|
||||
slot_count_per_rank: int,
|
||||
aux_stream: Optional[torch.cuda.Stream] = None
|
||||
) -> Optional[SingleLayerMoeLoadBalancer]:
|
||||
"""
|
||||
Add a new layer to the current active MoeLoadBalancer.
|
||||
|
||||
@ -997,11 +1059,13 @@ def moe_load_balancer_add_single_layer(
|
||||
expert_count: The number of experts in the layer
|
||||
top_k: The number of experts each token selects
|
||||
slot_count_per_rank: The number of slots per rank
|
||||
aux_stream: The auxiliary stream for overlapping
|
||||
|
||||
Returns:
|
||||
A SingleLayerMoeLoadBalancer instance for the new layer, or None if not in a MoeLoadBalancer context
|
||||
"""
|
||||
load_balancer = get_moe_load_balancer()
|
||||
if load_balancer is not None:
|
||||
return load_balancer.add_layer(expert_count, top_k, slot_count_per_rank)
|
||||
return load_balancer.add_layer(expert_count, top_k, slot_count_per_rank,
|
||||
aux_stream)
|
||||
return None
|
||||
|
||||
@ -12,7 +12,12 @@ from tensorrt_llm.quantization.utils import fp4_utils
|
||||
|
||||
is_torch_compiling_flag = False
|
||||
|
||||
aux_stream_name_list = ['Attention', 'MoeShared', 'MoeChunkingOverlap']
|
||||
aux_stream_name_list = [
|
||||
'Attention',
|
||||
'MoeShared',
|
||||
'MoeChunkingOverlap',
|
||||
'MoeBalancer',
|
||||
]
|
||||
AuxStreamType = Enum(
|
||||
'AuxStreamType',
|
||||
aux_stream_name_list,
|
||||
|
||||
@ -222,22 +222,18 @@ class TestMoeLoadBalancer(unittest.TestCase):
|
||||
|
||||
# wait_for_gpu_stage
|
||||
mock_wait.return_value = torch.tensor([1])
|
||||
layer.wait_for_gpu_stage()
|
||||
layer.start_wait_gpu_stage()
|
||||
layer.done_wait_gpu_stage()
|
||||
result = layer.statistic_flag_tensor
|
||||
mock_wait.assert_called_once_with(
|
||||
mock_single_layer_impl.get_pointer())
|
||||
self.assertEqual(result, mock_wait.return_value)
|
||||
|
||||
# set_cpu_stage
|
||||
layer.set_cpu_stage()
|
||||
mock_set_cpu.assert_called_once_with(
|
||||
mock_single_layer_impl.get_pointer())
|
||||
|
||||
# statistic
|
||||
mock_expert_ids = torch.tensor([[0, 1], [2, 3]])
|
||||
mock_enabled = torch.tensor([1])
|
||||
layer.statistic_flag_tensor = mock_enabled
|
||||
layer.statistic(mock_expert_ids, True, False)
|
||||
layer.update_statistic_with_global_ids(mock_expert_ids, True, False)
|
||||
mock_statistic.assert_called_once_with(
|
||||
mock_expert_ids, mock_enabled,
|
||||
mock_single_layer_impl.get_pointer(), True, False)
|
||||
@ -248,6 +244,12 @@ class TestMoeLoadBalancer(unittest.TestCase):
|
||||
result = layer.route(mock_selected_experts)
|
||||
assert torch.equal(result, mock_route.return_value)
|
||||
|
||||
# set_cpu_stage
|
||||
layer.start_set_cpu_stage()
|
||||
layer.done_set_cpu_stage()
|
||||
mock_set_cpu.assert_called_once_with(
|
||||
mock_single_layer_impl.get_pointer())
|
||||
|
||||
@patch('tensorrt_llm.bindings.internal.runtime.MoeLoadBalancer')
|
||||
def test_moe_load_balancer_lifecycle_methods(self, mock_load_balancer_impl):
|
||||
"""Test lifecycle methods of MoeLoadBalancer."""
|
||||
@ -323,13 +325,16 @@ class TestMoeLoadBalancer(unittest.TestCase):
|
||||
try:
|
||||
with MoeLoadBalancerIterContext(balancer):
|
||||
# Wait for GPU stage and get enabled flag
|
||||
layer.wait_for_gpu_stage()
|
||||
layer.start_wait_gpu_stage()
|
||||
layer.done_wait_gpu_stage()
|
||||
|
||||
# Run statistic - just test it runs without error
|
||||
layer.statistic(gathered_raw_expert_ids, True, True)
|
||||
layer.update_statistic_with_global_ids(gathered_raw_expert_ids,
|
||||
True, True)
|
||||
|
||||
# Set CPU stage to signal completion
|
||||
layer.set_cpu_stage()
|
||||
layer.start_set_cpu_stage()
|
||||
layer.done_set_cpu_stage()
|
||||
|
||||
# Test passed if we got here without exceptions
|
||||
self.assertTrue(True, "Statistic kernel ran successfully")
|
||||
@ -384,13 +389,15 @@ class TestMoeLoadBalancer(unittest.TestCase):
|
||||
try:
|
||||
with MoeLoadBalancerIterContext(balancer):
|
||||
# Wait for GPU stage
|
||||
layer.wait_for_gpu_stage()
|
||||
layer.start_wait_gpu_stage()
|
||||
layer.done_wait_gpu_stage()
|
||||
|
||||
# Run routing
|
||||
routed_slots = layer.route(token_selected_experts)
|
||||
|
||||
# Set CPU stage
|
||||
layer.set_cpu_stage()
|
||||
layer.start_set_cpu_stage()
|
||||
layer.done_set_cpu_stage()
|
||||
|
||||
# Verify results - with our initial assignment, expert i should map to slot i
|
||||
expected_slots = torch.tensor(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user