diff --git a/examples/disaggregated/slurm/benchmark/start_server.sh b/examples/disaggregated/slurm/benchmark/start_server.sh index 44d4b8e5c3..ff8e5aa902 100644 --- a/examples/disaggregated/slurm/benchmark/start_server.sh +++ b/examples/disaggregated/slurm/benchmark/start_server.sh @@ -4,12 +4,5 @@ set -e set -x config_file=$1 -server_env_var=$2 - -# Export server environment variables from config -for env_var in ${server_env_var}; do - export "${env_var}" - echo "Exported: ${env_var}" -done trtllm-serve disaggregated -c ${config_file} -t 7200 -r 7200 diff --git a/examples/disaggregated/slurm/benchmark/start_worker.sh b/examples/disaggregated/slurm/benchmark/start_worker.sh index d96201aa7e..6810e5241a 100644 --- a/examples/disaggregated/slurm/benchmark/start_worker.sh +++ b/examples/disaggregated/slurm/benchmark/start_worker.sh @@ -7,23 +7,16 @@ role=${1} instance_id=${2} model_path=${3} port=${4} -benchmark_mode=${5} -concurrency=${6} -numa_bind=${7} -log_dir=${8} -enable_nsys=${9} -profile_range=${10} -config_file=${11} -worker_env_var=${12} +numa_bind=${5} +log_dir=${6} +enable_nsys=${7} +config_file=${8} +# Clear UCX_TLS for specific clusters unset UCX_TLS -echo "SLURM_PROCID: ${SLURM_PROCID}, hostname: $(hostname), instance_id: ${instance_id}" -# Export worker environment variables from config -for env_var in ${worker_env_var}; do - export "${env_var}" - echo "Exported: ${env_var}" -done +echo "SLURM_PROCID: ${SLURM_PROCID}, hostname: $(hostname), instance_id: ${instance_id}" +echo "CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES}" if [ "${numa_bind}" = "true" ]; then numa_bind_cmd="numactl -m 0,1" @@ -33,10 +26,6 @@ else echo "Not binding memory. If on GB200/GB300 NVL72, use \"numactl -m 0,1\" to only allocate memory from nodes." fi -if [ "${benchmark_mode}" = "gen_only" ]; then - export TLLM_BENCHMARK_REQ_QUEUES_SIZE=${concurrency} -fi - echo "config_file: ${config_file}" nsys_prefix="" @@ -44,11 +33,7 @@ if [ "${enable_nsys}" != "true" ]; then echo "nsys is not enabled, start normal flow" else nsys_file=${log_dir}/nsys_worker_proc_${role}_${instance_id}_${SLURM_PROCID} - export TLLM_PROFILE_RECORD_GC=1 - export TLLM_NVTX_DEBUG=1 - export NSYS_MPI_STORE_TEAMS_PER_RANK=1 - export TLLM_PROFILE_START_STOP=${profile_range} - echo "nsys is enabled on ${role} GPUs, TLLM_PROFILE_START_STOP=${profile_range}" + echo "nsys is enabled on ${role} GPUs, TLLM_PROFILE_START_STOP=${TLLM_PROFILE_START_STOP}" nsys_prefix="nsys profile -o ${nsys_file} -f true -t cuda,nvtx,python-gil -c cudaProfilerApi --cuda-graph-trace node --capture-range-end=stop --gpu-metrics-devices=none" fi diff --git a/examples/disaggregated/slurm/benchmark/submit.py b/examples/disaggregated/slurm/benchmark/submit.py index a6fa8cc42c..330dff0943 100644 --- a/examples/disaggregated/slurm/benchmark/submit.py +++ b/examples/disaggregated/slurm/benchmark/submit.py @@ -113,6 +113,7 @@ def convert_allocations_to_server_config(allocations, server_port=8333): generation_servers = {} context_servers = {} server_hostname = None + for server_type in allocations.keys(): num_servers = len(allocations[server_type]) urls = [] @@ -120,13 +121,16 @@ def convert_allocations_to_server_config(allocations, server_port=8333): instance = allocations[server_type][server_id] urls.append( f"{list(instance['nodes'].keys())[0]}:{instance['port']}") + + server_config_entry = {'num_instances': num_servers, 'urls': urls} + if server_type == "GEN": - generation_servers = {'num_instances': num_servers, 'urls': urls} + generation_servers = server_config_entry server_hostname = urls[0].split(':')[0] if allocations[server_type][server_id]['port'] == server_port: server_port += 1 # Avoid port conflict elif server_type == "CTX": - context_servers = {'num_instances': num_servers, 'urls': urls} + context_servers = server_config_entry server_config = { 'backend': 'pytorch', @@ -168,6 +172,121 @@ def replace_env_in_file(log_dir, file_path, env_var): return tmp_dir +def build_worker_environment(worker_config, env_config, role, benchmark_mode, + nsys_on, profile_range, concurrency, gpu_ids): + """Build complete environment dictionary for worker processes. + + Args: + worker_config: Worker configuration dict + env_config: Environment configuration dict + role: Server role ("CTX" or "GEN") + benchmark_mode: Benchmark mode string + nsys_on: Whether nsys profiling is enabled + profile_range: Profile range string (e.g., "10-30") + concurrency: Concurrency level + gpu_ids: List of GPU IDs assigned to this worker + + Returns: + Dictionary of environment variables + + Note: + CUDA_VISIBLE_DEVICES is NOT set here. It is passed as an argument to + start_worker.sh and set per-rank based on SLURM_LOCALID. + """ + env = {} + + # 1. Use gpu_ids to set CUDA_VISIBLE_DEVICES + cuda_devices = ','.join(map(str, gpu_ids)) + env["CUDA_VISIBLE_DEVICES"] = cuda_devices + + # 2. Parse user-defined worker env vars from config + worker_env_var = env_config.get('worker_env_var', '') + for var_string in worker_env_var.split(): + if '=' in var_string: + key, val = var_string.split('=', 1) + env[key] = val + + # 3. Add role-specific env vars (CTX or GEN) + role_env_vars = { + "CTX": env_config.get('ctx_worker_env_var', ''), + "GEN": env_config.get('gen_worker_env_var', '') + } + role_specific_env_var = role_env_vars.get(role, '') + for var_string in role_specific_env_var.split(): + if '=' in var_string: + key, val = var_string.split('=', 1) + env[key] = val + + # 4. Add mode-based env vars + if benchmark_mode == "gen_only_no_context": + env["TRTLLM_DISAGG_BENCHMARK_GEN_ONLY"] = "1" + if benchmark_mode == "gen_only": + env["TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP"] = "1" + if role == "GEN": + env["TLLM_BENCHMARK_REQ_QUEUES_SIZE"] = str(concurrency) + + # 5. Add profiling env vars (conditional) + if nsys_on: + env["TLLM_PROFILE_RECORD_GC"] = "1" + env["TLLM_NVTX_DEBUG"] = "1" + env["NSYS_MPI_STORE_TEAMS_PER_RANK"] = "1" + env["TLLM_PROFILE_START_STOP"] = profile_range + + return env + + +def build_server_environment(env_config, benchmark_mode): + """Build complete environment dictionary for server process. + + Args: + env_config: Environment configuration dict + benchmark_mode: Benchmark mode string + + Returns: + Dictionary of environment variables + """ + env = {} + + # Parse user-defined server env vars + server_env_var = env_config.get('server_env_var', '') + for var_string in server_env_var.split(): + if '=' in var_string: + key, val = var_string.split('=', 1) + env[key] = val + + # Add mode-based env vars + if benchmark_mode == "gen_only_no_context": + env["TRTLLM_DISAGG_BENCHMARK_GEN_ONLY"] = "1" + + return env + + +def format_export_string(env_dict): + """Convert environment dictionary to srun --export format. + + Args: + env_dict: Dictionary of environment variables + + Returns: + String formatted for srun --export flag (e.g., "KEY1=val1,KEY2=val2") + Returns "NONE" if no variables specified. + + Note: + Values containing commas are quoted to avoid conflicts with srun's delimiter. + """ + if not env_dict: + return "NONE" + + export_list = [] + for k, v in env_dict.items(): + # srun cannot handle values that contain commas + if ',' in v: + export_list.append(f"'{k}={v}'") + else: + export_list.append(f"{k}={v}") + return ",".join(export_list) + + def save_env_file(env_file, server_env_var, worker_env_var, ctx_worker_env_var, gen_worker_env_var): @@ -226,16 +345,6 @@ def submit_job(config, log_dir, dry_run): env_config.setdefault('worker_env_var', '') env_config.setdefault('server_env_var', '') - worker_env_var = env_config.get('worker_env_var', '') - ctx_worker_env_var = env_config.get('ctx_worker_env_var', '') - gen_worker_env_var = env_config.get('gen_worker_env_var', '') - server_env_var = env_config.get('server_env_var', '') - if benchmark_config['mode'] == "gen_only_no_context": - worker_env_var += " TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1" - server_env_var += " TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1" - if benchmark_config['mode'] == "gen_only": - worker_env_var += " TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP=1" - profiling_config = config.get('profiling', {}) profiling_config.setdefault('nsys_on', False) profiling_config.setdefault('ctx_profile_range', '10-30') @@ -320,7 +429,11 @@ def submit_job(config, log_dir, dry_run): os.makedirs(log_dir, exist_ok=True) print(f"Log will be saved to: {log_dir}") - # Save environment variables + # Save environment variables (for record-keeping only) + worker_env_var = env_config.get('worker_env_var', '') + ctx_worker_env_var = env_config.get('ctx_worker_env_var', '') + gen_worker_env_var = env_config.get('gen_worker_env_var', '') + server_env_var = env_config.get('server_env_var', '') save_env_file(os.path.join(log_dir, "env_vars.json"), server_env_var, worker_env_var, ctx_worker_env_var, gen_worker_env_var) @@ -353,24 +466,52 @@ def submit_job(config, log_dir, dry_run): start_server_cmds = [] container_mount_str = env_config['container_mount'] container_mount_str += f",{script_dir}:{script_dir}" + + # Pre-define server-type-specific configurations + server_configs = { + "GEN": { + "world_size": gen_world_size, + "profile_range": profiling_config['gen_profile_range'], + "config_path": gen_config_path + }, + "CTX": { + "world_size": ctx_world_size, + "profile_range": profiling_config['ctx_profile_range'], + "config_path": ctx_config_path + } + } + # Generate start worker commands with placeholder hostnames for server_type in allocations.keys(): + server_cfg = server_configs[server_type] + for server_id in allocations[server_type].keys(): allocation = allocations[server_type][server_id] - cuda_devices = ",".join([ - str(device) for device in list(allocation["nodes"].values())[0] - ]) - cur_worker_env_var = worker_env_var + \ - f" CUDA_VISIBLE_DEVICES={cuda_devices}" + \ - (f" {ctx_worker_env_var}" if server_type == "CTX" else "") + \ - (f" {gen_worker_env_var}" if server_type == "GEN" else "") + # Get GPU IDs for this server from allocation + # When multi-node, all nodes have same device list, so use first node [0] + gpu_ids = list(allocation["nodes"].values())[0] + + # Build environment for this worker + worker_env = build_worker_environment( + worker_config=worker_config, + env_config=env_config, + role=server_type, + benchmark_mode=benchmark_config['mode'], + nsys_on=profiling_config['nsys_on'], + profile_range=server_cfg['profile_range'], + concurrency=benchmark_config['concurrency_list'].split(',')[0], + gpu_ids=gpu_ids, + ) + export_str = format_export_string(worker_env) + # Use script_dir for start_worker.sh cmd = [ "srun -l", f"--nodelist {','.join(allocation['nodes'].keys())}", f"-N {len(allocation['nodes'])}", - f"--ntasks {gen_world_size if server_type == 'GEN' else ctx_world_size}", + f"--ntasks {server_cfg['world_size']}", f"--ntasks-per-node {gpus_per_node}", + f"--export=\"{export_str}\"", f"--container-image {env_config['container_image']}", f"--container-name {container_name}", f"--container-mounts {container_mount_str}", @@ -380,28 +521,27 @@ def submit_job(config, log_dir, dry_run): str(server_id), env_config['model_path'], str(allocation["port"]), - benchmark_config['mode'], - f"'{benchmark_config['concurrency_list']}'", str(slurm_config['numa_bind']).lower(), log_dir, str(profiling_config['nsys_on']).lower(), - f"'{profiling_config['gen_profile_range']}'" if server_type - == "GEN" else f"'{profiling_config['ctx_profile_range']}'", - gen_config_path if server_type == "GEN" else ctx_config_path, - f"'{cur_worker_env_var}'", + server_cfg['config_path'], f"&> {log_dir}/3_output_{server_type}_{server_id}.log &", ] start_server_cmds.append(" ".join(cmd)) # Generate start server commands (use script_dir for start_server.sh) + server_env = build_server_environment(env_config, benchmark_config['mode']) + export_str = format_export_string(server_env) + cmd = [ "srun -l", f"--nodelist {disagg_server_hostname}", f"--container-name={container_name}", + f"--export=\"{export_str}\"", f"--container-image={env_config['container_image']}", f"--container-mounts={container_mount_str}", f"--no-container-mount-home --mpi=pmix --overlap -N 1 -n 1", - f"bash {os.path.join(script_dir, 'start_server.sh')} {os.path.join(log_dir, 'server_config.yaml')} \"{server_env_var}\"", + f"bash {os.path.join(script_dir, 'start_server.sh')} {os.path.join(log_dir, 'server_config.yaml')}", f"&> {log_dir}/4_output_server.log &", ] start_server_cmds.append(" ".join(cmd))