#!/usr/bin/env python3 import argparse import glob import json import math import os import shutil import subprocess import sys from datetime import datetime from typing import Any, Dict, List import yaml def parse_args(): parser = argparse.ArgumentParser( description='Submit disaggregated benchmark job') group = parser.add_mutually_exclusive_group(required=True) group.add_argument('-c', '--config', type=str, help='Path to the configuration YAML file') group.add_argument('-d', '--dir', type=str, help='Directory containing YAML configuration files') parser.add_argument('--log-dir', type=str, default=None, help='Log directory') return parser.parse_args() def load_config(config_path): with open(config_path, 'r') as f: return yaml.safe_load(f) def save_worker_config(config, output_path, worker_type): """Save worker config to a separate YAML file.""" # Get just the worker configuration without the wrapper worker_config = config['worker_config'][worker_type] os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, 'w') as f: yaml.dump(worker_config, f, default_flow_style=False) def calculate_nodes(world_size, num_servers, gpus_per_node): """Calculate required nodes based on world size and server count.""" return math.ceil(world_size * num_servers / gpus_per_node) def allocate_gpus( total_nodes: int, gpus_per_node: int, num_gen_servers: int, num_ctx_servers: int, gen_world_size: int, ctx_world_size: int, base_port: int = 8000, ) -> List[Dict[str, Any]]: allocations = [] hostnames = [f"" for i in range(total_nodes)] global_gpu_cursor = 0 def get_gpu_location(gpus_per_node: int): node_id = global_gpu_cursor // gpus_per_node local_gpu_id = global_gpu_cursor % gpus_per_node return node_id, local_gpu_id def assign_server(server_allocation: Dict[str, Any], world_size: int, gpus_per_node: int): nonlocal global_gpu_cursor for _ in range(world_size): node_id, gpu_id = get_gpu_location(gpus_per_node) hostname = hostnames[node_id] if hostname not in server_allocation["nodes"]: server_allocation["nodes"][hostname] = [] server_allocation["nodes"][hostname].append(gpu_id) global_gpu_cursor += 1 def assign_servers( server_allocations: List[Dict[str, Any]], server_type: str, num_servers: int, world_size: int, gpus_per_node: int, ): for i in range(num_servers): server_allocation = { "server_type": server_type, "server_id": i, "port": base_port + i, "nodes": {}, } assign_server(server_allocation, world_size, gpus_per_node) server_allocations.append(server_allocation) assign_servers(allocations, "GEN", num_gen_servers, gen_world_size, gpus_per_node) assign_servers(allocations, "CTX", num_ctx_servers, ctx_world_size, gpus_per_node) return allocations def submit_job(config, log_dir): # Extract configurations slurm_config = config['slurm'] slurm_config.setdefault('extra_args', '') slurm_config.setdefault('set_segment', True) hw_config = config['hardware'] env_config = config['environment'] # Set default accuracy configuration for backward compatibility if 'accuracy' not in config: config['accuracy'] = { 'enable_accuracy_test': False, 'model': 'local-completions', 'tasks': 'gsm8k', 'model_args_extra': 'num_concurrent=512,max_retries=3,tokenized_requests=false,timeout=1200,max_gen_toks=256,max_length=4096' } # Set default environment configuration for backward compatibility env_config.setdefault('trtllm_repo', '') env_config.setdefault('build_wheel', False) env_config.setdefault('cuda_architectures', '') env_config.setdefault('trtllm_wheel_path', '') env_config.setdefault('worker_env_var', '') env_config.setdefault('server_env_var', '') profiling_config = config.get('profiling', {}) profiling_config.setdefault('nsys_on', False) profiling_config.setdefault('ctx_profile_range', '10-30') profiling_config.setdefault('gen_profile_range', '200-250') # Get number of servers from config ctx_num = hw_config['num_ctx_servers'] gen_num = hw_config['num_gen_servers'] gpus_per_node = hw_config['gpus_per_node'] # Get mtp_size from gen config's speculative_config gen_config = config['worker_config']['gen'] mtp_size = gen_config.get('speculative_config', {}).get('num_nextn_predict_layers', 0) # Calculate nodes based on world sizes ctx_tp_size = config['worker_config']['ctx'].get('tensor_parallel_size', 1) ctx_cp_size = config['worker_config']['ctx'].get('context_parallel_size', 1) ctx_pp_size = config['worker_config']['ctx'].get('pipeline_parallel_size', 1) ctx_world_size = ctx_tp_size * ctx_cp_size * ctx_pp_size ctx_nodes = calculate_nodes(ctx_world_size, ctx_num, gpus_per_node) gen_tp_size = config['worker_config']['gen'].get('tensor_parallel_size', 1) gen_cp_size = config['worker_config']['gen'].get('context_parallel_size', 1) gen_pp_size = config['worker_config']['gen'].get('pipeline_parallel_size', 1) gen_world_size = gen_tp_size * gen_cp_size * gen_pp_size gen_nodes = calculate_nodes(gen_world_size, gen_num, gpus_per_node) total_nodes = ctx_nodes + gen_nodes total_tasks = total_nodes * gpus_per_node # Generate log directory path based on configuration isl = config['benchmark']['input_length'] osl = config['benchmark']['output_length'] gen_batch_size = config['worker_config']['gen']['max_batch_size'] gen_enable_attention_dp = config['worker_config']['gen'][ 'enable_attention_dp'] if log_dir is None: # Create base log directory path date_prefix = datetime.now().strftime("%Y%m%d") log_base = os.path.join(env_config['work_dir'], f"{date_prefix}/{isl}-{osl}") # Get eplb num_slots for gen worker load_balancer_config = config['worker_config']['gen'].get( 'moe_config', {}).get('load_balancer', {}) if isinstance(load_balancer_config, str): with open(load_balancer_config, 'r') as f: load_balancer_config = yaml.safe_load(f) eplb_num_slots = load_balancer_config.get('num_slots', 0) # Determine directory suffix based on attention_dp if gen_enable_attention_dp: dir_suffix = f"ctx{ctx_num}_gen{gen_num}_dep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}" else: dir_suffix = f"ctx{ctx_num}_gen{gen_num}_tep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}" # Create full log directory path log_dir = os.path.join(log_base, dir_suffix) # Remove existing directory if it exists if os.path.exists(log_dir): print(f"[WARNING] Removing existing log directory: {log_dir}") shutil.rmtree(log_dir) os.makedirs(log_dir) print(f"Log will be saved to: {log_dir}") # Setup config file paths and save worker configs ctx_config_path = os.path.join(log_dir, 'ctx_config.yaml') gen_config_path = os.path.join(log_dir, 'gen_config.yaml') save_worker_config(config, ctx_config_path, 'ctx') save_worker_config(config, gen_config_path, 'gen') # Prepare allocation template allocations = allocate_gpus( total_nodes=total_nodes, gpus_per_node=gpus_per_node, num_gen_servers=gen_num, num_ctx_servers=ctx_num, gen_world_size=gen_world_size, ctx_world_size=ctx_world_size, ) with open(os.path.join(log_dir, "allocations.json"), "w") as f: json.dump(allocations, f, indent=2) # Generate start worker commands with placeholder hostnames start_worker_cmds = [] for allocation in allocations: server_type = allocation["server_type"] cuda_devices = ",".join( [str(device) for device in list(allocation["nodes"].values())[0]]) worker_env_var = env_config[ 'worker_env_var'] + f" CUDA_VISIBLE_DEVICES={cuda_devices}" cmd = [ "srun", "-l", "--nodelist", ",".join(allocation["nodes"].keys()), "-N", str(len(allocation["nodes"])), "--ntasks", str(gen_world_size) if server_type == "GEN" else str(ctx_world_size), "--ntasks-per-node", str(gpus_per_node), "--container-image", env_config['container_image'], "--container-name", "disaggr-test", "--container-mounts", env_config['container_mount'], "--mpi", "pmix", "--overlap", "bash", os.path.join(env_config['work_dir'], "start_worker.sh"), server_type, str(allocation["server_id"]), env_config['model_path'], str(allocation["port"]), config['benchmark']['mode'], config['benchmark']['concurrency_list'], str(slurm_config['numa_bind']).lower(), log_dir, str(profiling_config['nsys_on']).lower(), profiling_config['gen_profile_range'] if server_type == "GEN" else profiling_config['ctx_profile_range'], gen_config_path if server_type == "GEN" else ctx_config_path, f'"{worker_env_var}"', f"&> {log_dir}/3_output_{server_type}_{allocation['server_id']}.log &", ] start_worker_cmds.append(" ".join(cmd)) with open(os.path.join(log_dir, "start_worker_cmds.txt"), "w") as f: f.write("\n".join(start_worker_cmds) + "\n") # Prepare sbatch command # yapf: disable cmd = [ 'sbatch', f'--partition={slurm_config["partition"]}', f'--account={slurm_config["account"]}', f'--time={slurm_config["job_time"]}', f'--job-name={slurm_config["job_name"]}', f'--nodes={total_nodes}', f'--ntasks={total_tasks}', f'--ntasks-per-node={hw_config["gpus_per_node"]}', *([] if not slurm_config['set_segment'] else [f'--segment={total_nodes}']), f'--output={log_dir}/slurm-%j.out', f'--error={log_dir}/slurm-%j.err', f'--gpus-per-node={hw_config["gpus_per_node"]}', *([arg for arg in slurm_config['extra_args'].split() if arg]), slurm_config['script_file'], # Worker configuration '--num-ctx-servers', str(ctx_num), '--num-gen-servers', str(gen_num), '--concurrency-list', config['benchmark']['concurrency_list'], # Sequence and benchmark parameters '--isl', str(config['benchmark']['input_length']), '--osl', str(config['benchmark']['output_length']), '--multi-round', str(config['benchmark']['multi_round']), '--benchmark-ratio', str(config['benchmark']['benchmark_ratio']), '--streaming', str(config['benchmark']['streaming']).lower(), '--use-nv-sa-benchmark', str(config['benchmark']['use_nv_sa_benchmark']).lower(), '--benchmark-mode', config['benchmark']['mode'], # Environment and paths '--dataset-file', config['benchmark']['dataset_file'], '--model-path', env_config['model_path'], '--trtllm-repo', env_config['trtllm_repo'], '--work-dir', env_config['work_dir'], '--full-logdir', log_dir, '--container-mount', env_config['container_mount'], '--container-image', env_config['container_image'], '--build-wheel', str(env_config['build_wheel']).lower(), '--cuda-architectures', env_config['cuda_architectures'], '--trtllm-wheel-path', env_config['trtllm_wheel_path'], # Accuracy evaluation '--enable-accuracy-test', str(config['accuracy']['enable_accuracy_test']).lower(), '--accuracy-model', config['accuracy']['model'], '--accuracy-tasks', config['accuracy']['tasks'], '--model-args-extra', config['accuracy']['model_args_extra'], # Server environment variables '--server-env-var', env_config['server_env_var'] ] # yapf: enable # Submit the job try: subprocess.run(cmd, check=True) except subprocess.CalledProcessError as e: print(f"Error submitting job: {e}", file=sys.stderr) sys.exit(1) def main(): args = parse_args() # Determine which mode to use if args.config: # Single config file mode config_files = [args.config] else: # Directory mode - find all YAML files yaml_pattern = os.path.join(args.dir, '*.yaml') config_files = sorted(glob.glob(yaml_pattern)) if not config_files: print(f"No YAML files found in directory: {args.dir}", file=sys.stderr) sys.exit(1) print(f"Found {len(config_files)} YAML file(s) in {args.dir}") # Process each config file for config_file in config_files: print(f"Processing: {config_file}") try: config = load_config(config_file) submit_job(config, args.log_dir) print(f"Successfully submitted job for: {config_file}\n") except Exception as e: print(f"Error processing {config_file}: {e}", file=sys.stderr) # Continue processing other files even if one fails continue if __name__ == '__main__': main()