TensorRT-LLMs/examples/disaggregated/slurm/benchmark/submit.py
Kaiyu Xie db2a42f641
[None][chore] Add sample yaml for wide-ep example and minor fixes (#8825)
Signed-off-by: Zero Zeng <38289304+zerollzeng@users.noreply.github.com>
Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
Co-authored-by: Zero Zeng <38289304+zerollzeng@users.noreply.github.com>
2025-11-03 07:48:34 -08:00

197 lines
7.0 KiB
Python

#!/usr/bin/env python3
import argparse
import glob
import os
import subprocess
import sys
import yaml
def parse_args():
parser = argparse.ArgumentParser(
description='Submit disaggregated benchmark job')
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('-c',
'--config',
type=str,
help='Path to the configuration YAML file')
group.add_argument('-d',
'--dir',
type=str,
help='Directory containing YAML configuration files')
return parser.parse_args()
def load_config(config_path):
with open(config_path, 'r') as f:
return yaml.safe_load(f)
def save_worker_config(config, output_path, worker_type):
"""Save worker config to a separate YAML file."""
# Get just the worker configuration without the wrapper
worker_config = config['worker_config'][worker_type]
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'w') as f:
yaml.dump(worker_config, f, default_flow_style=False)
def calculate_nodes(world_size, num_servers, gpus_per_node):
"""Calculate required nodes based on world size and server count."""
return (world_size + gpus_per_node - 1) // gpus_per_node * num_servers
def submit_job(config):
# Extract configurations
slurm_config = config['slurm']
hw_config = config['hardware']
env_config = config['environment']
# Get number of servers from config
ctx_num = hw_config['num_ctx_servers']
gen_num = hw_config['num_gen_servers']
# Get mtp_size from gen config's speculative_config
gen_config = config['worker_config']['gen']
mtp_size = gen_config.get('speculative_config',
{}).get('num_nextn_predict_layers', 0)
# Calculate nodes based on world sizes
ctx_tp_size = config['worker_config']['ctx']['tensor_parallel_size']
ctx_pp_size = config['worker_config']['ctx']['pipeline_parallel_size']
ctx_world_size = ctx_tp_size * ctx_pp_size
ctx_nodes = calculate_nodes(ctx_world_size, ctx_num,
hw_config['gpus_per_node'])
gen_tp_size = config['worker_config']['gen']['tensor_parallel_size']
gen_pp_size = config['worker_config']['gen']['pipeline_parallel_size']
gen_world_size = gen_tp_size * gen_pp_size
gen_nodes = calculate_nodes(gen_world_size, gen_num,
hw_config['gpus_per_node'])
total_nodes = ctx_nodes + gen_nodes
total_tasks = total_nodes * hw_config['gpus_per_node']
# Generate log directory path based on configuration
isl = config['sequence']['input_length']
osl = config['sequence']['output_length']
gen_batch_size = config['worker_config']['gen']['max_batch_size']
gen_enable_attention_dp = config['worker_config']['gen'][
'enable_attention_dp']
# Create base log directory path
log_base = os.path.join(env_config['work_dir'], f"{isl}-{osl}")
# Determine directory suffix based on attention_dp
if gen_enable_attention_dp:
dir_suffix = f"ctx{ctx_num}_gen{gen_num}_dep{gen_tp_size}_batch{gen_batch_size}_eplb{config['worker_config']['gen']['moe_config']['load_balancer']['num_slots']}_mtp{mtp_size}"
else:
dir_suffix = f"ctx{ctx_num}_gen{gen_num}_tep{gen_tp_size}_batch{gen_batch_size}_eplb{config['worker_config']['gen']['moe_config']['load_balancer']['num_slots']}_mtp{mtp_size}"
# Create full log directory path
log_dir = os.path.join(log_base, dir_suffix)
os.makedirs(log_dir, exist_ok=True)
# Setup config file paths and save worker configs
ctx_config_path = os.path.join(log_dir, 'ctx_config.yaml')
gen_config_path = os.path.join(log_dir, 'gen_config.yaml')
save_worker_config(config, ctx_config_path, 'ctx')
save_worker_config(config, gen_config_path, 'gen')
# Prepare sbatch command
cmd = [
'sbatch',
f'--partition={slurm_config["partition"]}',
f'--gres=gpu:{hw_config["gpus_per_node"]}',
f'--account={slurm_config["account"]}',
f'--time={slurm_config["job_time"]}',
f'--job-name={slurm_config["job_name"]}',
f'--nodes={total_nodes}',
f'--ntasks={total_tasks}',
f'--ntasks-per-node={hw_config["gpus_per_node"]}',
f'--segment={total_nodes}',
slurm_config['script_file'],
# Hardware configuration
str(hw_config['gpus_per_node']),
str(slurm_config['numa_bind']).lower(),
str(ctx_nodes), # Number of nodes needed for ctx workers
str(gen_nodes), # Number of nodes needed for gen workers
str(ctx_world_size), # World size for ctx workers
str(gen_world_size), # World size for gen workers
# Worker configuration
str(ctx_num),
ctx_config_path,
str(gen_num),
gen_config_path,
config['benchmark']['concurrency_list'],
# Sequence and benchmark parameters
str(config['sequence']['input_length']),
str(config['sequence']['output_length']),
str(config['benchmark']['multi_round']),
str(config['benchmark']['benchmark_ratio']),
str(config['benchmark']['streaming']).lower(),
str(config['benchmark']['use_nv_sa_benchmark']).lower(),
config['benchmark']['mode'],
str(config['worker_config']['gen']['cache_transceiver_config']
['max_tokens_in_buffer']),
# Environment and paths
env_config['dataset_file'],
env_config['model_path'],
env_config['trtllm_repo'],
env_config['work_dir'],
log_dir, # Pass the generated log directory
env_config['container_mount'],
env_config['container_image'],
str(env_config['build_wheel']).lower(),
# Profiling
str(config['profiling']['nsys_on']).lower()
]
# Submit the job
try:
subprocess.run(cmd, check=True)
except subprocess.CalledProcessError as e:
print(f"Error submitting job: {e}", file=sys.stderr)
sys.exit(1)
def main():
args = parse_args()
# Determine which mode to use
if args.config:
# Single config file mode
config_files = [args.config]
else:
# Directory mode - find all YAML files
yaml_pattern = os.path.join(args.dir, '*.yaml')
config_files = sorted(glob.glob(yaml_pattern))
if not config_files:
print(f"No YAML files found in directory: {args.dir}",
file=sys.stderr)
sys.exit(1)
print(f"Found {len(config_files)} YAML file(s) in {args.dir}")
# Process each config file
for config_file in config_files:
print(f"\nProcessing: {config_file}")
try:
config = load_config(config_file)
submit_job(config)
print(f"Successfully submitted job for: {config_file}")
except Exception as e:
print(f"Error processing {config_file}: {e}", file=sys.stderr)
# Continue processing other files even if one fails
continue
if __name__ == '__main__':
main()