feat: large-scale EP(part 4: Static EP load balancer integration) (#4615)

* MoeLoadBalancerConfig

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* MoeLoadBalancer integration

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* config file

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* test

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* test

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* fix

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

---------

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
Enwei Zhu 2025-05-26 18:25:11 +08:00 committed by GitHub
parent 44eb053b95
commit 88190faa34
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 212 additions and 55 deletions

View File

@ -15,6 +15,51 @@ from tensorrt_llm.quantization.mode import QuantAlgo
TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig)
@dataclass
class MoeLoadBalancerConfig:
num_slots: Optional[int] = None
initial_global_assignments: Optional[Dict[int, List[int]]] = None
layer_updates_per_iter: int = 0
num_experts: Optional[int] = field(default=None, init=False)
ep_rank: Optional[int] = field(default=None, init=False)
ep_size: Optional[int] = field(default=None, init=False)
def setup(self, num_experts: int, ep_rank: int, ep_size: int) -> None:
self.num_experts = num_experts
self.ep_rank = ep_rank
self.ep_size = ep_size
if self.num_slots is None:
self.num_slots = self.num_experts
assert self.num_slots >= self.num_experts
assert self.num_slots % self.ep_size == 0
@property
def num_local_slots(self) -> int:
return self.num_slots // self.ep_size
@property
def slot_start(self) -> int:
return self.ep_rank * self.num_local_slots
@property
def slot_end(self) -> int:
return self.slot_start + self.num_local_slots
def get_layer_initial_global_assignments(self, layer_idx: int) -> List[int]:
if self.initial_global_assignments is None:
return [(ep_rank * self.num_experts // self.ep_size + i) %
self.num_experts for ep_rank in range(self.ep_size)
for i in range(self.num_local_slots)]
else:
assert layer_idx in self.initial_global_assignments
assert len(
self.initial_global_assignments[layer_idx]) == self.num_slots
assert set(self.initial_global_assignments[layer_idx]) == set(
range(self.num_experts))
return self.initial_global_assignments[layer_idx]
@dataclass(kw_only=True)
class ModelConfig(Generic[TConfig]):
pretrained_config: Optional[TConfig] = None
@ -28,6 +73,7 @@ class ModelConfig(Generic[TConfig]):
is_generation: bool = True
max_num_tokens: int = 8192
moe_max_num_tokens: Optional[int] = None
moe_load_balancer: Optional[MoeLoadBalancerConfig] = None
attn_backend: str = 'TRTLLM'
moe_backend: str = 'CUTLASS' # options can be CUTLASS, TRTLLM

View File

@ -54,6 +54,7 @@ from ..modules.embedding import Embedding
from ..modules.fused_moe import DeepSeekV3MoeRoutingMethod, FusedMoE
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import Linear
from ..modules.moe_load_balancer import MoeLoadBalancer
from ..modules.multi_stream_utils import maybe_execute_in_parallel
from ..modules.rms_norm import RMSNorm
from ..speculative import MTPEagleWorker, MTPSpecMetadata, MTPWorker
@ -339,7 +340,9 @@ class Deepseekv3MoE(nn.Module):
shared_expert_intermediate_size: int,
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
dtype: Optional[torch.dtype] = None,
model_config: ModelConfig = ModelConfig()):
model_config: ModelConfig = ModelConfig(),
moe_load_balancer: Optional[MoeLoadBalancer] = None,
layer_idx: Optional[int] = None):
from ..distributed import AllReduce
super().__init__()
@ -371,7 +374,9 @@ class Deepseekv3MoE(nn.Module):
False, # In both lowlatency and attentionDP modes, FusedMoE skips the inop allreduce.
model_config=model_config,
aux_stream=aux_stream_dict[AuxStreamType.MoeChunkingOverlap],
enable_alltoall=self.enable_alltoall)
enable_alltoall=self.enable_alltoall,
moe_load_balancer=moe_load_balancer,
layer_idx=layer_idx)
self.mapping = model_config.mapping
@ -531,9 +536,11 @@ class Deepseekv3MoE(nn.Module):
class DeepseekV3DecoderLayer(DecoderLayer):
def __init__(self, model_config: ModelConfig[PretrainedConfig],
layer_idx: int, aux_stream_dict: Dict[AuxStreamType,
torch.cuda.Stream]):
def __init__(self,
model_config: ModelConfig[PretrainedConfig],
layer_idx: int,
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
moe_load_balancer: Optional[MoeLoadBalancer] = None):
super().__init__()
self.model_config = model_config
config = model_config.pretrained_config
@ -580,7 +587,9 @@ class DeepseekV3DecoderLayer(DecoderLayer):
self.num_shared_experts,
dtype=config.torch_dtype,
model_config=model_config,
aux_stream_dict=aux_stream_dict)
aux_stream_dict=aux_stream_dict,
moe_load_balancer=moe_load_balancer,
layer_idx=layer_idx)
else:
block_size = 1
if model_config.quant_config and model_config.quant_config.group_size is not None:
@ -952,11 +961,27 @@ class DeepseekV3Model(DecoderModel):
dtype=config.torch_dtype,
)
self.moe_load_balancer = None
if model_config.moe_load_balancer is not None:
num_experts = config.n_routed_experts
ep_rank = model_config.mapping.moe_ep_rank
ep_size = model_config.mapping.moe_ep_size
model_config.moe_load_balancer.setup(num_experts=num_experts,
ep_rank=ep_rank,
ep_size=ep_size)
self.moe_load_balancer = MoeLoadBalancer(
ep_rank=ep_rank,
ep_size=ep_size,
layer_updates_per_iter=model_config.moe_load_balancer.
layer_updates_per_iter)
self.layers = nn.ModuleList([
DeepseekV3DecoderLayer(model_config, layer_idx,
self.aux_stream_dict)
self.aux_stream_dict, self.moe_load_balancer)
for layer_idx in range(config.num_hidden_layers)
])
if self.moe_load_balancer is not None:
self.moe_load_balancer.finalize_model()
self.norm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)

