mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-22 19:52:38 +08:00
174 lines
7.0 KiB
Python
174 lines
7.0 KiB
Python
import argparse
|
|
import glob
|
|
import json
|
|
import os
|
|
|
|
import pandas as pd
|
|
import safetensors
|
|
import torch
|
|
import yaml
|
|
|
|
from tensorrt_llm.bindings.internal.runtime import (MoeLoadBalanceMetaInfo,
|
|
MoePlacementCpuInfo,
|
|
do_placement,
|
|
do_replication)
|
|
from tensorrt_llm.logger import logger
|
|
|
|
logger.set_level("info")
|
|
|
|
|
|
def calculate_load_statistics(load_iters: torch.Tensor):
|
|
# sum the loads over iterations and calculate the statistics
|
|
load_total = load_iters.sum(dim=0)
|
|
mean = load_total.mean().item()
|
|
std = load_total.std().item()
|
|
imbalance_ratio = (load_total.max().item() - load_total.min().item()) / mean
|
|
stats = {
|
|
"mean-total": mean,
|
|
"std-total": std,
|
|
"imbalance-ratio-total": imbalance_ratio
|
|
}
|
|
|
|
# calculate the statistics for each iteration and average over iterations
|
|
mean = load_iters.mean(dim=-1).mean().item()
|
|
std = load_iters.std(dim=-1).mean().item()
|
|
imbalance_ratio = (load_iters.max(dim=-1).values -
|
|
load_iters.min(dim=-1).values) / load_iters.mean(dim=-1)
|
|
imbalance_ratio = imbalance_ratio.mean().item()
|
|
stats.update({
|
|
"mean-iter": mean,
|
|
"std-iter": std,
|
|
"imbalance-ratio-iter": imbalance_ratio
|
|
})
|
|
return stats
|
|
|
|
|
|
def save_eplb_config(config: dict, path: str):
|
|
|
|
def represent_list_inline(dumper, data):
|
|
return dumper.represent_sequence('tag:yaml.org,2002:seq',
|
|
data,
|
|
flow_style=True)
|
|
|
|
yaml.add_representer(list, represent_list_inline)
|
|
|
|
with open(path, "w") as f:
|
|
yaml.dump(config, f, width=float('inf'))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--expert_statistic_path",
|
|
type=str,
|
|
default=os.environ.get("EXPERT_STATISTIC_PATH",
|
|
"expert_statistic"))
|
|
parser.add_argument("--iter_start", type=int, default=None)
|
|
parser.add_argument("--iter_stop", type=int, default=None)
|
|
parser.add_argument("--output_path",
|
|
type=str,
|
|
default="moe_load_balancer.yaml")
|
|
parser.add_argument("--ep_size", type=int, default=8)
|
|
parser.add_argument("--num_slots", type=int, default=320)
|
|
parser.add_argument("--layer_updates_per_iter", type=int, default=0)
|
|
args = parser.parse_args()
|
|
|
|
with open(f"{args.expert_statistic_path}/meta_info.json", "r") as f:
|
|
meta_info = json.load(f)
|
|
num_experts = meta_info["num_experts"]
|
|
num_experts_per_token = meta_info["num_experts_per_token"]
|
|
|
|
statistic = {}
|
|
for statistic_file in glob.glob(
|
|
f"{args.expert_statistic_path}/rank*.safetensors"):
|
|
rank_statistic = safetensors.torch.load_file(statistic_file)
|
|
for key, data in rank_statistic.items():
|
|
if key not in statistic:
|
|
statistic[key] = torch.zeros_like(data)
|
|
statistic[key] += data
|
|
|
|
def parse_key(key: str) -> tuple[int, int]:
|
|
iter_idx, layer_idx = key.split("_")
|
|
return int(iter_idx), int(layer_idx)
|
|
|
|
statistic = {parse_key(key): data for key, data in statistic.items()}
|
|
|
|
iters = sorted(list(set(iter_idx for iter_idx, _ in statistic)))
|
|
layers = sorted(list(set(layer_idx for _, layer_idx in statistic)))
|
|
num_iters = len(iters)
|
|
num_layers = len(layers)
|
|
assert iters[-1] + 1 - iters[0] == num_iters
|
|
assert len(statistic) == num_iters * num_layers
|
|
|
|
if args.iter_start is None:
|
|
args.iter_start = iters[0]
|
|
if args.iter_stop is None:
|
|
args.iter_stop = iters[-1] + 1
|
|
logger.info(f"Statistic iterations: {iters[0]} - {iters[-1] + 1}")
|
|
logger.info(f"Used iterations: {args.iter_start} - {args.iter_stop}")
|
|
logger.info(f"Statistic layers: {layers}")
|
|
|
|
num_local_slots = args.num_slots // args.ep_size
|
|
initial_global_assignments = {}
|
|
load_stats = {}
|
|
load_stats_rebalanced = {}
|
|
|
|
for layer_idx in layers:
|
|
expert_token_count_iters = [
|
|
data for key, data in statistic.items() if
|
|
args.iter_start <= key[0] < args.iter_stop and key[1] == layer_idx
|
|
]
|
|
expert_token_count_iters = torch.stack(expert_token_count_iters, dim=0)
|
|
assert expert_token_count_iters.size(
|
|
0) == args.iter_stop - args.iter_start
|
|
expert_load_factor = expert_token_count_iters.sum(dim=0).float()
|
|
|
|
meta_info = MoeLoadBalanceMetaInfo(expert_count=num_experts,
|
|
top_k=num_experts_per_token,
|
|
ep_rank=0,
|
|
ep_size=args.ep_size,
|
|
slot_count_per_rank=num_local_slots)
|
|
placement_info = MoePlacementCpuInfo()
|
|
placement_info.expert_replica_count = [0] * num_experts
|
|
placement_info.rank_expert_ids = [[0] * num_local_slots
|
|
for _ in range(args.ep_size)]
|
|
|
|
do_replication(meta_info, expert_load_factor.tolist(), placement_info)
|
|
do_placement(meta_info, expert_load_factor.tolist(), placement_info)
|
|
|
|
initial_global_assignments[layer_idx] = []
|
|
for local_expert_ids in placement_info.rank_expert_ids:
|
|
initial_global_assignments[layer_idx].extend(local_expert_ids)
|
|
|
|
# Report load statistics
|
|
rank_load_iters = expert_token_count_iters.reshape(
|
|
expert_token_count_iters.size(0), args.ep_size, -1).sum(dim=-1)
|
|
load_stats[layer_idx] = calculate_load_statistics(
|
|
rank_load_iters.float())
|
|
|
|
token_load_iters_rebalanced = expert_token_count_iters / torch.tensor(
|
|
placement_info.expert_replica_count)
|
|
token_load_iters_rebalanced = token_load_iters_rebalanced[:,
|
|
torch.
|
|
tensor(initial_global_assignments[
|
|
layer_idx]
|
|
)]
|
|
rank_load_iters_rebalanced = token_load_iters_rebalanced.reshape(
|
|
expert_token_count_iters.size(0), args.ep_size, -1).sum(dim=-1)
|
|
load_stats_rebalanced[layer_idx] = calculate_load_statistics(
|
|
rank_load_iters_rebalanced.float())
|
|
|
|
eplb_config = {
|
|
"num_slots": args.num_slots,
|
|
"initial_global_assignments": initial_global_assignments,
|
|
"layer_updates_per_iter": args.layer_updates_per_iter,
|
|
}
|
|
save_eplb_config(eplb_config, args.output_path)
|
|
|
|
load_stats = pd.DataFrame(load_stats).T
|
|
logger.info(f"Load statistics:\n{load_stats}")
|
|
|
|
load_stats_rebalanced = pd.DataFrame(load_stats_rebalanced).T
|
|
logger.info(
|
|
f"Load statistics after rebalance (estimated):\n{load_stats_rebalanced}"
|
|
)
|