mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-03 09:41:30 +08:00
Signed-off-by: FredricZ-2007 <226039983+fredricz-20070104@users.noreply.github.com> Signed-off-by: fredricz-20070104 <226039983+fredricz-20070104@users.noreply.github.com>
451 lines
18 KiB
Python
451 lines
18 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
import glob
|
|
import json
|
|
import math
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
import traceback
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List
|
|
|
|
import yaml
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description='Submit disaggregated benchmark job')
|
|
group = parser.add_mutually_exclusive_group(required=True)
|
|
group.add_argument('-c',
|
|
'--config',
|
|
type=str,
|
|
help='Path to the configuration YAML file')
|
|
group.add_argument('-d',
|
|
'--dir',
|
|
type=str,
|
|
help='Directory containing YAML configuration files')
|
|
parser.add_argument('--log-dir',
|
|
type=str,
|
|
default=None,
|
|
help='Log directory')
|
|
parser.add_argument('--dry-run',
|
|
action='store_true',
|
|
help='Dry run the Python part, test purpose only')
|
|
return parser.parse_args()
|
|
|
|
|
|
def load_config(config_path):
|
|
with open(config_path, 'r') as f:
|
|
return yaml.safe_load(f)
|
|
|
|
|
|
def save_worker_config(worker_config, output_path):
|
|
"""Save worker config to a separate YAML file."""
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
with open(output_path, 'w') as f:
|
|
yaml.dump(worker_config, f, default_flow_style=False)
|
|
|
|
|
|
def calculate_nodes(world_size, num_servers, gpus_per_node):
|
|
"""Calculate required nodes based on world size and server count."""
|
|
return math.ceil(world_size * num_servers / gpus_per_node)
|
|
|
|
|
|
def allocate_gpus(
|
|
total_nodes: int,
|
|
gpus_per_node: int,
|
|
num_gen_servers: int,
|
|
num_ctx_servers: int,
|
|
gen_world_size: int,
|
|
ctx_world_size: int,
|
|
base_port: int = 8000,
|
|
) -> List[Dict[str, Any]]:
|
|
allocations = {}
|
|
hostnames = [f"<node{i}_placeholder>" for i in range(total_nodes)]
|
|
|
|
global_gpu_cursor = 0
|
|
|
|
def get_gpu_location(gpus_per_node: int):
|
|
node_id = global_gpu_cursor // gpus_per_node
|
|
local_gpu_id = global_gpu_cursor % gpus_per_node
|
|
return node_id, local_gpu_id
|
|
|
|
def assign_server(server_allocation: Dict[str, Any], world_size: int,
|
|
gpus_per_node: int):
|
|
nonlocal global_gpu_cursor
|
|
for _ in range(world_size):
|
|
node_id, gpu_id = get_gpu_location(gpus_per_node)
|
|
hostname = hostnames[node_id]
|
|
if hostname not in server_allocation["nodes"]:
|
|
server_allocation["nodes"][hostname] = []
|
|
server_allocation["nodes"][hostname].append(gpu_id)
|
|
global_gpu_cursor += 1
|
|
|
|
def assign_servers(
|
|
server_allocations: Dict[str, Any],
|
|
server_type: str,
|
|
num_servers: int,
|
|
world_size: int,
|
|
gpus_per_node: int,
|
|
):
|
|
if server_type not in server_allocations:
|
|
server_allocations[server_type] = {}
|
|
for i in range(num_servers):
|
|
server_allocation = {
|
|
"port": base_port + i,
|
|
"nodes": {},
|
|
}
|
|
assign_server(server_allocation, world_size, gpus_per_node)
|
|
server_allocations[server_type][i] = server_allocation
|
|
|
|
assign_servers(allocations, "GEN", num_gen_servers, gen_world_size,
|
|
gpus_per_node)
|
|
assign_servers(allocations, "CTX", num_ctx_servers, ctx_world_size,
|
|
gpus_per_node)
|
|
|
|
return allocations
|
|
|
|
|
|
def convert_allocations_to_server_config(allocations, server_port=8333):
|
|
generation_servers = {}
|
|
context_servers = {}
|
|
server_hostname = None
|
|
for server_type in allocations.keys():
|
|
num_servers = len(allocations[server_type])
|
|
urls = []
|
|
for server_id in allocations[server_type].keys():
|
|
instance = allocations[server_type][server_id]
|
|
urls.append(
|
|
f"{list(instance['nodes'].keys())[0]}:{instance['port']}")
|
|
if server_type == "GEN":
|
|
generation_servers = {'num_instances': num_servers, 'urls': urls}
|
|
server_hostname = urls[0].split(':')[0]
|
|
if allocations[server_type][server_id]['port'] == server_port:
|
|
server_port += 1 # Avoid port conflict
|
|
elif server_type == "CTX":
|
|
context_servers = {'num_instances': num_servers, 'urls': urls}
|
|
|
|
server_config = {
|
|
'backend': 'pytorch',
|
|
'hostname': server_hostname,
|
|
'port': server_port,
|
|
'context_servers': context_servers,
|
|
'generation_servers': generation_servers
|
|
}
|
|
return server_config
|
|
|
|
|
|
def submit_job(config, log_dir, dry_run):
|
|
# Extract configurations
|
|
slurm_config = config['slurm']
|
|
slurm_config.setdefault('extra_args', '')
|
|
slurm_config.setdefault('set_segment', True)
|
|
|
|
hw_config = config['hardware']
|
|
env_config = config['environment']
|
|
worker_config = config['worker_config']
|
|
benchmark_config = config['benchmark']
|
|
|
|
# Set default accuracy configuration for backward compatibility
|
|
if 'accuracy' not in config:
|
|
config['accuracy'] = {
|
|
'enable_accuracy_test':
|
|
False,
|
|
'model':
|
|
'local-completions',
|
|
'tasks':
|
|
'gsm8k',
|
|
'model_args_extra':
|
|
'num_concurrent=512,max_retries=3,tokenized_requests=false,timeout=1200,max_gen_toks=256,max_length=4096'
|
|
}
|
|
|
|
# Set default environment configuration for backward compatibility
|
|
env_config.setdefault('trtllm_repo', '')
|
|
env_config.setdefault('build_wheel', False)
|
|
env_config.setdefault('cuda_architectures', '')
|
|
env_config.setdefault('trtllm_wheel_path', '')
|
|
env_config.setdefault('worker_env_var', '')
|
|
env_config.setdefault('server_env_var', '')
|
|
|
|
worker_env_var = env_config.get('worker_env_var')
|
|
server_env_var = env_config.get('server_env_var')
|
|
if benchmark_config['mode'] == "gen_only_no_context":
|
|
worker_env_var += " TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1"
|
|
server_env_var += " TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1"
|
|
|
|
profiling_config = config.get('profiling', {})
|
|
profiling_config.setdefault('nsys_on', False)
|
|
profiling_config.setdefault('ctx_profile_range', '10-30')
|
|
profiling_config.setdefault('gen_profile_range', '200-250')
|
|
|
|
# Get number of servers from config
|
|
ctx_num = hw_config['num_ctx_servers']
|
|
gen_num = hw_config['num_gen_servers']
|
|
gpus_per_node = hw_config['gpus_per_node']
|
|
|
|
# Calculate nodes based on world sizes
|
|
ctx_tp_size = worker_config['ctx'].get('tensor_parallel_size', 1)
|
|
ctx_cp_size = worker_config['ctx'].get('context_parallel_size', 1)
|
|
ctx_pp_size = worker_config['ctx'].get('pipeline_parallel_size', 1)
|
|
ctx_world_size = ctx_tp_size * ctx_cp_size * ctx_pp_size
|
|
ctx_nodes = calculate_nodes(ctx_world_size, ctx_num, gpus_per_node)
|
|
|
|
gen_tp_size = worker_config['gen'].get('tensor_parallel_size', 1)
|
|
gen_cp_size = worker_config['gen'].get('context_parallel_size', 1)
|
|
gen_pp_size = worker_config['gen'].get('pipeline_parallel_size', 1)
|
|
gen_world_size = gen_tp_size * gen_cp_size * gen_pp_size
|
|
gen_nodes = calculate_nodes(gen_world_size, gen_num, gpus_per_node)
|
|
|
|
total_nodes = ctx_nodes + gen_nodes
|
|
total_tasks = total_nodes * gpus_per_node
|
|
|
|
# Generate log directory path based on configuration
|
|
isl = benchmark_config['input_length']
|
|
osl = benchmark_config['output_length']
|
|
gen_batch_size = worker_config['gen']['max_batch_size']
|
|
gen_enable_attention_dp = worker_config['gen']['enable_attention_dp']
|
|
|
|
if log_dir is None:
|
|
# Create base log directory path
|
|
date_prefix = datetime.now().strftime("%Y%m%d")
|
|
log_base = os.path.join(env_config['work_dir'],
|
|
f"{date_prefix}/{isl}-{osl}")
|
|
|
|
# Get eplb num_slots for gen worker
|
|
load_balancer_config = worker_config['gen'].get('moe_config', {}).get(
|
|
'load_balancer', {})
|
|
if isinstance(load_balancer_config, str):
|
|
with open(load_balancer_config, 'r') as f:
|
|
load_balancer_config = yaml.safe_load(f)
|
|
eplb_num_slots = load_balancer_config.get('num_slots', 0)
|
|
|
|
# Get mtp_size from gen config's speculative_config
|
|
mtp_size = worker_config['gen'].get('speculative_config',
|
|
{}).get('num_nextn_predict_layers',
|
|
0)
|
|
|
|
# Determine directory suffix based on attention_dp
|
|
if gen_enable_attention_dp:
|
|
dir_suffix = f"disagg_ctx{ctx_num}_gen{gen_num}_dep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}"
|
|
else:
|
|
dir_suffix = f"disagg_ctx{ctx_num}_gen{gen_num}_tep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}"
|
|
|
|
# Create full log directory path
|
|
log_dir = os.path.join(log_base, dir_suffix)
|
|
|
|
# Remove existing directory if it exists
|
|
if os.path.exists(log_dir):
|
|
print(f"[WARNING] Removing existing log directory: {log_dir}")
|
|
shutil.rmtree(log_dir)
|
|
os.makedirs(log_dir)
|
|
print(f"Log will be saved to: {log_dir}")
|
|
|
|
# Setup config file paths and save worker configs
|
|
ctx_config_path = os.path.join(log_dir, 'ctx_config.yaml')
|
|
gen_config_path = os.path.join(log_dir, 'gen_config.yaml')
|
|
save_worker_config(worker_config['ctx'], ctx_config_path)
|
|
save_worker_config(worker_config['gen'], gen_config_path)
|
|
|
|
# Prepare allocation template
|
|
allocations = allocate_gpus(
|
|
total_nodes=total_nodes,
|
|
gpus_per_node=gpus_per_node,
|
|
num_gen_servers=gen_num,
|
|
num_ctx_servers=ctx_num,
|
|
gen_world_size=gen_world_size,
|
|
ctx_world_size=ctx_world_size,
|
|
)
|
|
with open(os.path.join(log_dir, "allocations.json"), "w") as f:
|
|
json.dump(allocations, f, indent=2)
|
|
|
|
# Generate disagg server config
|
|
server_config = convert_allocations_to_server_config(allocations)
|
|
with open(os.path.join(log_dir, "server_config.yaml"), "w") as f:
|
|
yaml.dump(server_config, f)
|
|
disagg_server_hostname = server_config['hostname']
|
|
disagg_server_port = server_config['port']
|
|
|
|
container_name = "disaggr-test"
|
|
start_server_cmds = []
|
|
# Generate start worker commands with placeholder hostnames
|
|
for server_type in allocations.keys():
|
|
for server_id in allocations[server_type].keys():
|
|
allocation = allocations[server_type][server_id]
|
|
cuda_devices = ",".join([
|
|
str(device) for device in list(allocation["nodes"].values())[0]
|
|
])
|
|
cur_worker_env_var = worker_env_var + f" CUDA_VISIBLE_DEVICES={cuda_devices}"
|
|
cmd = [
|
|
"srun -l",
|
|
f"--nodelist {','.join(allocation['nodes'].keys())}",
|
|
f"-N {len(allocation['nodes'])}",
|
|
f"--ntasks {gen_world_size if server_type == 'GEN' else ctx_world_size}",
|
|
f"--ntasks-per-node {gpus_per_node}",
|
|
f"--container-image {env_config['container_image']}",
|
|
f"--container-name {container_name}",
|
|
f"--container-mounts {env_config['container_mount']}",
|
|
"--mpi=pmix --overlap",
|
|
f"bash {os.path.join(env_config['work_dir'], 'start_worker.sh')}",
|
|
server_type,
|
|
str(server_id),
|
|
env_config['model_path'],
|
|
str(allocation["port"]),
|
|
benchmark_config['mode'],
|
|
f"'{benchmark_config['concurrency_list']}'",
|
|
str(slurm_config['numa_bind']).lower(),
|
|
log_dir,
|
|
str(profiling_config['nsys_on']).lower(),
|
|
f"'{profiling_config['gen_profile_range']}'" if server_type
|
|
== "GEN" else f"'{profiling_config['ctx_profile_range']}'",
|
|
gen_config_path if server_type == "GEN" else ctx_config_path,
|
|
f"'{cur_worker_env_var}'",
|
|
f"&> {log_dir}/3_output_{server_type}_{server_id}.log &",
|
|
]
|
|
start_server_cmds.append(" ".join(cmd))
|
|
|
|
# Generate start server commands
|
|
cmd = [
|
|
"srun -l",
|
|
f"--nodelist {disagg_server_hostname}",
|
|
f"--container-name={container_name}",
|
|
f"--container-image={env_config['container_image']}",
|
|
f"--container-mounts={env_config['container_mount']}",
|
|
f"--mpi=pmix --overlap -N 1 -n 1",
|
|
f"bash {env_config['work_dir']}/start_server.sh {os.path.join(log_dir, 'server_config.yaml')} \"{server_env_var}\"",
|
|
f"&> {log_dir}/4_output_server.log &",
|
|
]
|
|
start_server_cmds.append(" ".join(cmd))
|
|
|
|
# Generate wait server command
|
|
cmd = [
|
|
"srun -l",
|
|
f"--container-name={container_name}",
|
|
f"--container-mounts={env_config['container_mount']}",
|
|
f"--mpi=pmix --overlap -N 1 -n 1",
|
|
f"bash {env_config['work_dir']}/wait_server.sh {disagg_server_hostname} {disagg_server_port}",
|
|
f"&> {log_dir}/5_wait_server.log",
|
|
]
|
|
start_server_cmds.append(" ".join(cmd))
|
|
|
|
with open(os.path.join(log_dir, "start_server_cmds.sh"), "w") as f:
|
|
f.write("\n".join(start_server_cmds) + "\n")
|
|
|
|
# Generate client commands
|
|
client_cmds = []
|
|
client_slurm_prefix = [
|
|
f"srun -l --container-name={container_name}",
|
|
f"--container-mounts={env_config['container_mount']}",
|
|
f"--mpi=pmix --overlap -N 1 -n 1",
|
|
]
|
|
if benchmark_config['use_nv_sa_benchmark']:
|
|
benchmark_cmd = [
|
|
f"bash {env_config['work_dir']}/run_benchmark_nv_sa.sh",
|
|
f"'{env_config['model_path']}' {isl} {osl} {benchmark_config['benchmark_ratio']} {benchmark_config['multi_round']} {gen_num} '{benchmark_config['concurrency_list']}' {benchmark_config['streaming']} '{log_dir}' {disagg_server_hostname} {disagg_server_port}",
|
|
f"&> {log_dir}/6_bench.log"
|
|
]
|
|
client_cmds.append(" ".join(client_slurm_prefix + benchmark_cmd))
|
|
else:
|
|
benchmark_cmd = [
|
|
f"bash {env_config['work_dir']}/run_benchmark.sh",
|
|
f"'{env_config['model_path']}' '{benchmark_config['dataset_file']}' {benchmark_config['multi_round']} {gen_num} '{benchmark_config['concurrency_list']}' {benchmark_config['streaming']} '{log_dir}' {disagg_server_hostname} {disagg_server_port}",
|
|
f"&> {log_dir}/6_bench.log"
|
|
]
|
|
client_cmds.append(" ".join(client_slurm_prefix + benchmark_cmd))
|
|
if config['accuracy']['enable_accuracy_test']:
|
|
accuracy_cmd = [
|
|
f"bash {env_config['work_dir']}/accuracy_eval.sh",
|
|
f"'{log_dir}' '{config['accuracy']['model']}' '{config['accuracy']['tasks']}' '{env_config['model_path']}' '{config['accuracy']['model_args_extra']}' '{log_dir}/accuracy_eval' {disagg_server_hostname} {disagg_server_port}",
|
|
f"&> {log_dir}/7_accuracy_eval.log"
|
|
]
|
|
client_cmds.append(" ".join(client_slurm_prefix + accuracy_cmd))
|
|
with open(os.path.join(log_dir, "client_cmds.sh"), "w") as f:
|
|
f.write("\n".join(client_cmds) + "\n")
|
|
|
|
# Prepare sbatch command
|
|
# yapf: disable
|
|
cmd = [
|
|
'sbatch',
|
|
f'--partition={slurm_config["partition"]}',
|
|
f'--account={slurm_config["account"]}',
|
|
f'--time={slurm_config["job_time"]}',
|
|
f'--job-name={slurm_config["job_name"]}',
|
|
f'--nodes={total_nodes}',
|
|
f'--ntasks={total_tasks}',
|
|
f'--ntasks-per-node={hw_config["gpus_per_node"]}',
|
|
*([] if not slurm_config['set_segment'] else [f'--segment={total_nodes}']),
|
|
f'--output={log_dir}/slurm-%j.out',
|
|
f'--error={log_dir}/slurm-%j.err',
|
|
*([arg for arg in slurm_config['extra_args'].split() if arg]),
|
|
slurm_config['script_file'],
|
|
|
|
# Benchmark Configuration
|
|
'--benchmark-mode', benchmark_config['mode'],
|
|
|
|
# Environment and paths
|
|
'--trtllm-repo', env_config['trtllm_repo'],
|
|
'--work-dir', env_config['work_dir'],
|
|
'--full-logdir', log_dir,
|
|
'--container-name', container_name,
|
|
'--container-mount', env_config['container_mount'],
|
|
'--container-image', env_config['container_image'],
|
|
'--build-wheel', str(env_config['build_wheel']).lower(),
|
|
'--cuda-architectures', env_config['cuda_architectures'],
|
|
'--trtllm-wheel-path', env_config['trtllm_wheel_path'],
|
|
]
|
|
# yapf: enable
|
|
|
|
if dry_run:
|
|
print(
|
|
"[WARNING] Dry run mode, will not submit the job. This should be used for test purpose only."
|
|
)
|
|
print("sbatch command:")
|
|
print(" ".join(cmd))
|
|
return
|
|
else:
|
|
# Submit the job
|
|
try:
|
|
subprocess.run(cmd, check=True)
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"Error submitting job: {e}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
# Determine which mode to use
|
|
if args.config:
|
|
# Single config file mode
|
|
config_files = [args.config]
|
|
else:
|
|
# Directory mode - find all YAML files
|
|
yaml_pattern = os.path.join(args.dir, '*.yaml')
|
|
config_files = sorted(glob.glob(yaml_pattern))
|
|
|
|
if not config_files:
|
|
print(f"No YAML files found in directory: {args.dir}",
|
|
file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
print(f"Found {len(config_files)} YAML file(s) in {args.dir}")
|
|
|
|
# Process each config file
|
|
for config_file in config_files:
|
|
print(f"Processing: {config_file}")
|
|
try:
|
|
config = load_config(config_file)
|
|
submit_job(config, args.log_dir, args.dry_run)
|
|
print(f"Successfully submitted job for: {config_file}\n")
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
print(f"Error processing {config_file}: {e}", file=sys.stderr)
|
|
# Continue processing other files even if one fails
|
|
continue
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|