From cc0bbde745f96c903f343fe81d5edfd983b83410 Mon Sep 17 00:00:00 2001 From: Xianjie Qiao <5410381+qiaoxj07@users.noreply.github.com> Date: Mon, 19 Jan 2026 15:53:48 +0800 Subject: [PATCH] [None][feat] Update disagg slurm scripts (#10712) Signed-off-by: Xianjie <5410381+qiaoxj07@users.noreply.github.com> --- .../disaggregated/slurm/benchmark/config.yaml | 2 +- .../slurm/benchmark/disaggr_torch.slurm | 11 +- .../disaggregated/slurm/benchmark/get_env.py | 33 +++++ .../slurm/benchmark/run_benchmark.sh | 103 +++++++++++++- .../slurm/benchmark/run_benchmark_nv_sa.sh | 98 +++++++++++++- .../slurm/benchmark/start_worker.sh | 1 - .../disaggregated/slurm/benchmark/submit.py | 128 ++++++++++++++---- 7 files changed, 341 insertions(+), 35 deletions(-) create mode 100644 examples/disaggregated/slurm/benchmark/get_env.py diff --git a/examples/disaggregated/slurm/benchmark/config.yaml b/examples/disaggregated/slurm/benchmark/config.yaml index b0952d9b7c..50761e7e8e 100644 --- a/examples/disaggregated/slurm/benchmark/config.yaml +++ b/examples/disaggregated/slurm/benchmark/config.yaml @@ -37,7 +37,7 @@ environment: cuda_architectures: "" # Optional CUDA architectures to build for (e.g. "90-real;100-real"). If empty, builds for all architectures trtllm_wheel_path: "" # Path to pre-built TensorRT-LLM wheel. If provided, install from this wheel instead work_dir: "" - worker_env_var: "TLLM_LOG_LEVEL=INFO TRTLLM_SERVER_DISABLE_GC=1 TRTLLM_WORKER_DISABLE_GC=1 TRTLLM_ENABLE_PDL=1 ENROOT_ALLOW_DEV=yes" + worker_env_var: "TLLM_LOG_LEVEL=INFO TRTLLM_SERVER_DISABLE_GC=1 TRTLLM_WORKER_DISABLE_GC=1 TRTLLM_ENABLE_PDL=1 ENROOT_ALLOW_DEV=yes NCCL_GRAPH_MIXING_SUPPORT=0" server_env_var: "TRTLLM_SERVER_DISABLE_GC=1" # Profiling Configuration diff --git a/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm b/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm index 1938db569f..cce8604176 100644 --- a/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm +++ b/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm @@ -124,9 +124,16 @@ elif [ -d "${trtllm_repo}" ]; then echo "TensorRT-LLM installation completed successfully" else echo "trtllm_wheel_path and trtllm_repo are not provided, will use the installed TensorRT-LLM from the container" - if [ -v TRT_LLM_GIT_COMMIT ]; then - echo "TRT_LLM_GIT_COMMIT: ${TRT_LLM_GIT_COMMIT}" + # get_env file is in the same directory as this script + get_env_file=${work_dir}/get_env.py + if ! srun --container-name=${container_name} \ + --container-mounts=${container_mount} --no-container-mount-home \ + --mpi=pmix --overlap -N 1 --ntasks-per-node=1 \ + bash -c "python ${get_env_file} -e ${full_logdir}/env_vars.json" \ + &> ${full_logdir}/2_get_env.log; then + cleanup_on_failure "Failed to get TensorRT-LLM environment variables. Check ${full_logdir}/2_get_env.log for details" fi + echo "TensorRT-LLM environment variables saved to ${full_logdir}/env_vars.json" fi # Get node lists and replace the placeholder with the actual node names diff --git a/examples/disaggregated/slurm/benchmark/get_env.py b/examples/disaggregated/slurm/benchmark/get_env.py new file mode 100644 index 0000000000..0120c9f0f1 --- /dev/null +++ b/examples/disaggregated/slurm/benchmark/get_env.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +import argparse +import json +import os + + +def main(): + parser = argparse.ArgumentParser( + description="Get TensorRT-LLM environment variables and save to JSON" + ) + parser.add_argument("-e", "--env-file", required=True, help="Environment file path") + args = parser.parse_args() + + # read env file, append new envs to it + with open(args.env_file, "r") as f: + env_data = json.load(f) + + # Get environment variables + new_env_data = { + "TRT_LLM_GIT_COMMIT": os.environ.get("TRT_LLM_GIT_COMMIT", ""), + "TRT_LLM_VERSION": os.environ.get("TRT_LLM_VERSION", ""), + } + print(f"Environment variables: {new_env_data}") + env_data.update(new_env_data) + # Save to environment file + with open(args.env_file, "w") as f: + json.dump(env_data, f, indent=2) + + print(f"Environment variables saved to {args.env_file}") + + +if __name__ == "__main__": + main() diff --git a/examples/disaggregated/slurm/benchmark/run_benchmark.sh b/examples/disaggregated/slurm/benchmark/run_benchmark.sh index fdf71b6ac1..47c83cf4b7 100644 --- a/examples/disaggregated/slurm/benchmark/run_benchmark.sh +++ b/examples/disaggregated/slurm/benchmark/run_benchmark.sh @@ -6,9 +6,9 @@ set -u trap 'echo "Error occurred at line $LINENO"; exit 1' ERR # Add parameter validation -if [ "$#" -lt 9 ]; then - echo "Error: Missing required arguments" - echo "Usage: $0 model_name dataset_file multi_round concurrency_list streaming log_path hostname port" +if [ "$#" -lt 10 ]; then + echo "Error: Missing required arguments, got $# arguments, args: $@" + echo "Usage: $0 model_name dataset_file multi_round num_gen_servers concurrency_list streaming log_path hostname port ucx_warmup_requests" exit 1 fi @@ -21,6 +21,7 @@ streaming=$6 log_path=$7 hostname=$8 port=$9 +ucx_warmup_requests=${10} # check process id is not 0 if [[ ${SLURM_PROCID} != "0" ]]; then @@ -28,6 +29,99 @@ if [[ ${SLURM_PROCID} != "0" ]]; then exit 0 fi +do_get_logs(){ + local input_file=$1 + local output_file=$2 + local mode=$3 + local start_line=$4 + # check mode is ctx or gen + if [ "${mode}" = "ctx" ]; then + sed -n "${start_line},\$p" ${input_file} | grep -a "'num_generation_tokens': 0" > ${output_file} || true + elif [ "${mode}" = "gen" ]; then + sed -n "${start_line},\$p" ${input_file} | grep -a "'num_ctx_requests': 0, 'num_ctx_tokens': 0" > ${output_file} || true + else + echo "Invalid mode: ${mode}" + return 1 + fi + return 0 +} + +do_process_all_logs(){ + local input_folder=$1 + local output_folder=$2 + local mode=$3 + if [ "${mode}" != "line" ] && [ "${mode}" != "log" ] && [ "${mode}" != "clean" ]; then + echo "Invalid mode: ${mode}" + exit 1 + fi + local ctx_log + local ctx_num + local gen_log + local gen_num + local line_count + local start_line + for ctx_log in ${input_folder}/3_output_CTX_*.log; do + if [ -f "${ctx_log}" ]; then + ctx_num=$(basename "${ctx_log}" | sed 's/3_output_CTX_\([0-9]*\)\.log/\1/') + if [ "${mode}" = "line" ]; then + line_count=$(wc -l < ${ctx_log}) + echo ${line_count} > ${output_folder}/ctx_only_line_${ctx_num}.txt + elif [ "${mode}" = "log" ]; then + if [ ! -f "${output_folder}/ctx_only_line_${ctx_num}.txt" ]; then + start_line=0 + else + start_line=$(cat ${output_folder}/ctx_only_line_${ctx_num}.txt) + rm -f ${output_folder}/ctx_only_line_${ctx_num}.txt + fi + do_get_logs ${ctx_log} ${output_folder}/ctx_only_${ctx_num}.txt "ctx" ${start_line} + elif [ "${mode}" = "clean" ]; then + rm -f ${ctx_log} + fi + fi + done + # process all the gen log files in the input folder + for gen_log in ${input_folder}/3_output_GEN_*.log; do + if [ -f "${gen_log}" ]; then + gen_num=$(basename "${gen_log}" | sed 's/3_output_GEN_\([0-9]*\)\.log/\1/') + if [ "${mode}" = "line" ]; then + line_count=$(wc -l < ${gen_log}) + echo ${line_count} > ${output_folder}/gen_only_line_${gen_num}.txt + elif [ "${mode}" = "log" ]; then + if [ ! -f "${output_folder}/gen_only_line_${gen_num}.txt" ]; then + start_line=0 + else + start_line=$(cat ${output_folder}/gen_only_line_${gen_num}.txt) + rm -f ${output_folder}/gen_only_line_${gen_num}.txt + fi + do_get_logs ${gen_log} ${output_folder}/gen_only_${gen_num}.txt "gen" ${start_line} + elif [ "${mode}" = "clean" ]; then + rm -f ${gen_log} + fi + fi + done +} + +mkdir -p ${log_path}/start_logs +cp ${log_path}/3_output_CTX_*.log ${log_path}/start_logs/ 2>/dev/null || true +cp ${log_path}/3_output_GEN_*.log ${log_path}/start_logs/ 2>/dev/null || true + +# warmup requests for ucx connections +if [ "${ucx_warmup_requests}" -gt 0 ]; then + echo "warming up ucx connections with small requests... ${ucx_warmup_requests}" + python -m tensorrt_llm.serve.scripts.benchmark_serving \ + --model ${model_name} \ + --dataset-name random \ + --random-ids \ + --random-input-len 100 \ + --random-output-len 10 \ + --num-prompts ${ucx_warmup_requests} \ + --host ${hostname} \ + --port ${port} \ + --ignore-eos \ + --non-streaming + echo "UCX warmup done" +fi + echo "Hostname: ${hostname}, Port: ${port}" echo "Starting benchmark..." for concurrency in ${concurrency_list}; do @@ -35,6 +129,7 @@ for concurrency in ${concurrency_list}; do num_prompts=$((concurrency * multi_round)) echo "Benchmarking with concurrency ${concurrency} ... ${num_prompts} prompts" mkdir -p ${log_path}/concurrency_${concurrency} + do_process_all_logs ${log_path}/ ${log_path}/concurrency_${concurrency} "line" python -m tensorrt_llm.serve.scripts.benchmark_serving \ --model ${model_name} \ --backend openai \ @@ -53,4 +148,6 @@ for concurrency in ${concurrency_list}; do --percentile-metrics "ttft,tpot,itl,e2el" \ $(if [ "${streaming}" = "false" ]; then echo "--non-streaming"; fi) echo "Benchmark with concurrency ${concurrency} done" + do_process_all_logs ${log_path}/ ${log_path}/concurrency_${concurrency} "log" done +# do_process_all_logs ${log_path}/ ${log_path}/concurrency_${concurrency} "clean" diff --git a/examples/disaggregated/slurm/benchmark/run_benchmark_nv_sa.sh b/examples/disaggregated/slurm/benchmark/run_benchmark_nv_sa.sh index ef3ea39f6b..8abcd7514d 100644 --- a/examples/disaggregated/slurm/benchmark/run_benchmark_nv_sa.sh +++ b/examples/disaggregated/slurm/benchmark/run_benchmark_nv_sa.sh @@ -42,6 +42,7 @@ streaming=$8 log_path=$9 hostname=${10} port=${11} +ucx_warmup_requests=${12} # check process id is not 0 if [[ ${SLURM_PROCID} != "0" ]]; then @@ -59,14 +60,107 @@ fi echo "Cloning benchmark repository..." git clone "${BENCH_SERVING_REPO}" "${BENCH_SERVING_DIR}" +do_get_logs(){ + local input_file=$1 + local output_file=$2 + local mode=$3 + local start_line=$4 + # check mode is ctx or gen + if [ "${mode}" = "ctx" ]; then + sed -n "${start_line},\$p" ${input_file} | grep -a "'num_generation_tokens': 0" > ${output_file} || true + elif [ "${mode}" = "gen" ]; then + sed -n "${start_line},\$p" ${input_file} | grep -a "'num_ctx_requests': 0, 'num_ctx_tokens': 0" > ${output_file} || true + else + echo "Invalid mode: ${mode}" + return 1 + fi + return 0 +} + +do_process_all_logs(){ + local input_folder=$1 + local output_folder=$2 + local mode=$3 + if [ "${mode}" != "line" ] && [ "${mode}" != "log" ] && [ "${mode}" != "clean" ]; then + echo "Invalid mode: ${mode}" + exit 1 + fi + local ctx_log + local ctx_num + local gen_log + local gen_num + local line_count + local start_line + for ctx_log in ${input_folder}/output_ctx_*.log; do + if [ -f "${ctx_log}" ]; then + ctx_num=$(basename "${ctx_log}" | sed 's/output_ctx_\([0-9]*\)\.log/\1/') + if [ "${mode}" = "line" ]; then + line_count=$(wc -l < ${ctx_log}) + echo ${line_count} > ${output_folder}/ctx_only_line_${ctx_num}.txt + elif [ "${mode}" = "log" ]; then + if [ ! -f "${output_folder}/ctx_only_line_${ctx_num}.txt" ]; then + start_line=0 + else + start_line=$(cat ${output_folder}/ctx_only_line_${ctx_num}.txt) + rm -f ${output_folder}/ctx_only_line_${ctx_num}.txt + fi + do_get_logs ${ctx_log} ${output_folder}/ctx_only_${ctx_num}.txt "ctx" ${start_line} + elif [ "${mode}" = "clean" ]; then + rm -f ${ctx_log} + fi + fi + done + # process all the gen log files in the input folder + for gen_log in ${input_folder}/output_gen_*.log; do + if [ -f "${gen_log}" ]; then + gen_num=$(basename "${gen_log}" | sed 's/output_gen_\([0-9]*\)\.log/\1/') + if [ "${mode}" = "line" ]; then + line_count=$(wc -l < ${gen_log}) + echo ${line_count} > ${output_folder}/gen_only_line_${gen_num}.txt + elif [ "${mode}" = "log" ]; then + if [ ! -f "${output_folder}/gen_only_line_${gen_num}.txt" ]; then + start_line=0 + else + start_line=$(cat ${output_folder}/gen_only_line_${gen_num}.txt) + rm -f ${output_folder}/gen_only_line_${gen_num}.txt + fi + do_get_logs ${gen_log} ${output_folder}/gen_only_${gen_num}.txt "gen" ${start_line} + elif [ "${mode}" = "clean" ]; then + rm -f ${gen_log} + fi + fi + done +} + +mkdir -p ${log_path}/start_logs +cp ${log_path}/output_ctx_*.log ${log_path}/start_logs/ 2>/dev/null || true +cp ${log_path}/output_gen_*.log ${log_path}/start_logs/ 2>/dev/null || true + +# warmup requests for ucx connections +if [ "${ucx_warmup_requests}" -gt 0 ]; then + echo "warming up ucx connections with small requests... ${ucx_warmup_requests}" + python -m tensorrt_llm.serve.scripts.benchmark_serving \ + --model ${model_name} \ + --dataset-name random \ + --random-ids \ + --random-input-len 100 \ + --random-output-len 10 \ + --num-prompts ${ucx_warmup_requests} \ + --host ${hostname} \ + --port ${port} \ + --ignore-eos \ + --non-streaming + echo "UCX warmup done" +fi + # Run benchmarks echo "Starting benchmark..." for concurrency in ${concurrency_list}; do concurrency=$((concurrency * num_gen_servers)) num_prompts=$((concurrency * multi_round)) output_dir="${log_path}/concurrency_${concurrency}" - echo "Benchmarking with concurrency ${concurrency} ... ${num_prompts} prompts" + do_process_all_logs ${log_path}/ ${log_path}/concurrency_${concurrency} "line" mkdir -p "${output_dir}" python "${BENCH_SCRIPT}" \ @@ -89,4 +183,6 @@ for concurrency in ${concurrency_list}; do $([ "${streaming}" = "false" ] && echo "--non-streaming") echo "Benchmark with concurrency ${concurrency} done" + do_process_all_logs ${log_path}/ ${log_path}/concurrency_${concurrency} "log" done +# do_process_all_logs ${log_path}/ ${log_path}/concurrency_${concurrency} "clean" diff --git a/examples/disaggregated/slurm/benchmark/start_worker.sh b/examples/disaggregated/slurm/benchmark/start_worker.sh index 7e995fe18b..d96201aa7e 100644 --- a/examples/disaggregated/slurm/benchmark/start_worker.sh +++ b/examples/disaggregated/slurm/benchmark/start_worker.sh @@ -34,7 +34,6 @@ else fi if [ "${benchmark_mode}" = "gen_only" ]; then - export TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP=1 export TLLM_BENCHMARK_REQ_QUEUES_SIZE=${concurrency} fi diff --git a/examples/disaggregated/slurm/benchmark/submit.py b/examples/disaggregated/slurm/benchmark/submit.py index d7e9f0a621..40d158e62d 100644 --- a/examples/disaggregated/slurm/benchmark/submit.py +++ b/examples/disaggregated/slurm/benchmark/submit.py @@ -168,6 +168,27 @@ def replace_env_in_file(log_dir, file_path, env_var): return tmp_dir +def save_env_file(env_file, server_env_var, worker_env_var, ctx_worker_env_var, + gen_worker_env_var): + + def get_env_var_str(env_var_str): + env_data = {} + for env_var in env_var_str.split(): + if '=' in env_var: + key, value = env_var.split('=', 1) + env_data[key] = value + return env_data + + env_data = {} + env_data['server_env_var'] = get_env_var_str(server_env_var) + env_data['worker_env_var'] = get_env_var_str(worker_env_var) + env_data['ctx_worker_env_var'] = get_env_var_str(ctx_worker_env_var) + env_data['gen_worker_env_var'] = get_env_var_str(gen_worker_env_var) + with open(env_file, 'w') as f: + json.dump(env_data, f, indent=2) + print(f"Environment variables saved to {env_file}") + + def submit_job(config, log_dir, dry_run): # Extract configurations slurm_config = config['slurm'] @@ -179,6 +200,11 @@ def submit_job(config, log_dir, dry_run): worker_config = config['worker_config'] benchmark_config = config['benchmark'] + if 'work_dir' in env_config and os.path.isdir(env_config['work_dir']): + script_dir = env_config['work_dir'] + else: + script_dir = os.path.dirname(os.path.abspath(__file__)) + # Set default accuracy configuration for backward compatibility if 'accuracy' not in config: config['accuracy'] = { @@ -200,11 +226,15 @@ def submit_job(config, log_dir, dry_run): env_config.setdefault('worker_env_var', '') env_config.setdefault('server_env_var', '') - worker_env_var = env_config.get('worker_env_var') - server_env_var = env_config.get('server_env_var') + worker_env_var = env_config.get('worker_env_var', '') + ctx_worker_env_var = env_config.get('ctx_worker_env_var', '') + gen_worker_env_var = env_config.get('gen_worker_env_var', '') + server_env_var = env_config.get('server_env_var', '') if benchmark_config['mode'] == "gen_only_no_context": worker_env_var += " TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1" server_env_var += " TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1" + if benchmark_config['mode'] == "gen_only": + worker_env_var += " TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP=1" profiling_config = config.get('profiling', {}) profiling_config.setdefault('nsys_on', False) @@ -228,6 +258,8 @@ def submit_job(config, log_dir, dry_run): gen_pp_size = worker_config['gen'].get('pipeline_parallel_size', 1) gen_world_size = gen_tp_size * gen_cp_size * gen_pp_size gen_nodes = calculate_nodes(gen_world_size, gen_num, gpus_per_node) + ucx_warmup_requests = 2 * ctx_world_size * \ + gen_world_size if benchmark_config['mode'] == "e2e" else 0 total_nodes = ctx_nodes + gen_nodes total_tasks = total_nodes * gpus_per_node @@ -251,8 +283,10 @@ def submit_job(config, log_dir, dry_run): {}).get('num_nextn_predict_layers', 0) # Create base log directory path + if 'log_dir' in env_config and env_config['log_dir']: + log_dir = env_config['log_dir'] if log_dir is None: - log_base = os.path.join(env_config['work_dir'], "logs") + log_base = os.path.join(script_dir, "logs") date_prefix = datetime.now().strftime("%Y%m%d-%H%M%S") log_base = os.path.join(log_base, f"{date_prefix}/{isl}-{osl}") @@ -266,13 +300,30 @@ def submit_job(config, log_dir, dry_run): # Create full log directory path log_dir = os.path.join(log_base, dir_suffix) - # Remove existing directory if it exists + # if trtllm_config.yaml exists, don't remove the directory, remove other files in the directory except trtllm_config.yaml + # also don't remove concurrency_* folders if os.path.exists(log_dir): - print(f"[WARNING] Removing existing log directory: {log_dir}") - shutil.rmtree(log_dir) - os.makedirs(log_dir) + if not os.path.exists(os.path.join(log_dir, 'trtllm_config.yaml')): + print(f"[WARNING] Removing existing log directory: {log_dir}") + shutil.rmtree(log_dir) + else: + print( + f"[WARNING] trtllm_config.yaml exists, not removing the directory: {log_dir}" + ) + for file in os.listdir(log_dir): + if file != 'trtllm_config.yaml' and not file.startswith( + 'concurrency_'): + if os.path.isdir(os.path.join(log_dir, file)): + shutil.rmtree(os.path.join(log_dir, file)) + else: + os.remove(os.path.join(log_dir, file)) + os.makedirs(log_dir, exist_ok=True) print(f"Log will be saved to: {log_dir}") + # Save environment variables + save_env_file(os.path.join(log_dir, "env_vars.json"), server_env_var, + worker_env_var, ctx_worker_env_var, gen_worker_env_var) + # Setup config file paths and save worker configs ctx_config_path = os.path.join(log_dir, 'ctx_config.yaml') gen_config_path = os.path.join(log_dir, 'gen_config.yaml') @@ -300,6 +351,8 @@ def submit_job(config, log_dir, dry_run): container_name = "disaggr-test" start_server_cmds = [] + container_mount_str = env_config['container_mount'] + container_mount_str += f",{script_dir}:{script_dir}" # Generate start worker commands with placeholder hostnames for server_type in allocations.keys(): for server_id in allocations[server_type].keys(): @@ -307,7 +360,11 @@ def submit_job(config, log_dir, dry_run): cuda_devices = ",".join([ str(device) for device in list(allocation["nodes"].values())[0] ]) - cur_worker_env_var = worker_env_var + f" CUDA_VISIBLE_DEVICES={cuda_devices}" + cur_worker_env_var = worker_env_var + \ + f" CUDA_VISIBLE_DEVICES={cuda_devices}" + \ + (f" {ctx_worker_env_var}" if server_type == "CTX" else "") + \ + (f" {gen_worker_env_var}" if server_type == "GEN" else "") + # Use script_dir for start_worker.sh cmd = [ "srun -l", f"--nodelist {','.join(allocation['nodes'].keys())}", @@ -316,9 +373,9 @@ def submit_job(config, log_dir, dry_run): f"--ntasks-per-node {gpus_per_node}", f"--container-image {env_config['container_image']}", f"--container-name {container_name}", - f"--container-mounts {env_config['container_mount']}", + f"--container-mounts {container_mount_str}", "--no-container-mount-home --mpi=pmix --overlap", - f"bash {os.path.join(env_config['work_dir'], 'start_worker.sh')}", + f"bash {os.path.join(script_dir, 'start_worker.sh')}", server_type, str(server_id), env_config['model_path'], @@ -336,26 +393,26 @@ def submit_job(config, log_dir, dry_run): ] start_server_cmds.append(" ".join(cmd)) - # Generate start server commands + # Generate start server commands (use script_dir for start_server.sh) cmd = [ "srun -l", f"--nodelist {disagg_server_hostname}", f"--container-name={container_name}", f"--container-image={env_config['container_image']}", - f"--container-mounts={env_config['container_mount']}", + f"--container-mounts={container_mount_str}", f"--no-container-mount-home --mpi=pmix --overlap -N 1 -n 1", - f"bash {env_config['work_dir']}/start_server.sh {os.path.join(log_dir, 'server_config.yaml')} \"{server_env_var}\"", + f"bash {os.path.join(script_dir, 'start_server.sh')} {os.path.join(log_dir, 'server_config.yaml')} \"{server_env_var}\"", f"&> {log_dir}/4_output_server.log &", ] start_server_cmds.append(" ".join(cmd)) - # Generate wait server command + # Generate wait server command (use script_dir for wait_server.sh) cmd = [ "srun -l", f"--container-name={container_name}", - f"--container-mounts={env_config['container_mount']}", + f"--container-mounts={container_mount_str}", f"--mpi=pmix --overlap -N 1 -n 1", - f"bash {env_config['work_dir']}/wait_server.sh {disagg_server_hostname} {disagg_server_port}", + f"bash {os.path.join(script_dir, 'wait_server.sh')} {disagg_server_hostname} {disagg_server_port}", f"&> {log_dir}/5_wait_server.log", ] start_server_cmds.append(" ".join(cmd)) @@ -363,14 +420,13 @@ def submit_job(config, log_dir, dry_run): with open(os.path.join(log_dir, "start_server_cmds.sh"), "w") as f: f.write("\n".join(start_server_cmds) + "\n") - # Generate client commands + # Generate client commands (use script_dir for benchmark scripts) client_cmds = [] client_slurm_prefix = [ f"srun -l --container-name={container_name}", - f"--container-mounts={env_config['container_mount']}", + f"--container-mounts={container_mount_str}", f"--mpi=pmix --overlap -N 1 -n 1", ] - # Append benchmark commands if benchmark_config.get('enable_benchmark', True): env_var = config['benchmark'].get('env_var', {}) @@ -378,16 +434,21 @@ def submit_job(config, log_dir, dry_run): f"--export \"{convert_envs_to_str(env_var)}\"" ] if benchmark_config['use_nv_sa_benchmark']: + if benchmark_config['mode'] == "gen_only": + print( + f"[ERROR] SA benchmark client script is not supported for gen_only mode" + ) + sys.exit(1) benchmark_cmd = [ - f"bash {env_config['work_dir']}/run_benchmark_nv_sa.sh", - f"'{env_config['model_path']}' {isl} {osl} {benchmark_config['benchmark_ratio']} {benchmark_config['multi_round']} {gen_num} '{benchmark_config['concurrency_list']}' {benchmark_config['streaming']} '{log_dir}' {disagg_server_hostname} {disagg_server_port}", + f"bash {os.path.join(script_dir, 'run_benchmark_nv_sa.sh')}", + f"'{env_config['model_path']}' {isl} {osl} {benchmark_config['benchmark_ratio']} {benchmark_config['multi_round']} {gen_num} '{benchmark_config['concurrency_list']}' {benchmark_config['streaming']} '{log_dir}' {disagg_server_hostname} {disagg_server_port} {ucx_warmup_requests}", f"&> {log_dir}/6_bench.log" ] client_cmds.append(" ".join(benchmark_prefix + benchmark_cmd)) else: benchmark_cmd = [ - f"bash {env_config['work_dir']}/run_benchmark.sh", - f"'{env_config['model_path']}' '{benchmark_config['dataset_file']}' {benchmark_config['multi_round']} {gen_num} '{benchmark_config['concurrency_list']}' {benchmark_config['streaming']} '{log_dir}' {disagg_server_hostname} {disagg_server_port}", + f"bash {os.path.join(script_dir, 'run_benchmark.sh')}", + f"'{env_config['model_path']}' '{benchmark_config['dataset_file']}' {benchmark_config['multi_round']} {gen_num} '{benchmark_config['concurrency_list']}' {benchmark_config['streaming']} '{log_dir}' {disagg_server_hostname} {disagg_server_port} {ucx_warmup_requests}", f"&> {log_dir}/6_bench.log" ] client_cmds.append(" ".join(benchmark_prefix + benchmark_cmd)) @@ -426,6 +487,18 @@ def submit_job(config, log_dir, dry_run): with open(os.path.join(log_dir, "client_cmds.sh"), "w") as f: f.write("\n".join(client_cmds) + "\n") + # Resolve slurm script_file path + # If it's a relative path, make it relative to script_dir + slurm_script_file = slurm_config['script_file'] + if not os.path.isabs(slurm_script_file): + slurm_script_file = os.path.join(script_dir, slurm_script_file) + + # Verify the script file exists + if not os.path.exists(slurm_script_file): + print(f"[ERROR] SLURM script file not found: {slurm_script_file}", + file=sys.stderr) + sys.exit(1) + # Prepare sbatch command # yapf: disable cmd = [ @@ -437,21 +510,22 @@ def submit_job(config, log_dir, dry_run): f'--nodes={total_nodes}', f'--ntasks={total_tasks}', f'--ntasks-per-node={hw_config["gpus_per_node"]}', - *([] if not slurm_config['set_segment'] else [f'--segment={total_nodes}']), + *([] if not slurm_config['set_segment'] + else [f'--segment={total_nodes}']), f'--output={log_dir}/slurm-%j.out', f'--error={log_dir}/slurm-%j.err', *([arg for arg in slurm_config['extra_args'].split() if arg]), - slurm_config['script_file'], + slurm_script_file, # Benchmark Configuration '--benchmark-mode', benchmark_config['mode'], # Environment and paths '--trtllm-repo', env_config['trtllm_repo'], - '--work-dir', env_config['work_dir'], + '--work-dir', script_dir, '--full-logdir', log_dir, '--container-name', container_name, - '--container-mount', env_config['container_mount'], + '--container-mount', container_mount_str, '--container-image', env_config['container_image'], '--build-wheel', str(env_config['build_wheel']).lower(), '--cuda-architectures', env_config['cuda_architectures'],