TensorRT-LLMs/examples/ep_load_balancer/report_load_statistics.py
Enwei Zhu 00991d1520
chore: Merge remaining changes from feat/large-ep branch to main (#5039)
Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>
Signed-off-by: ShiXiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Jun Yang <143764042+juney-nvidia@users.noreply.github.com>
Co-authored-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>
Co-authored-by: ShiXiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
Co-authored-by: Jun Yang <143764042+juney-nvidia@users.noreply.github.com>
2025-06-11 13:47:43 +08:00

69 lines
2.5 KiB
Python

import argparse
import pandas as pd
import torch
from utils import load_expert_statistic
def calculate_load_statistics(load_iters: torch.Tensor):
# calculate the statistics for each iteration and average over iterations
mean = load_iters.mean(dim=-1).mean().item()
std = load_iters.std(dim=-1).mean().item()
imbalance_ratio = load_iters.max(dim=-1).values / load_iters.mean(
dim=-1) - 1
imbalance_ratio = imbalance_ratio.mean().item()
return {"mean": mean, "std": std, "imbalance-ratio": imbalance_ratio}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--expert_statistic_path",
type=str,
required=True,
help="The directory path to the expert statistic files.")
parser.add_argument("--iter_start",
type=int,
default=None,
help="The start iteration of used iterations.")
parser.add_argument("--iter_stop",
type=int,
default=None,
help="The end iteration of used iterations.")
parser.add_argument("--per_expert",
default=False,
action="store_true",
help="Report the load statistics per expert.")
args = parser.parse_args()
meta_info, statistic = load_expert_statistic(args.expert_statistic_path)
num_experts = meta_info["num_experts"]
num_experts_per_token = meta_info["num_experts_per_token"]
if args.iter_start is None:
args.iter_start = meta_info["iter_start"]
if args.iter_stop is None:
args.iter_stop = meta_info["iter_stop"]
num_iters = args.iter_stop - args.iter_start
load_stats = {}
for layer_idx in meta_info["layers"]:
expert_token_count_iters = [
data for key, data in statistic.items() if
args.iter_start <= key[0] < args.iter_stop and key[1] == layer_idx
]
expert_token_count_iters = torch.stack(expert_token_count_iters, dim=0)
assert expert_token_count_iters.size(0) == num_iters
if args.per_expert:
load_iters = expert_token_count_iters
else:
load_iters = expert_token_count_iters.reshape(
num_iters, meta_info["ep_size"], -1).sum(dim=-1)
load_stats[layer_idx] = calculate_load_statistics(load_iters.float())
load_stats = pd.DataFrame(load_stats)
load_stats["average"] = load_stats.mean(axis=1)
load_stats = load_stats.T
print(f"Load statistics:\n{load_stats}")