TensorRT-LLMs/examples/disaggregated/slurm/benchmark/submit.py
fredricz-20070104 621156ad44
[None][chore] Fix GB300 support issues (#10196)
Signed-off-by: FredricZ-2007 <226039983+fredricz-20070104@users.noreply.github.com>
Signed-off-by: fredricz-20070104 <226039983+fredricz-20070104@users.noreply.github.com>
2025-12-23 10:42:41 +08:00

451 lines
18 KiB
Python

#!/usr/bin/env python3
import argparse
import glob
import json
import math
import os
import shutil
import subprocess
import sys
import traceback
from datetime import datetime
from typing import Any, Dict, List
import yaml
def parse_args():
parser = argparse.ArgumentParser(
description='Submit disaggregated benchmark job')
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('-c',
'--config',
type=str,
help='Path to the configuration YAML file')
group.add_argument('-d',
'--dir',
type=str,
help='Directory containing YAML configuration files')
parser.add_argument('--log-dir',
type=str,
default=None,
help='Log directory')
parser.add_argument('--dry-run',
action='store_true',
help='Dry run the Python part, test purpose only')
return parser.parse_args()
def load_config(config_path):
with open(config_path, 'r') as f:
return yaml.safe_load(f)
def save_worker_config(worker_config, output_path):
"""Save worker config to a separate YAML file."""
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'w') as f:
yaml.dump(worker_config, f, default_flow_style=False)
def calculate_nodes(world_size, num_servers, gpus_per_node):
"""Calculate required nodes based on world size and server count."""
return math.ceil(world_size * num_servers / gpus_per_node)
def allocate_gpus(
total_nodes: int,
gpus_per_node: int,
num_gen_servers: int,
num_ctx_servers: int,
gen_world_size: int,
ctx_world_size: int,
base_port: int = 8000,
) -> List[Dict[str, Any]]:
allocations = {}
hostnames = [f"<node{i}_placeholder>" for i in range(total_nodes)]
global_gpu_cursor = 0
def get_gpu_location(gpus_per_node: int):
node_id = global_gpu_cursor // gpus_per_node
local_gpu_id = global_gpu_cursor % gpus_per_node
return node_id, local_gpu_id
def assign_server(server_allocation: Dict[str, Any], world_size: int,
gpus_per_node: int):
nonlocal global_gpu_cursor
for _ in range(world_size):
node_id, gpu_id = get_gpu_location(gpus_per_node)
hostname = hostnames[node_id]
if hostname not in server_allocation["nodes"]:
server_allocation["nodes"][hostname] = []
server_allocation["nodes"][hostname].append(gpu_id)
global_gpu_cursor += 1
def assign_servers(
server_allocations: Dict[str, Any],
server_type: str,
num_servers: int,
world_size: int,
gpus_per_node: int,
):
if server_type not in server_allocations:
server_allocations[server_type] = {}
for i in range(num_servers):
server_allocation = {
"port": base_port + i,
"nodes": {},
}
assign_server(server_allocation, world_size, gpus_per_node)
server_allocations[server_type][i] = server_allocation
assign_servers(allocations, "GEN", num_gen_servers, gen_world_size,
gpus_per_node)
assign_servers(allocations, "CTX", num_ctx_servers, ctx_world_size,
gpus_per_node)
return allocations
def convert_allocations_to_server_config(allocations, server_port=8333):
generation_servers = {}
context_servers = {}
server_hostname = None
for server_type in allocations.keys():
num_servers = len(allocations[server_type])
urls = []
for server_id in allocations[server_type].keys():
instance = allocations[server_type][server_id]
urls.append(
f"{list(instance['nodes'].keys())[0]}:{instance['port']}")
if server_type == "GEN":
generation_servers = {'num_instances': num_servers, 'urls': urls}
server_hostname = urls[0].split(':')[0]
if allocations[server_type][server_id]['port'] == server_port:
server_port += 1 # Avoid port conflict
elif server_type == "CTX":
context_servers = {'num_instances': num_servers, 'urls': urls}
server_config = {
'backend': 'pytorch',
'hostname': server_hostname,
'port': server_port,
'context_servers': context_servers,
'generation_servers': generation_servers
}
return server_config
def submit_job(config, log_dir, dry_run):
# Extract configurations
slurm_config = config['slurm']
slurm_config.setdefault('extra_args', '')
slurm_config.setdefault('set_segment', True)
hw_config = config['hardware']
env_config = config['environment']
worker_config = config['worker_config']
benchmark_config = config['benchmark']
# Set default accuracy configuration for backward compatibility
if 'accuracy' not in config:
config['accuracy'] = {
'enable_accuracy_test':
False,
'model':
'local-completions',
'tasks':
'gsm8k',
'model_args_extra':
'num_concurrent=512,max_retries=3,tokenized_requests=false,timeout=1200,max_gen_toks=256,max_length=4096'
}
# Set default environment configuration for backward compatibility
env_config.setdefault('trtllm_repo', '')
env_config.setdefault('build_wheel', False)
env_config.setdefault('cuda_architectures', '')
env_config.setdefault('trtllm_wheel_path', '')
env_config.setdefault('worker_env_var', '')
env_config.setdefault('server_env_var', '')
worker_env_var = env_config.get('worker_env_var')
server_env_var = env_config.get('server_env_var')
if benchmark_config['mode'] == "gen_only_no_context":
worker_env_var += " TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1"
server_env_var += " TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1"
profiling_config = config.get('profiling', {})
profiling_config.setdefault('nsys_on', False)
profiling_config.setdefault('ctx_profile_range', '10-30')
profiling_config.setdefault('gen_profile_range', '200-250')
# Get number of servers from config
ctx_num = hw_config['num_ctx_servers']
gen_num = hw_config['num_gen_servers']
gpus_per_node = hw_config['gpus_per_node']
# Calculate nodes based on world sizes
ctx_tp_size = worker_config['ctx'].get('tensor_parallel_size', 1)
ctx_cp_size = worker_config['ctx'].get('context_parallel_size', 1)
ctx_pp_size = worker_config['ctx'].get('pipeline_parallel_size', 1)
ctx_world_size = ctx_tp_size * ctx_cp_size * ctx_pp_size
ctx_nodes = calculate_nodes(ctx_world_size, ctx_num, gpus_per_node)
gen_tp_size = worker_config['gen'].get('tensor_parallel_size', 1)
gen_cp_size = worker_config['gen'].get('context_parallel_size', 1)
gen_pp_size = worker_config['gen'].get('pipeline_parallel_size', 1)
gen_world_size = gen_tp_size * gen_cp_size * gen_pp_size
gen_nodes = calculate_nodes(gen_world_size, gen_num, gpus_per_node)
total_nodes = ctx_nodes + gen_nodes
total_tasks = total_nodes * gpus_per_node
# Generate log directory path based on configuration
isl = benchmark_config['input_length']
osl = benchmark_config['output_length']
gen_batch_size = worker_config['gen']['max_batch_size']
gen_enable_attention_dp = worker_config['gen']['enable_attention_dp']
if log_dir is None:
# Create base log directory path
date_prefix = datetime.now().strftime("%Y%m%d")
log_base = os.path.join(env_config['work_dir'],
f"{date_prefix}/{isl}-{osl}")
# Get eplb num_slots for gen worker
load_balancer_config = worker_config['gen'].get('moe_config', {}).get(
'load_balancer', {})
if isinstance(load_balancer_config, str):
with open(load_balancer_config, 'r') as f:
load_balancer_config = yaml.safe_load(f)
eplb_num_slots = load_balancer_config.get('num_slots', 0)
# Get mtp_size from gen config's speculative_config
mtp_size = worker_config['gen'].get('speculative_config',
{}).get('num_nextn_predict_layers',
0)
# Determine directory suffix based on attention_dp
if gen_enable_attention_dp:
dir_suffix = f"disagg_ctx{ctx_num}_gen{gen_num}_dep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}"
else:
dir_suffix = f"disagg_ctx{ctx_num}_gen{gen_num}_tep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}"
# Create full log directory path
log_dir = os.path.join(log_base, dir_suffix)
# Remove existing directory if it exists
if os.path.exists(log_dir):
print(f"[WARNING] Removing existing log directory: {log_dir}")
shutil.rmtree(log_dir)
os.makedirs(log_dir)
print(f"Log will be saved to: {log_dir}")
# Setup config file paths and save worker configs
ctx_config_path = os.path.join(log_dir, 'ctx_config.yaml')
gen_config_path = os.path.join(log_dir, 'gen_config.yaml')
save_worker_config(worker_config['ctx'], ctx_config_path)
save_worker_config(worker_config['gen'], gen_config_path)
# Prepare allocation template
allocations = allocate_gpus(
total_nodes=total_nodes,
gpus_per_node=gpus_per_node,
num_gen_servers=gen_num,
num_ctx_servers=ctx_num,
gen_world_size=gen_world_size,
ctx_world_size=ctx_world_size,
)
with open(os.path.join(log_dir, "allocations.json"), "w") as f:
json.dump(allocations, f, indent=2)
# Generate disagg server config
server_config = convert_allocations_to_server_config(allocations)
with open(os.path.join(log_dir, "server_config.yaml"), "w") as f:
yaml.dump(server_config, f)
disagg_server_hostname = server_config['hostname']
disagg_server_port = server_config['port']
container_name = "disaggr-test"
start_server_cmds = []
# Generate start worker commands with placeholder hostnames
for server_type in allocations.keys():
for server_id in allocations[server_type].keys():
allocation = allocations[server_type][server_id]
cuda_devices = ",".join([
str(device) for device in list(allocation["nodes"].values())[0]
])
cur_worker_env_var = worker_env_var + f" CUDA_VISIBLE_DEVICES={cuda_devices}"
cmd = [
"srun -l",
f"--nodelist {','.join(allocation['nodes'].keys())}",
f"-N {len(allocation['nodes'])}",
f"--ntasks {gen_world_size if server_type == 'GEN' else ctx_world_size}",
f"--ntasks-per-node {gpus_per_node}",
f"--container-image {env_config['container_image']}",
f"--container-name {container_name}",
f"--container-mounts {env_config['container_mount']}",
"--mpi=pmix --overlap",
f"bash {os.path.join(env_config['work_dir'], 'start_worker.sh')}",
server_type,
str(server_id),
env_config['model_path'],
str(allocation["port"]),
benchmark_config['mode'],
f"'{benchmark_config['concurrency_list']}'",
str(slurm_config['numa_bind']).lower(),
log_dir,
str(profiling_config['nsys_on']).lower(),
f"'{profiling_config['gen_profile_range']}'" if server_type
== "GEN" else f"'{profiling_config['ctx_profile_range']}'",
gen_config_path if server_type == "GEN" else ctx_config_path,
f"'{cur_worker_env_var}'",
f"&> {log_dir}/3_output_{server_type}_{server_id}.log &",
]
start_server_cmds.append(" ".join(cmd))
# Generate start server commands
cmd = [
"srun -l",
f"--nodelist {disagg_server_hostname}",
f"--container-name={container_name}",
f"--container-image={env_config['container_image']}",
f"--container-mounts={env_config['container_mount']}",
f"--mpi=pmix --overlap -N 1 -n 1",
f"bash {env_config['work_dir']}/start_server.sh {os.path.join(log_dir, 'server_config.yaml')} \"{server_env_var}\"",
f"&> {log_dir}/4_output_server.log &",
]
start_server_cmds.append(" ".join(cmd))
# Generate wait server command
cmd = [
"srun -l",
f"--container-name={container_name}",
f"--container-mounts={env_config['container_mount']}",
f"--mpi=pmix --overlap -N 1 -n 1",
f"bash {env_config['work_dir']}/wait_server.sh {disagg_server_hostname} {disagg_server_port}",
f"&> {log_dir}/5_wait_server.log",
]
start_server_cmds.append(" ".join(cmd))
with open(os.path.join(log_dir, "start_server_cmds.sh"), "w") as f:
f.write("\n".join(start_server_cmds) + "\n")
# Generate client commands
client_cmds = []
client_slurm_prefix = [
f"srun -l --container-name={container_name}",
f"--container-mounts={env_config['container_mount']}",
f"--mpi=pmix --overlap -N 1 -n 1",
]
if benchmark_config['use_nv_sa_benchmark']:
benchmark_cmd = [
f"bash {env_config['work_dir']}/run_benchmark_nv_sa.sh",
f"'{env_config['model_path']}' {isl} {osl} {benchmark_config['benchmark_ratio']} {benchmark_config['multi_round']} {gen_num} '{benchmark_config['concurrency_list']}' {benchmark_config['streaming']} '{log_dir}' {disagg_server_hostname} {disagg_server_port}",
f"&> {log_dir}/6_bench.log"
]
client_cmds.append(" ".join(client_slurm_prefix + benchmark_cmd))
else:
benchmark_cmd = [
f"bash {env_config['work_dir']}/run_benchmark.sh",
f"'{env_config['model_path']}' '{benchmark_config['dataset_file']}' {benchmark_config['multi_round']} {gen_num} '{benchmark_config['concurrency_list']}' {benchmark_config['streaming']} '{log_dir}' {disagg_server_hostname} {disagg_server_port}",
f"&> {log_dir}/6_bench.log"
]
client_cmds.append(" ".join(client_slurm_prefix + benchmark_cmd))
if config['accuracy']['enable_accuracy_test']:
accuracy_cmd = [
f"bash {env_config['work_dir']}/accuracy_eval.sh",
f"'{log_dir}' '{config['accuracy']['model']}' '{config['accuracy']['tasks']}' '{env_config['model_path']}' '{config['accuracy']['model_args_extra']}' '{log_dir}/accuracy_eval' {disagg_server_hostname} {disagg_server_port}",
f"&> {log_dir}/7_accuracy_eval.log"
]
client_cmds.append(" ".join(client_slurm_prefix + accuracy_cmd))
with open(os.path.join(log_dir, "client_cmds.sh"), "w") as f:
f.write("\n".join(client_cmds) + "\n")
# Prepare sbatch command
# yapf: disable
cmd = [
'sbatch',
f'--partition={slurm_config["partition"]}',
f'--account={slurm_config["account"]}',
f'--time={slurm_config["job_time"]}',
f'--job-name={slurm_config["job_name"]}',
f'--nodes={total_nodes}',
f'--ntasks={total_tasks}',
f'--ntasks-per-node={hw_config["gpus_per_node"]}',
*([] if not slurm_config['set_segment'] else [f'--segment={total_nodes}']),
f'--output={log_dir}/slurm-%j.out',
f'--error={log_dir}/slurm-%j.err',
*([arg for arg in slurm_config['extra_args'].split() if arg]),
slurm_config['script_file'],
# Benchmark Configuration
'--benchmark-mode', benchmark_config['mode'],
# Environment and paths
'--trtllm-repo', env_config['trtllm_repo'],
'--work-dir', env_config['work_dir'],
'--full-logdir', log_dir,
'--container-name', container_name,
'--container-mount', env_config['container_mount'],
'--container-image', env_config['container_image'],
'--build-wheel', str(env_config['build_wheel']).lower(),
'--cuda-architectures', env_config['cuda_architectures'],
'--trtllm-wheel-path', env_config['trtllm_wheel_path'],
]
# yapf: enable
if dry_run:
print(
"[WARNING] Dry run mode, will not submit the job. This should be used for test purpose only."
)
print("sbatch command:")
print(" ".join(cmd))
return
else:
# Submit the job
try:
subprocess.run(cmd, check=True)
except subprocess.CalledProcessError as e:
print(f"Error submitting job: {e}", file=sys.stderr)
sys.exit(1)
def main():
args = parse_args()
# Determine which mode to use
if args.config:
# Single config file mode
config_files = [args.config]
else:
# Directory mode - find all YAML files
yaml_pattern = os.path.join(args.dir, '*.yaml')
config_files = sorted(glob.glob(yaml_pattern))
if not config_files:
print(f"No YAML files found in directory: {args.dir}",
file=sys.stderr)
sys.exit(1)
print(f"Found {len(config_files)} YAML file(s) in {args.dir}")
# Process each config file
for config_file in config_files:
print(f"Processing: {config_file}")
try:
config = load_config(config_file)
submit_job(config, args.log_dir, args.dry_run)
print(f"Successfully submitted job for: {config_file}\n")
except Exception as e:
traceback.print_exc()
print(f"Error processing {config_file}: {e}", file=sys.stderr)
# Continue processing other files even if one fails
continue
if __name__ == '__main__':
main()