[None][feat] Update disagg slurm scripts (#10712)

Signed-off-by: Xianjie <5410381+qiaoxj07@users.noreply.github.com>
This commit is contained in:
Xianjie Qiao 2026-01-19 15:53:48 +08:00 committed by GitHub
parent 32ab809f36
commit cc0bbde745
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 341 additions and 35 deletions

View File

@ -37,7 +37,7 @@ environment:
cuda_architectures: "" # Optional CUDA architectures to build for (e.g. "90-real;100-real"). If empty, builds for all architectures
trtllm_wheel_path: "" # Path to pre-built TensorRT-LLM wheel. If provided, install from this wheel instead
work_dir: "<full_path_to_work_dir>"
worker_env_var: "TLLM_LOG_LEVEL=INFO TRTLLM_SERVER_DISABLE_GC=1 TRTLLM_WORKER_DISABLE_GC=1 TRTLLM_ENABLE_PDL=1 ENROOT_ALLOW_DEV=yes"
worker_env_var: "TLLM_LOG_LEVEL=INFO TRTLLM_SERVER_DISABLE_GC=1 TRTLLM_WORKER_DISABLE_GC=1 TRTLLM_ENABLE_PDL=1 ENROOT_ALLOW_DEV=yes NCCL_GRAPH_MIXING_SUPPORT=0"
server_env_var: "TRTLLM_SERVER_DISABLE_GC=1"
# Profiling Configuration

View File

@ -124,9 +124,16 @@ elif [ -d "${trtllm_repo}" ]; then
echo "TensorRT-LLM installation completed successfully"
else
echo "trtllm_wheel_path and trtllm_repo are not provided, will use the installed TensorRT-LLM from the container"
if [ -v TRT_LLM_GIT_COMMIT ]; then
echo "TRT_LLM_GIT_COMMIT: ${TRT_LLM_GIT_COMMIT}"
# get_env file is in the same directory as this script
get_env_file=${work_dir}/get_env.py
if ! srun --container-name=${container_name} \
--container-mounts=${container_mount} --no-container-mount-home \
--mpi=pmix --overlap -N 1 --ntasks-per-node=1 \
bash -c "python ${get_env_file} -e ${full_logdir}/env_vars.json" \
&> ${full_logdir}/2_get_env.log; then
cleanup_on_failure "Failed to get TensorRT-LLM environment variables. Check ${full_logdir}/2_get_env.log for details"
fi
echo "TensorRT-LLM environment variables saved to ${full_logdir}/env_vars.json"
fi
# Get node lists and replace the placeholder with the actual node names

View File

@ -0,0 +1,33 @@
#!/usr/bin/env python3
import argparse
import json
import os
def main():
parser = argparse.ArgumentParser(
description="Get TensorRT-LLM environment variables and save to JSON"
)
parser.add_argument("-e", "--env-file", required=True, help="Environment file path")
args = parser.parse_args()
# read env file, append new envs to it
with open(args.env_file, "r") as f:
env_data = json.load(f)
# Get environment variables
new_env_data = {
"TRT_LLM_GIT_COMMIT": os.environ.get("TRT_LLM_GIT_COMMIT", ""),
"TRT_LLM_VERSION": os.environ.get("TRT_LLM_VERSION", ""),
}
print(f"Environment variables: {new_env_data}")
env_data.update(new_env_data)
# Save to environment file
with open(args.env_file, "w") as f:
json.dump(env_data, f, indent=2)
print(f"Environment variables saved to {args.env_file}")
if __name__ == "__main__":
main()

View File

@ -6,9 +6,9 @@ set -u
trap 'echo "Error occurred at line $LINENO"; exit 1' ERR
# Add parameter validation
if [ "$#" -lt 9 ]; then
echo "Error: Missing required arguments"
echo "Usage: $0 model_name dataset_file multi_round concurrency_list streaming log_path hostname port"
if [ "$#" -lt 10 ]; then
echo "Error: Missing required arguments, got $# arguments, args: $@"
echo "Usage: $0 model_name dataset_file multi_round num_gen_servers concurrency_list streaming log_path hostname port ucx_warmup_requests"
exit 1
fi
@ -21,6 +21,7 @@ streaming=$6
log_path=$7
hostname=$8
port=$9
ucx_warmup_requests=${10}
# check process id is not 0
if [[ ${SLURM_PROCID} != "0" ]]; then
@ -28,6 +29,99 @@ if [[ ${SLURM_PROCID} != "0" ]]; then
exit 0
fi
do_get_logs(){
local input_file=$1
local output_file=$2
local mode=$3
local start_line=$4
# check mode is ctx or gen
if [ "${mode}" = "ctx" ]; then
sed -n "${start_line},\$p" ${input_file} | grep -a "'num_generation_tokens': 0" > ${output_file} || true
elif [ "${mode}" = "gen" ]; then
sed -n "${start_line},\$p" ${input_file} | grep -a "'num_ctx_requests': 0, 'num_ctx_tokens': 0" > ${output_file} || true
else
echo "Invalid mode: ${mode}"
return 1
fi
return 0
}
do_process_all_logs(){
local input_folder=$1
local output_folder=$2
local mode=$3
if [ "${mode}" != "line" ] && [ "${mode}" != "log" ] && [ "${mode}" != "clean" ]; then
echo "Invalid mode: ${mode}"
exit 1
fi
local ctx_log
local ctx_num
local gen_log
local gen_num
local line_count
local start_line
for ctx_log in ${input_folder}/3_output_CTX_*.log; do
if [ -f "${ctx_log}" ]; then
ctx_num=$(basename "${ctx_log}" | sed 's/3_output_CTX_\([0-9]*\)\.log/\1/')
if [ "${mode}" = "line" ]; then
line_count=$(wc -l < ${ctx_log})
echo ${line_count} > ${output_folder}/ctx_only_line_${ctx_num}.txt
elif [ "${mode}" = "log" ]; then
if [ ! -f "${output_folder}/ctx_only_line_${ctx_num}.txt" ]; then
start_line=0
else
start_line=$(cat ${output_folder}/ctx_only_line_${ctx_num}.txt)
rm -f ${output_folder}/ctx_only_line_${ctx_num}.txt
fi
do_get_logs ${ctx_log} ${output_folder}/ctx_only_${ctx_num}.txt "ctx" ${start_line}
elif [ "${mode}" = "clean" ]; then
rm -f ${ctx_log}
fi
fi
done
# process all the gen log files in the input folder
for gen_log in ${input_folder}/3_output_GEN_*.log; do
if [ -f "${gen_log}" ]; then
gen_num=$(basename "${gen_log}" | sed 's/3_output_GEN_\([0-9]*\)\.log/\1/')
if [ "${mode}" = "line" ]; then
line_count=$(wc -l < ${gen_log})
echo ${line_count} > ${output_folder}/gen_only_line_${gen_num}.txt
elif [ "${mode}" = "log" ]; then
if [ ! -f "${output_folder}/gen_only_line_${gen_num}.txt" ]; then
start_line=0
else
start_line=$(cat ${output_folder}/gen_only_line_${gen_num}.txt)
rm -f ${output_folder}/gen_only_line_${gen_num}.txt
fi
do_get_logs ${gen_log} ${output_folder}/gen_only_${gen_num}.txt "gen" ${start_line}
elif [ "${mode}" = "clean" ]; then
rm -f ${gen_log}
fi
fi
done
}
mkdir -p ${log_path}/start_logs
cp ${log_path}/3_output_CTX_*.log ${log_path}/start_logs/ 2>/dev/null || true
cp ${log_path}/3_output_GEN_*.log ${log_path}/start_logs/ 2>/dev/null || true
# warmup requests for ucx connections
if [ "${ucx_warmup_requests}" -gt 0 ]; then
echo "warming up ucx connections with small requests... ${ucx_warmup_requests}"
python -m tensorrt_llm.serve.scripts.benchmark_serving \
--model ${model_name} \
--dataset-name random \
--random-ids \
--random-input-len 100 \
--random-output-len 10 \
--num-prompts ${ucx_warmup_requests} \
--host ${hostname} \
--port ${port} \
--ignore-eos \
--non-streaming
echo "UCX warmup done"
fi
echo "Hostname: ${hostname}, Port: ${port}"
echo "Starting benchmark..."
for concurrency in ${concurrency_list}; do
@ -35,6 +129,7 @@ for concurrency in ${concurrency_list}; do
num_prompts=$((concurrency * multi_round))
echo "Benchmarking with concurrency ${concurrency} ... ${num_prompts} prompts"
mkdir -p ${log_path}/concurrency_${concurrency}
do_process_all_logs ${log_path}/ ${log_path}/concurrency_${concurrency} "line"
python -m tensorrt_llm.serve.scripts.benchmark_serving \
--model ${model_name} \
--backend openai \
@ -53,4 +148,6 @@ for concurrency in ${concurrency_list}; do
--percentile-metrics "ttft,tpot,itl,e2el" \
$(if [ "${streaming}" = "false" ]; then echo "--non-streaming"; fi)
echo "Benchmark with concurrency ${concurrency} done"
do_process_all_logs ${log_path}/ ${log_path}/concurrency_${concurrency} "log"
done
# do_process_all_logs ${log_path}/ ${log_path}/concurrency_${concurrency} "clean"

View File

@ -42,6 +42,7 @@ streaming=$8
log_path=$9
hostname=${10}
port=${11}
ucx_warmup_requests=${12}
# check process id is not 0
if [[ ${SLURM_PROCID} != "0" ]]; then
@ -59,14 +60,107 @@ fi
echo "Cloning benchmark repository..."
git clone "${BENCH_SERVING_REPO}" "${BENCH_SERVING_DIR}"
do_get_logs(){
local input_file=$1
local output_file=$2
local mode=$3
local start_line=$4
# check mode is ctx or gen
if [ "${mode}" = "ctx" ]; then
sed -n "${start_line},\$p" ${input_file} | grep -a "'num_generation_tokens': 0" > ${output_file} || true
elif [ "${mode}" = "gen" ]; then
sed -n "${start_line},\$p" ${input_file} | grep -a "'num_ctx_requests': 0, 'num_ctx_tokens': 0" > ${output_file} || true
else
echo "Invalid mode: ${mode}"
return 1
fi
return 0
}
do_process_all_logs(){
local input_folder=$1
local output_folder=$2
local mode=$3
if [ "${mode}" != "line" ] && [ "${mode}" != "log" ] && [ "${mode}" != "clean" ]; then
echo "Invalid mode: ${mode}"
exit 1
fi
local ctx_log
local ctx_num
local gen_log
local gen_num
local line_count
local start_line
for ctx_log in ${input_folder}/output_ctx_*.log; do
if [ -f "${ctx_log}" ]; then
ctx_num=$(basename "${ctx_log}" | sed 's/output_ctx_\([0-9]*\)\.log/\1/')
if [ "${mode}" = "line" ]; then
line_count=$(wc -l < ${ctx_log})
echo ${line_count} > ${output_folder}/ctx_only_line_${ctx_num}.txt
elif [ "${mode}" = "log" ]; then
if [ ! -f "${output_folder}/ctx_only_line_${ctx_num}.txt" ]; then
start_line=0
else
start_line=$(cat ${output_folder}/ctx_only_line_${ctx_num}.txt)
rm -f ${output_folder}/ctx_only_line_${ctx_num}.txt
fi
do_get_logs ${ctx_log} ${output_folder}/ctx_only_${ctx_num}.txt "ctx" ${start_line}
elif [ "${mode}" = "clean" ]; then
rm -f ${ctx_log}
fi
fi
done
# process all the gen log files in the input folder
for gen_log in ${input_folder}/output_gen_*.log; do
if [ -f "${gen_log}" ]; then
gen_num=$(basename "${gen_log}" | sed 's/output_gen_\([0-9]*\)\.log/\1/')
if [ "${mode}" = "line" ]; then
line_count=$(wc -l < ${gen_log})
echo ${line_count} > ${output_folder}/gen_only_line_${gen_num}.txt
elif [ "${mode}" = "log" ]; then
if [ ! -f "${output_folder}/gen_only_line_${gen_num}.txt" ]; then
start_line=0
else
start_line=$(cat ${output_folder}/gen_only_line_${gen_num}.txt)
rm -f ${output_folder}/gen_only_line_${gen_num}.txt
fi
do_get_logs ${gen_log} ${output_folder}/gen_only_${gen_num}.txt "gen" ${start_line}
elif [ "${mode}" = "clean" ]; then
rm -f ${gen_log}
fi
fi
done
}
mkdir -p ${log_path}/start_logs
cp ${log_path}/output_ctx_*.log ${log_path}/start_logs/ 2>/dev/null || true
cp ${log_path}/output_gen_*.log ${log_path}/start_logs/ 2>/dev/null || true
# warmup requests for ucx connections
if [ "${ucx_warmup_requests}" -gt 0 ]; then
echo "warming up ucx connections with small requests... ${ucx_warmup_requests}"
python -m tensorrt_llm.serve.scripts.benchmark_serving \
--model ${model_name} \
--dataset-name random \
--random-ids \
--random-input-len 100 \
--random-output-len 10 \
--num-prompts ${ucx_warmup_requests} \
--host ${hostname} \
--port ${port} \
--ignore-eos \
--non-streaming
echo "UCX warmup done"
fi
# Run benchmarks
echo "Starting benchmark..."
for concurrency in ${concurrency_list}; do
concurrency=$((concurrency * num_gen_servers))
num_prompts=$((concurrency * multi_round))
output_dir="${log_path}/concurrency_${concurrency}"
echo "Benchmarking with concurrency ${concurrency} ... ${num_prompts} prompts"
do_process_all_logs ${log_path}/ ${log_path}/concurrency_${concurrency} "line"
mkdir -p "${output_dir}"
python "${BENCH_SCRIPT}" \
@ -89,4 +183,6 @@ for concurrency in ${concurrency_list}; do
$([ "${streaming}" = "false" ] && echo "--non-streaming")
echo "Benchmark with concurrency ${concurrency} done"
do_process_all_logs ${log_path}/ ${log_path}/concurrency_${concurrency} "log"
done
# do_process_all_logs ${log_path}/ ${log_path}/concurrency_${concurrency} "clean"

View File

@ -34,7 +34,6 @@ else
fi
if [ "${benchmark_mode}" = "gen_only" ]; then
export TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP=1
export TLLM_BENCHMARK_REQ_QUEUES_SIZE=${concurrency}
fi

View File

@ -168,6 +168,27 @@ def replace_env_in_file(log_dir, file_path, env_var):
return tmp_dir
def save_env_file(env_file, server_env_var, worker_env_var, ctx_worker_env_var,
gen_worker_env_var):
def get_env_var_str(env_var_str):
env_data = {}
for env_var in env_var_str.split():
if '=' in env_var:
key, value = env_var.split('=', 1)
env_data[key] = value
return env_data
env_data = {}
env_data['server_env_var'] = get_env_var_str(server_env_var)
env_data['worker_env_var'] = get_env_var_str(worker_env_var)
env_data['ctx_worker_env_var'] = get_env_var_str(ctx_worker_env_var)
env_data['gen_worker_env_var'] = get_env_var_str(gen_worker_env_var)
with open(env_file, 'w') as f:
json.dump(env_data, f, indent=2)
print(f"Environment variables saved to {env_file}")
def submit_job(config, log_dir, dry_run):
# Extract configurations
slurm_config = config['slurm']
@ -179,6 +200,11 @@ def submit_job(config, log_dir, dry_run):
worker_config = config['worker_config']
benchmark_config = config['benchmark']
if 'work_dir' in env_config and os.path.isdir(env_config['work_dir']):
script_dir = env_config['work_dir']
else:
script_dir = os.path.dirname(os.path.abspath(__file__))
# Set default accuracy configuration for backward compatibility
if 'accuracy' not in config:
config['accuracy'] = {
@ -200,11 +226,15 @@ def submit_job(config, log_dir, dry_run):
env_config.setdefault('worker_env_var', '')
env_config.setdefault('server_env_var', '')
worker_env_var = env_config.get('worker_env_var')
server_env_var = env_config.get('server_env_var')
worker_env_var = env_config.get('worker_env_var', '')
ctx_worker_env_var = env_config.get('ctx_worker_env_var', '')
gen_worker_env_var = env_config.get('gen_worker_env_var', '')
server_env_var = env_config.get('server_env_var', '')
if benchmark_config['mode'] == "gen_only_no_context":
worker_env_var += " TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1"
server_env_var += " TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1"
if benchmark_config['mode'] == "gen_only":
worker_env_var += " TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP=1"
profiling_config = config.get('profiling', {})
profiling_config.setdefault('nsys_on', False)
@ -228,6 +258,8 @@ def submit_job(config, log_dir, dry_run):
gen_pp_size = worker_config['gen'].get('pipeline_parallel_size', 1)
gen_world_size = gen_tp_size * gen_cp_size * gen_pp_size
gen_nodes = calculate_nodes(gen_world_size, gen_num, gpus_per_node)
ucx_warmup_requests = 2 * ctx_world_size * \
gen_world_size if benchmark_config['mode'] == "e2e" else 0
total_nodes = ctx_nodes + gen_nodes
total_tasks = total_nodes * gpus_per_node
@ -251,8 +283,10 @@ def submit_job(config, log_dir, dry_run):
{}).get('num_nextn_predict_layers', 0)
# Create base log directory path
if 'log_dir' in env_config and env_config['log_dir']:
log_dir = env_config['log_dir']
if log_dir is None:
log_base = os.path.join(env_config['work_dir'], "logs")
log_base = os.path.join(script_dir, "logs")
date_prefix = datetime.now().strftime("%Y%m%d-%H%M%S")
log_base = os.path.join(log_base, f"{date_prefix}/{isl}-{osl}")
@ -266,13 +300,30 @@ def submit_job(config, log_dir, dry_run):
# Create full log directory path
log_dir = os.path.join(log_base, dir_suffix)
# Remove existing directory if it exists
# if trtllm_config.yaml exists, don't remove the directory, remove other files in the directory except trtllm_config.yaml
# also don't remove concurrency_* folders
if os.path.exists(log_dir):
print(f"[WARNING] Removing existing log directory: {log_dir}")
shutil.rmtree(log_dir)
os.makedirs(log_dir)
if not os.path.exists(os.path.join(log_dir, 'trtllm_config.yaml')):
print(f"[WARNING] Removing existing log directory: {log_dir}")
shutil.rmtree(log_dir)
else:
print(
f"[WARNING] trtllm_config.yaml exists, not removing the directory: {log_dir}"
)
for file in os.listdir(log_dir):
if file != 'trtllm_config.yaml' and not file.startswith(
'concurrency_'):
if os.path.isdir(os.path.join(log_dir, file)):
shutil.rmtree(os.path.join(log_dir, file))
else:
os.remove(os.path.join(log_dir, file))
os.makedirs(log_dir, exist_ok=True)
print(f"Log will be saved to: {log_dir}")
# Save environment variables
save_env_file(os.path.join(log_dir, "env_vars.json"), server_env_var,
worker_env_var, ctx_worker_env_var, gen_worker_env_var)
# Setup config file paths and save worker configs
ctx_config_path = os.path.join(log_dir, 'ctx_config.yaml')
gen_config_path = os.path.join(log_dir, 'gen_config.yaml')
@ -300,6 +351,8 @@ def submit_job(config, log_dir, dry_run):
container_name = "disaggr-test"
start_server_cmds = []
container_mount_str = env_config['container_mount']
container_mount_str += f",{script_dir}:{script_dir}"
# Generate start worker commands with placeholder hostnames
for server_type in allocations.keys():
for server_id in allocations[server_type].keys():
@ -307,7 +360,11 @@ def submit_job(config, log_dir, dry_run):
cuda_devices = ",".join([
str(device) for device in list(allocation["nodes"].values())[0]
])
cur_worker_env_var = worker_env_var + f" CUDA_VISIBLE_DEVICES={cuda_devices}"
cur_worker_env_var = worker_env_var + \
f" CUDA_VISIBLE_DEVICES={cuda_devices}" + \
(f" {ctx_worker_env_var}" if server_type == "CTX" else "") + \
(f" {gen_worker_env_var}" if server_type == "GEN" else "")
# Use script_dir for start_worker.sh
cmd = [
"srun -l",
f"--nodelist {','.join(allocation['nodes'].keys())}",
@ -316,9 +373,9 @@ def submit_job(config, log_dir, dry_run):
f"--ntasks-per-node {gpus_per_node}",
f"--container-image {env_config['container_image']}",
f"--container-name {container_name}",
f"--container-mounts {env_config['container_mount']}",
f"--container-mounts {container_mount_str}",
"--no-container-mount-home --mpi=pmix --overlap",
f"bash {os.path.join(env_config['work_dir'], 'start_worker.sh')}",
f"bash {os.path.join(script_dir, 'start_worker.sh')}",
server_type,
str(server_id),
env_config['model_path'],
@ -336,26 +393,26 @@ def submit_job(config, log_dir, dry_run):
]
start_server_cmds.append(" ".join(cmd))
# Generate start server commands
# Generate start server commands (use script_dir for start_server.sh)
cmd = [
"srun -l",
f"--nodelist {disagg_server_hostname}",
f"--container-name={container_name}",
f"--container-image={env_config['container_image']}",
f"--container-mounts={env_config['container_mount']}",
f"--container-mounts={container_mount_str}",
f"--no-container-mount-home --mpi=pmix --overlap -N 1 -n 1",
f"bash {env_config['work_dir']}/start_server.sh {os.path.join(log_dir, 'server_config.yaml')} \"{server_env_var}\"",
f"bash {os.path.join(script_dir, 'start_server.sh')} {os.path.join(log_dir, 'server_config.yaml')} \"{server_env_var}\"",
f"&> {log_dir}/4_output_server.log &",
]
start_server_cmds.append(" ".join(cmd))
# Generate wait server command
# Generate wait server command (use script_dir for wait_server.sh)
cmd = [
"srun -l",
f"--container-name={container_name}",
f"--container-mounts={env_config['container_mount']}",
f"--container-mounts={container_mount_str}",
f"--mpi=pmix --overlap -N 1 -n 1",
f"bash {env_config['work_dir']}/wait_server.sh {disagg_server_hostname} {disagg_server_port}",
f"bash {os.path.join(script_dir, 'wait_server.sh')} {disagg_server_hostname} {disagg_server_port}",
f"&> {log_dir}/5_wait_server.log",
]
start_server_cmds.append(" ".join(cmd))
@ -363,14 +420,13 @@ def submit_job(config, log_dir, dry_run):
with open(os.path.join(log_dir, "start_server_cmds.sh"), "w") as f:
f.write("\n".join(start_server_cmds) + "\n")
# Generate client commands
# Generate client commands (use script_dir for benchmark scripts)
client_cmds = []
client_slurm_prefix = [
f"srun -l --container-name={container_name}",
f"--container-mounts={env_config['container_mount']}",
f"--container-mounts={container_mount_str}",
f"--mpi=pmix --overlap -N 1 -n 1",
]
# Append benchmark commands
if benchmark_config.get('enable_benchmark', True):
env_var = config['benchmark'].get('env_var', {})
@ -378,16 +434,21 @@ def submit_job(config, log_dir, dry_run):
f"--export \"{convert_envs_to_str(env_var)}\""
]
if benchmark_config['use_nv_sa_benchmark']:
if benchmark_config['mode'] == "gen_only":
print(
f"[ERROR] SA benchmark client script is not supported for gen_only mode"
)
sys.exit(1)
benchmark_cmd = [
f"bash {env_config['work_dir']}/run_benchmark_nv_sa.sh",
f"'{env_config['model_path']}' {isl} {osl} {benchmark_config['benchmark_ratio']} {benchmark_config['multi_round']} {gen_num} '{benchmark_config['concurrency_list']}' {benchmark_config['streaming']} '{log_dir}' {disagg_server_hostname} {disagg_server_port}",
f"bash {os.path.join(script_dir, 'run_benchmark_nv_sa.sh')}",
f"'{env_config['model_path']}' {isl} {osl} {benchmark_config['benchmark_ratio']} {benchmark_config['multi_round']} {gen_num} '{benchmark_config['concurrency_list']}' {benchmark_config['streaming']} '{log_dir}' {disagg_server_hostname} {disagg_server_port} {ucx_warmup_requests}",
f"&> {log_dir}/6_bench.log"
]
client_cmds.append(" ".join(benchmark_prefix + benchmark_cmd))
else:
benchmark_cmd = [
f"bash {env_config['work_dir']}/run_benchmark.sh",
f"'{env_config['model_path']}' '{benchmark_config['dataset_file']}' {benchmark_config['multi_round']} {gen_num} '{benchmark_config['concurrency_list']}' {benchmark_config['streaming']} '{log_dir}' {disagg_server_hostname} {disagg_server_port}",
f"bash {os.path.join(script_dir, 'run_benchmark.sh')}",
f"'{env_config['model_path']}' '{benchmark_config['dataset_file']}' {benchmark_config['multi_round']} {gen_num} '{benchmark_config['concurrency_list']}' {benchmark_config['streaming']} '{log_dir}' {disagg_server_hostname} {disagg_server_port} {ucx_warmup_requests}",
f"&> {log_dir}/6_bench.log"
]
client_cmds.append(" ".join(benchmark_prefix + benchmark_cmd))
@ -426,6 +487,18 @@ def submit_job(config, log_dir, dry_run):
with open(os.path.join(log_dir, "client_cmds.sh"), "w") as f:
f.write("\n".join(client_cmds) + "\n")
# Resolve slurm script_file path
# If it's a relative path, make it relative to script_dir
slurm_script_file = slurm_config['script_file']
if not os.path.isabs(slurm_script_file):
slurm_script_file = os.path.join(script_dir, slurm_script_file)
# Verify the script file exists
if not os.path.exists(slurm_script_file):
print(f"[ERROR] SLURM script file not found: {slurm_script_file}",
file=sys.stderr)
sys.exit(1)
# Prepare sbatch command
# yapf: disable
cmd = [
@ -437,21 +510,22 @@ def submit_job(config, log_dir, dry_run):
f'--nodes={total_nodes}',
f'--ntasks={total_tasks}',
f'--ntasks-per-node={hw_config["gpus_per_node"]}',
*([] if not slurm_config['set_segment'] else [f'--segment={total_nodes}']),
*([] if not slurm_config['set_segment']
else [f'--segment={total_nodes}']),
f'--output={log_dir}/slurm-%j.out',
f'--error={log_dir}/slurm-%j.err',
*([arg for arg in slurm_config['extra_args'].split() if arg]),
slurm_config['script_file'],
slurm_script_file,
# Benchmark Configuration
'--benchmark-mode', benchmark_config['mode'],
# Environment and paths
'--trtllm-repo', env_config['trtllm_repo'],
'--work-dir', env_config['work_dir'],
'--work-dir', script_dir,
'--full-logdir', log_dir,
'--container-name', container_name,
'--container-mount', env_config['container_mount'],
'--container-mount', container_mount_str,
'--container-image', env_config['container_image'],
'--build-wheel', str(env_config['build_wheel']).lower(),
'--cuda-architectures', env_config['cuda_architectures'],