mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-23 12:12:39 +08:00
Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Signed-off-by: ShiXiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> Signed-off-by: Jun Yang <143764042+juney-nvidia@users.noreply.github.com> Co-authored-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Co-authored-by: ShiXiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> Co-authored-by: Jun Yang <143764042+juney-nvidia@users.noreply.github.com>
109 lines
4.2 KiB
Python
109 lines
4.2 KiB
Python
import argparse
|
|
|
|
import torch
|
|
import yaml
|
|
from utils import load_expert_statistic
|
|
|
|
from tensorrt_llm.bindings.internal.runtime import (MoeLoadBalanceMetaInfo,
|
|
MoePlacementCpuInfo,
|
|
do_placement,
|
|
do_replication)
|
|
|
|
|
|
def save_eplb_config(config: dict, path: str):
|
|
|
|
def represent_list_inline(dumper, data):
|
|
return dumper.represent_sequence('tag:yaml.org,2002:seq',
|
|
data,
|
|
flow_style=True)
|
|
|
|
yaml.add_representer(list, represent_list_inline)
|
|
|
|
with open(path, "w") as f:
|
|
yaml.dump(config, f, width=float('inf'))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--expert_statistic_path",
|
|
type=str,
|
|
required=True,
|
|
help="The directory path to the expert statistic files.")
|
|
parser.add_argument("--iter_start",
|
|
type=int,
|
|
default=None,
|
|
help="The start iteration of used iterations.")
|
|
parser.add_argument("--iter_stop",
|
|
type=int,
|
|
default=None,
|
|
help="The end iteration of used iterations.")
|
|
parser.add_argument("--output_path",
|
|
type=str,
|
|
required=True,
|
|
help="The output path to the eplb config file.")
|
|
parser.add_argument(
|
|
"--ep_size",
|
|
type=int,
|
|
default=None,
|
|
help="The expert parallelism size after load rebalance.")
|
|
parser.add_argument(
|
|
"--num_slots",
|
|
type=int,
|
|
default=None,
|
|
help="The total number of expert slots after load rebalance.")
|
|
parser.add_argument("--layer_updates_per_iter",
|
|
type=int,
|
|
default=0,
|
|
help="The number of layers to update per iteration.")
|
|
args = parser.parse_args()
|
|
|
|
meta_info, statistic = load_expert_statistic(args.expert_statistic_path)
|
|
num_experts = meta_info["num_experts"]
|
|
num_experts_per_token = meta_info["num_experts_per_token"]
|
|
|
|
if args.ep_size is None:
|
|
args.ep_size = meta_info["ep_size"]
|
|
if args.num_slots is None:
|
|
args.num_slots = num_experts
|
|
if args.iter_start is None:
|
|
args.iter_start = meta_info["iter_start"]
|
|
if args.iter_stop is None:
|
|
args.iter_stop = meta_info["iter_stop"]
|
|
num_iters = args.iter_stop - args.iter_start
|
|
|
|
num_local_slots = args.num_slots // args.ep_size
|
|
initial_global_assignments = {}
|
|
for layer_idx in meta_info["layers"]:
|
|
expert_token_count_iters = [
|
|
data for key, data in statistic.items() if
|
|
args.iter_start <= key[0] < args.iter_stop and key[1] == layer_idx
|
|
]
|
|
expert_token_count_iters = torch.stack(expert_token_count_iters, dim=0)
|
|
assert expert_token_count_iters.size(0) == num_iters
|
|
expert_load_factor = expert_token_count_iters.sum(dim=0).float()
|
|
|
|
moelb_info = MoeLoadBalanceMetaInfo(expert_count=num_experts,
|
|
top_k=num_experts_per_token,
|
|
ep_rank=0,
|
|
ep_size=args.ep_size,
|
|
slot_count_per_rank=num_local_slots)
|
|
placement_info = MoePlacementCpuInfo()
|
|
placement_info.expert_replica_count = [0] * num_experts
|
|
placement_info.rank_expert_ids = [[0] * num_local_slots
|
|
for _ in range(args.ep_size)]
|
|
|
|
do_replication(moelb_info, expert_load_factor.tolist(), placement_info)
|
|
do_placement(moelb_info, expert_load_factor.tolist(), placement_info)
|
|
|
|
initial_global_assignments[layer_idx] = []
|
|
for local_expert_ids in placement_info.rank_expert_ids:
|
|
initial_global_assignments[layer_idx].extend(local_expert_ids)
|
|
|
|
eplb_config = {
|
|
"num_slots": args.num_slots,
|
|
"initial_global_assignments": initial_global_assignments,
|
|
"layer_updates_per_iter": args.layer_updates_per_iter,
|
|
}
|
|
save_eplb_config(eplb_config, args.output_path)
|