TensorRT-LLMs/examples/ep_load_balancer/utils.py
Enwei Zhu 00991d1520
chore: Merge remaining changes from feat/large-ep branch to main (#5039)
Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>
Signed-off-by: ShiXiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Jun Yang <143764042+juney-nvidia@users.noreply.github.com>
Co-authored-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>
Co-authored-by: ShiXiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
Co-authored-by: Jun Yang <143764042+juney-nvidia@users.noreply.github.com>
2025-06-11 13:47:43 +08:00

38 lines
1.2 KiB
Python

import glob
import json
import safetensors.torch
import torch
def load_expert_statistic(path: str):
with open(f"{path}/meta_info.json", "r") as f:
meta_info = json.load(f)
statistic_files = glob.glob(f"{path}/rank*.safetensors")
statistic = {}
for statistic_file in statistic_files:
rank_statistic = safetensors.torch.load_file(statistic_file)
for key, data in rank_statistic.items():
if key not in statistic:
statistic[key] = torch.zeros_like(data)
statistic[key] += data
def parse_key(key: str) -> tuple[int, int]:
iter_idx, layer_idx = key.split("_")
return int(iter_idx), int(layer_idx)
statistic = {parse_key(key): data for key, data in statistic.items()}
iters = sorted(list(set(iter_idx for iter_idx, _ in statistic)))
layers = sorted(list(set(layer_idx for _, layer_idx in statistic)))
num_iters = len(iters)
num_layers = len(layers)
assert iters[-1] + 1 - iters[0] == num_iters
assert len(statistic) == num_iters * num_layers
meta_info["ep_size"] = len(statistic_files)
meta_info["iter_start"] = iters[0]
meta_info["iter_stop"] = iters[-1] + 1
meta_info["layers"] = layers
return meta_info, statistic