View File

@ -8,7 +8,7 @@ import torch
from torch import nn
from tensorrt_llm._mnnvl_utils import MnnvlMoe, MoEAlltoallInfo
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm._utils import get_sm_version, logger
from tensorrt_llm.quantization.utils.fp4_utils import (
get_reorder_rows_for_gated_act_gemm_row_indices,
get_shuffle_matrix_a_row_indices, get_shuffle_matrix_sf_a_row_indices,
@ -16,10 +16,11 @@ from tensorrt_llm.quantization.utils.fp4_utils import (
from ...quantization.utils.fp4_utils import float4_sf_dtype
from ..distributed import allgather, reducescatter
from ..model_config import ModelConfig
from ..model_config import ModelConfig, MoeLoadBalancerConfig
from ..utils import (EventType, Fp4QuantizedTensor, disable_fp4_allgather,
reswizzle_sf, swizzle_sf, unswizzle_sf)
from .linear import TensorParallelMode, load_weight_shard
from .moe_load_balancer import MoeLoadBalancer
# The declarations aligns with moe_kernels.h
# pack inputs into int64, e.g. 4 x bf16 input values
@ -367,6 +368,8 @@ class FusedMoE(nn.Module):
VANILLA,
apply_router_weight_on_input: bool = False,
enable_alltoall: bool = False,
moe_load_balancer: Optional[MoeLoadBalancer] = None,
layer_idx: Optional[int] = None,
):
from ..distributed import AllReduce
@ -404,25 +407,46 @@ class FusedMoE(nn.Module):
self.intermediate_size_per_partition = intermediate_size // self.tp_size
# self.expert_slots_per_partition will be replaced with real slots_per_partition to enable redundant expert slots
self.expert_slots_per_partition = num_experts // self.ep_size
assert self.expert_slots_per_partition * self.ep_size >= num_experts, "total slots should be at lease num_experts"
if self.smart_router:
assert self.expert_slots_per_partition == num_experts // self.ep_size,\
"Smart router should not have redundant slots"
self.num_slots = self.expert_slots_per_partition * self.ep_size
# Here the meaning of expert_size_per_partition is the number of expert slots that each rank has.
self.expert_size_per_partition = self.expert_slots_per_partition
self.slot_start = self.ep_rank * self.expert_size_per_partition
self.slot_end = self.slot_start + self.expert_size_per_partition
moe_load_balancer_config = model_config.moe_load_balancer
if moe_load_balancer_config is None:
assert moe_load_balancer is None
# A dummy MoeLoadBalancerConfig to generate default initial_global_assignments and initial_local_expert_ids
moe_load_balancer_config = MoeLoadBalancerConfig()
moe_load_balancer_config.setup(num_experts=num_experts,
ep_rank=self.ep_rank,
ep_size=self.ep_size)
else:
assert moe_load_balancer is not None
self.initial_global_assignments = [
(ep_rank * self.num_experts // self.ep_size + local_slot_id) %
self.num_experts for ep_rank in range(self.ep_size)
for local_slot_id in range(self.expert_slots_per_partition)
]
self.num_slots = moe_load_balancer_config.num_slots
if self.smart_router:
assert self.num_slots == self.num_experts, "Smart router should not have redundant slots"
self.initial_global_assignments = moe_load_balancer_config.get_layer_initial_global_assignments(
layer_idx)
self.expert_size_per_partition = moe_load_balancer_config.num_local_slots
self.slot_start = moe_load_balancer_config.slot_start
self.slot_end = moe_load_balancer_config.slot_end
self.initial_local_expert_ids = self.initial_global_assignments[
self.slot_start:self.slot_end]
assert len(
self.initial_local_expert_ids) == self.expert_size_per_partition
self.balancer_layer = None
if moe_load_balancer is not None:
self.balancer_layer = moe_load_balancer.add_layer(
expert_count=num_experts,
top_k=routing_method.experts_per_token,
slot_count_per_rank=self.expert_size_per_partition,
)
self.balancer_layer.set_initial_weight_assignments(
self.initial_global_assignments)
logger.info(
f"MoE load balancer enabled. num_experts = {num_experts}, num_slots = {self.num_slots}, ep_size = {self.ep_size}"
)
logger.info(
f"initial_global_assignments (layer {layer_idx}) = {self.initial_global_assignments}"
)
max_num_tokens = model_config.max_num_tokens
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
@ -854,13 +878,18 @@ class FusedMoE(nn.Module):
token_selected_experts, token_final_scales = self.routing_method.apply(
router_logits)
if self.balancer_layer is None:
token_selected_slots = token_selected_experts
else:
token_selected_slots = self.balancer_layer.route(
token_selected_experts)
assert token_selected_experts.shape[
assert token_selected_slots.shape[
1] == self.routing_method.experts_per_token
assert token_selected_experts.shape == token_final_scales.shape
assert token_selected_experts.shape[0] == router_logits.shape[0]
assert token_selected_slots.shape == token_final_scales.shape
assert token_selected_slots.shape[0] == router_logits.shape[0]
assert token_final_scales.dtype == torch.float32
assert token_selected_experts.dtype == torch.int32
assert token_selected_slots.dtype == torch.int32
if self.apply_router_weight_on_input:
x = x * token_final_scales.to(x.dtype)
@ -872,10 +901,10 @@ class FusedMoE(nn.Module):
alltoall_info = None
if self.enable_alltoall:
x, token_selected_experts, token_final_scales, alltoall_info = \
x, token_selected_slots, token_final_scales, alltoall_info = \
self.alltoall_prepare_maybe_dispatch(all_rank_num_tokens,
x,
token_selected_experts,
token_selected_slots,
token_final_scales)
x_sf = None
@ -910,15 +939,15 @@ class FusedMoE(nn.Module):
if self.use_dp and self.parallel_size > 1 and not disable_fp4_allgather(
) and not self.enable_alltoall:
if x_sf is None:
x, token_selected_experts, token_final_scales = allgather(
[x, token_selected_experts, token_final_scales],
x, token_selected_slots, token_final_scales = allgather(
[x, token_selected_slots, token_final_scales],
self.mapping,
dim=0,
sizes=None if use_dp_padding else all_rank_num_tokens)
else:
# Fp4 gemm has extra scaling factor
x, x_sf, token_selected_experts, token_final_scales = allgather(
[x, x_sf, token_selected_experts, token_final_scales],
x, x_sf, token_selected_slots, token_final_scales = allgather(
[x, x_sf, token_selected_slots, token_final_scales],
self.mapping,
dim=0,
sizes=None if use_dp_padding else all_rank_num_tokens)
@ -953,7 +982,7 @@ class FusedMoE(nn.Module):
final_hidden_states = torch.ops.trtllm.fused_moe(
x,
token_selected_experts,
token_selected_slots,
token_final_scales,
w3_w1_weight.view(weight_dtype),
w2_weight.view(weight_dtype),
@ -1234,31 +1263,29 @@ class FusedMoE(nn.Module):
def alltoall_prepare_maybe_dispatch(self, all_rank_num_tokens: list,
x: torch.Tensor,
token_selected_experts: torch.Tensor,
token_selected_slots: torch.Tensor,
token_final_scales: torch.Tensor):
top_k = self.routing_method.experts_per_token
expert_count = self.num_experts
# gather router info
max_num_token = max(all_rank_num_tokens)
token_selected_experts = torch.nn.functional.pad(
token_selected_experts,
(0, 0, 0, max_num_token - token_selected_experts.shape[0]),
token_selected_slots = torch.nn.functional.pad(
token_selected_slots,
(0, 0, 0, max_num_token - token_selected_slots.shape[0]),
'constant', self.num_experts)
token_final_scales = torch.nn.functional.pad(
token_final_scales,
(0, 0, 0, max_num_token - token_final_scales.shape[0]))
gathered_token_selected_experts, gathered_token_final_scales = allgather(
[token_selected_experts, token_final_scales], self.mapping, dim=0)
gathered_token_selected_experts = torch.flatten(
gathered_token_selected_experts.contiguous(),
start_dim=0,
end_dim=-2)
gathered_token_selected_slots, gathered_token_final_scales = allgather(
[token_selected_slots, token_final_scales], self.mapping, dim=0)
gathered_token_selected_slots = torch.flatten(
gathered_token_selected_slots.contiguous(), start_dim=0, end_dim=-2)
gathered_token_final_scales = torch.flatten(
gathered_token_final_scales.contiguous(), start_dim=0, end_dim=-2)
gathered_target_rank_ids = MnnvlMoe.compute_target_rank_id(
gathered_token_selected_experts, self.num_experts, self.ep_size)
alltoall_info, token_selected_experts, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv_prepare(
gathered_target_rank_ids, None, gathered_token_selected_experts,
gathered_token_selected_slots, self.num_experts, self.ep_size)
alltoall_info, token_selected_slots, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv_prepare(
gathered_target_rank_ids, None, gathered_token_selected_slots,
gathered_token_final_scales, max_num_token, expert_count, top_k,
self.ep_rank, self.ep_size)
@ -1270,7 +1297,7 @@ class FusedMoE(nn.Module):
self.alltoall_workspace,
self.ep_rank, self.ep_size)
return x, token_selected_experts, token_final_scales, alltoall_info
return x, token_selected_slots, token_final_scales, alltoall_info
def alltoall_postquant_dispatch(self, x: torch.Tensor, x_sf: torch.Tensor,
x_row: int, x_col: int,

View File

@ -1,13 +1,18 @@
import json
import math
import os
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Union
import yaml
from tensorrt_llm.bindings.executor import ExecutorConfig
from ...builder import BuildConfig
from ...logger import logger
from ...mapping import Mapping
from ..model_config import MoeLoadBalancerConfig
from ..speculative import SpecConfig
from .resource_manager import BaseResourceManager
@ -49,6 +54,7 @@ class PyTorchConfig:
# If set, at most moe_max_num_tokens tokens will be sent to torch.ops.trtllm.fused_moe at the same time.
# If the number of tokens exceeds moe_max_num_tokens, the input tensors will be split into chunks and a for loop will be used.
moe_max_num_tokens: Optional[int] = None
moe_load_balancer: Optional[Union[MoeLoadBalancerConfig, dict, str]] = None
attn_backend: str = 'TRTLLM'
moe_backend: str = 'CUTLASS'
@ -127,6 +133,22 @@ class PyTorchConfig:
self.cuda_graph_batch_sizes.append(
self.cuda_graph_max_batch_size)
if isinstance(self.moe_load_balancer, str):
assert os.path.exists(self.moe_load_balancer)
if self.moe_load_balancer.endswith(".json"):
with open(self.moe_load_balancer) as f:
self.moe_load_balancer = json.load(f)
elif self.moe_load_balancer.endswith((".yaml", ".yml")):
with open(self.moe_load_balancer) as f:
self.moe_load_balancer = yaml.safe_load(f)
else:
raise ValueError(
f"Unsupported moe load balancer config file: {self.moe_load_balancer}"
)
if isinstance(self.moe_load_balancer, dict):
self.moe_load_balancer = MoeLoadBalancerConfig(
**self.moe_load_balancer)
self._convert_load_format()

View File

@ -40,7 +40,7 @@ from ..compilation.utils import set_enable_piecewise_cuda_graph_capture_flag
from ..distributed import MPIDist
from ..distributed.communicator import init_pp_comm
from ..metadata import KVCacheParams
from ..model_config import ModelConfig
from ..model_config import ModelConfig, MoeLoadBalancerConfig
from ..models import AutoModelForCausalLM
from ..models.modeling_utils import (DecoderModelForCausalLM, MetaInitMode,
timing)
@ -323,6 +323,7 @@ class PyTorchModelEngine(ModelEngine):
load_format=pytorch_backend_config.load_format,
max_num_tokens=max_num_tokens,
moe_max_num_tokens=pytorch_backend_config.moe_max_num_tokens,
moe_load_balancer=pytorch_backend_config.moe_load_balancer,
lora_config=lora_config)
# In case that some tests use stub models and override `_load_model`.
if not hasattr(self.model, 'extra_attrs'):
@ -880,7 +881,8 @@ class PyTorchModelEngine(ModelEngine):
checkpoint_dir: str,
load_format: LoadFormat,
max_num_tokens: int,
moe_max_num_tokens: int,
moe_max_num_tokens: Optional[int] = None,
moe_load_balancer: Optional[MoeLoadBalancerConfig] = None,
lora_config: Optional[LoraConfig] = None,
**kwargs):
config = ModelConfig.from_pretrained(checkpoint_dir,
@ -889,6 +891,7 @@ class PyTorchModelEngine(ModelEngine):
config.spec_config = self.spec_config
config.max_num_tokens = max_num_tokens
config.moe_max_num_tokens = moe_max_num_tokens
config.moe_load_balancer = moe_load_balancer
config.lora_config = lora_config
validate_and_set_kv_cache_quant(

View File

@ -15,7 +15,8 @@
import pytest
from tensorrt_llm._torch import LLM
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm._torch.pyexecutor.config import (MoeLoadBalancerConfig,
PyTorchConfig)
from tensorrt_llm.llmapi import KvCacheConfig, MTPDecodingConfig, SamplingParams
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization import QuantAlgo
@ -511,10 +512,10 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
pytorch_backend_config=pytorch_config)
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
with llm:
task = CnnDailymail(self.MODEL_NAME)
task.evaluate(llm)
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@pytest.mark.skip_less_device(4)
@pytest.mark.skip_device_not_contain(["H100", "H200"])
@ -582,6 +583,38 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@pytest.mark.skip_less_device(4)
@pytest.mark.skip_device_not_contain(["H100", "H200"])
def test_fp8_block_scales_4gpus_static_eplb(self):
# OOM on H100 with default free_gpu_memory_fraction=0.9
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
num_experts = 72
num_slots = 80
first_k_dense_replace = 1
num_hidden_layers = 30
initial_global_assignments = {}
for i in range(first_k_dense_replace, num_hidden_layers):
initial_global_assignments[i] = [(i + j) % num_experts
for j in range(num_slots)]
eplb_config = MoeLoadBalancerConfig(
num_slots=num_slots,
initial_global_assignments=initial_global_assignments,
layer_updates_per_iter=0)
pytorch_config = PyTorchConfig(use_cuda_graph=True,
moe_load_balancer=eplb_config)
llm = LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
tensor_parallel_size=4,
moe_expert_parallel_size=4,
kv_cache_config=kv_cache_config,
pytorch_backend_config=pytorch_config,
enable_attention_dp=True)
with llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@skip_pre_blackwell
@parametrize_with_ids("torch_compile", [False, True])
@parametrize_with_ids("fp8kv,attention_dp,cuda_graph,overlap_scheduler",

View File

@ -100,6 +100,7 @@ l0_dgx_h100:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus_static_eplb
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8[DeepSeek-V3-Lite-fp8]
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8]
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8]