TensorRT-LLMs/examples/ep_load_balancer/generate_eplb_config.py
Enwei Zhu 5b4852b7b5
feat: large-scale EP(part 5: Static EP load balancer with offline statistics) (#4695)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
2025-06-02 01:25:02 +08:00

174 lines
7.0 KiB
Python

import argparse
import glob
import json
import os
import pandas as pd
import safetensors
import torch
import yaml
from tensorrt_llm.bindings.internal.runtime import (MoeLoadBalanceMetaInfo,
MoePlacementCpuInfo,
do_placement,
do_replication)
from tensorrt_llm.logger import logger
logger.set_level("info")
def calculate_load_statistics(load_iters: torch.Tensor):
# sum the loads over iterations and calculate the statistics
load_total = load_iters.sum(dim=0)
mean = load_total.mean().item()
std = load_total.std().item()
imbalance_ratio = (load_total.max().item() - load_total.min().item()) / mean
stats = {
"mean-total": mean,
"std-total": std,
"imbalance-ratio-total": imbalance_ratio
}
# calculate the statistics for each iteration and average over iterations
mean = load_iters.mean(dim=-1).mean().item()
std = load_iters.std(dim=-1).mean().item()
imbalance_ratio = (load_iters.max(dim=-1).values -
load_iters.min(dim=-1).values) / load_iters.mean(dim=-1)
imbalance_ratio = imbalance_ratio.mean().item()
stats.update({
"mean-iter": mean,
"std-iter": std,
"imbalance-ratio-iter": imbalance_ratio
})
return stats
def save_eplb_config(config: dict, path: str):
def represent_list_inline(dumper, data):
return dumper.represent_sequence('tag:yaml.org,2002:seq',
data,
flow_style=True)
yaml.add_representer(list, represent_list_inline)
with open(path, "w") as f:
yaml.dump(config, f, width=float('inf'))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--expert_statistic_path",
type=str,
default=os.environ.get("EXPERT_STATISTIC_PATH",
"expert_statistic"))
parser.add_argument("--iter_start", type=int, default=None)
parser.add_argument("--iter_stop", type=int, default=None)
parser.add_argument("--output_path",
type=str,
default="moe_load_balancer.yaml")
parser.add_argument("--ep_size", type=int, default=8)
parser.add_argument("--num_slots", type=int, default=320)
parser.add_argument("--layer_updates_per_iter", type=int, default=0)
args = parser.parse_args()
with open(f"{args.expert_statistic_path}/meta_info.json", "r") as f:
meta_info = json.load(f)
num_experts = meta_info["num_experts"]
num_experts_per_token = meta_info["num_experts_per_token"]
statistic = {}
for statistic_file in glob.glob(
f"{args.expert_statistic_path}/rank*.safetensors"):
rank_statistic = safetensors.torch.load_file(statistic_file)
for key, data in rank_statistic.items():
if key not in statistic:
statistic[key] = torch.zeros_like(data)
statistic[key] += data
def parse_key(key: str) -> tuple[int, int]:
iter_idx, layer_idx = key.split("_")
return int(iter_idx), int(layer_idx)
statistic = {parse_key(key): data for key, data in statistic.items()}
iters = sorted(list(set(iter_idx for iter_idx, _ in statistic)))
layers = sorted(list(set(layer_idx for _, layer_idx in statistic)))
num_iters = len(iters)
num_layers = len(layers)
assert iters[-1] + 1 - iters[0] == num_iters
assert len(statistic) == num_iters * num_layers
if args.iter_start is None:
args.iter_start = iters[0]
if args.iter_stop is None:
args.iter_stop = iters[-1] + 1
logger.info(f"Statistic iterations: {iters[0]} - {iters[-1] + 1}")
logger.info(f"Used iterations: {args.iter_start} - {args.iter_stop}")
logger.info(f"Statistic layers: {layers}")
num_local_slots = args.num_slots // args.ep_size
initial_global_assignments = {}
load_stats = {}
load_stats_rebalanced = {}
for layer_idx in layers:
expert_token_count_iters = [
data for key, data in statistic.items() if
args.iter_start <= key[0] < args.iter_stop and key[1] == layer_idx
]
expert_token_count_iters = torch.stack(expert_token_count_iters, dim=0)
assert expert_token_count_iters.size(
0) == args.iter_stop - args.iter_start
expert_load_factor = expert_token_count_iters.sum(dim=0).float()
meta_info = MoeLoadBalanceMetaInfo(expert_count=num_experts,
top_k=num_experts_per_token,
ep_rank=0,
ep_size=args.ep_size,
slot_count_per_rank=num_local_slots)
placement_info = MoePlacementCpuInfo()
placement_info.expert_replica_count = [0] * num_experts
placement_info.rank_expert_ids = [[0] * num_local_slots
for _ in range(args.ep_size)]
do_replication(meta_info, expert_load_factor.tolist(), placement_info)
do_placement(meta_info, expert_load_factor.tolist(), placement_info)
initial_global_assignments[layer_idx] = []
for local_expert_ids in placement_info.rank_expert_ids:
initial_global_assignments[layer_idx].extend(local_expert_ids)
# Report load statistics
rank_load_iters = expert_token_count_iters.reshape(
expert_token_count_iters.size(0), args.ep_size, -1).sum(dim=-1)
load_stats[layer_idx] = calculate_load_statistics(
rank_load_iters.float())
token_load_iters_rebalanced = expert_token_count_iters / torch.tensor(
placement_info.expert_replica_count)
token_load_iters_rebalanced = token_load_iters_rebalanced[:,
torch.
tensor(initial_global_assignments[
layer_idx]
)]
rank_load_iters_rebalanced = token_load_iters_rebalanced.reshape(
expert_token_count_iters.size(0), args.ep_size, -1).sum(dim=-1)
load_stats_rebalanced[layer_idx] = calculate_load_statistics(
rank_load_iters_rebalanced.float())
eplb_config = {
"num_slots": args.num_slots,
"initial_global_assignments": initial_global_assignments,
"layer_updates_per_iter": args.layer_updates_per_iter,
}
save_eplb_config(eplb_config, args.output_path)
load_stats = pd.DataFrame(load_stats).T
logger.info(f"Load statistics:\n{load_stats}")
load_stats_rebalanced = pd.DataFrame(load_stats_rebalanced).T
logger.info(
f"Load statistics after rebalance (estimated):\n{load_stats_rebalanced}"
)