TensorRT-LLMs/examples/wide_ep/ep_load_balancer/utils.py
Xianjie Qiao b1976c2add
Add wide-ep benchmarking scripts (#5760)
Signed-off-by: Xianjie <5410381+qiaoxj07@users.noreply.github.com>
Signed-off-by: Xianjie Qiao <5410381+qiaoxj07@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-07-05 19:29:39 +08:00

38 lines
1.2 KiB
Python

import glob
import json
import safetensors.torch
import torch
def load_expert_statistic(path: str):
with open(f"{path}/meta_info.json", "r") as f:
meta_info = json.load(f)
statistic_files = glob.glob(f"{path}/rank*.safetensors")
statistic = {}
for statistic_file in statistic_files:
rank_statistic = safetensors.torch.load_file(statistic_file)
for key, data in rank_statistic.items():
if key not in statistic:
statistic[key] = torch.zeros_like(data)
statistic[key] += data
def parse_key(key: str) -> tuple[int, int]:
iter_idx, layer_idx = key.split("_")
return int(iter_idx), int(layer_idx)
statistic = {parse_key(key): data for key, data in statistic.items()}
iters = sorted(list(set(iter_idx for iter_idx, _ in statistic)))
layers = sorted(list(set(layer_idx for _, layer_idx in statistic)))
num_iters = len(iters)
num_layers = len(layers)
assert iters[-1] + 1 - iters[0] == num_iters
assert len(statistic) == num_iters * num_layers
meta_info["ep_size"] = len(statistic_files)
meta_info["iter_start"] = iters[0]
meta_info["iter_stop"] = iters[-1] + 1
meta_info["layers"] = layers
return meta_info, statistic