mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
[None][feat] Update disagg slurm scripts (#10712)
Signed-off-by: Xianjie <5410381+qiaoxj07@users.noreply.github.com>
This commit is contained in:
parent
32ab809f36
commit
cc0bbde745
@ -37,7 +37,7 @@ environment:
|
||||
cuda_architectures: "" # Optional CUDA architectures to build for (e.g. "90-real;100-real"). If empty, builds for all architectures
|
||||
trtllm_wheel_path: "" # Path to pre-built TensorRT-LLM wheel. If provided, install from this wheel instead
|
||||
work_dir: "<full_path_to_work_dir>"
|
||||
worker_env_var: "TLLM_LOG_LEVEL=INFO TRTLLM_SERVER_DISABLE_GC=1 TRTLLM_WORKER_DISABLE_GC=1 TRTLLM_ENABLE_PDL=1 ENROOT_ALLOW_DEV=yes"
|
||||
worker_env_var: "TLLM_LOG_LEVEL=INFO TRTLLM_SERVER_DISABLE_GC=1 TRTLLM_WORKER_DISABLE_GC=1 TRTLLM_ENABLE_PDL=1 ENROOT_ALLOW_DEV=yes NCCL_GRAPH_MIXING_SUPPORT=0"
|
||||
server_env_var: "TRTLLM_SERVER_DISABLE_GC=1"
|
||||
|
||||
# Profiling Configuration
|
||||
|
||||
@ -124,9 +124,16 @@ elif [ -d "${trtllm_repo}" ]; then
|
||||
echo "TensorRT-LLM installation completed successfully"
|
||||
else
|
||||
echo "trtllm_wheel_path and trtllm_repo are not provided, will use the installed TensorRT-LLM from the container"
|
||||
if [ -v TRT_LLM_GIT_COMMIT ]; then
|
||||
echo "TRT_LLM_GIT_COMMIT: ${TRT_LLM_GIT_COMMIT}"
|
||||
# get_env file is in the same directory as this script
|
||||
get_env_file=${work_dir}/get_env.py
|
||||
if ! srun --container-name=${container_name} \
|
||||
--container-mounts=${container_mount} --no-container-mount-home \
|
||||
--mpi=pmix --overlap -N 1 --ntasks-per-node=1 \
|
||||
bash -c "python ${get_env_file} -e ${full_logdir}/env_vars.json" \
|
||||
&> ${full_logdir}/2_get_env.log; then
|
||||
cleanup_on_failure "Failed to get TensorRT-LLM environment variables. Check ${full_logdir}/2_get_env.log for details"
|
||||
fi
|
||||
echo "TensorRT-LLM environment variables saved to ${full_logdir}/env_vars.json"
|
||||
fi
|
||||
|
||||
# Get node lists and replace the placeholder with the actual node names
|
||||
|
||||
33
examples/disaggregated/slurm/benchmark/get_env.py
Normal file
33
examples/disaggregated/slurm/benchmark/get_env.py
Normal file
@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Get TensorRT-LLM environment variables and save to JSON"
|
||||
)
|
||||
parser.add_argument("-e", "--env-file", required=True, help="Environment file path")
|
||||
args = parser.parse_args()
|
||||
|
||||
# read env file, append new envs to it
|
||||
with open(args.env_file, "r") as f:
|
||||
env_data = json.load(f)
|
||||
|
||||
# Get environment variables
|
||||
new_env_data = {
|
||||
"TRT_LLM_GIT_COMMIT": os.environ.get("TRT_LLM_GIT_COMMIT", ""),
|
||||
"TRT_LLM_VERSION": os.environ.get("TRT_LLM_VERSION", ""),
|
||||
}
|
||||
print(f"Environment variables: {new_env_data}")
|
||||
env_data.update(new_env_data)
|
||||
# Save to environment file
|
||||
with open(args.env_file, "w") as f:
|
||||
json.dump(env_data, f, indent=2)
|
||||
|
||||
print(f"Environment variables saved to {args.env_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -6,9 +6,9 @@ set -u
|
||||
trap 'echo "Error occurred at line $LINENO"; exit 1' ERR
|
||||
|
||||
# Add parameter validation
|
||||
if [ "$#" -lt 9 ]; then
|
||||
echo "Error: Missing required arguments"
|
||||
echo "Usage: $0 model_name dataset_file multi_round concurrency_list streaming log_path hostname port"
|
||||
if [ "$#" -lt 10 ]; then
|
||||
echo "Error: Missing required arguments, got $# arguments, args: $@"
|
||||
echo "Usage: $0 model_name dataset_file multi_round num_gen_servers concurrency_list streaming log_path hostname port ucx_warmup_requests"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@ -21,6 +21,7 @@ streaming=$6
|
||||
log_path=$7
|
||||
hostname=$8
|
||||
port=$9
|
||||
ucx_warmup_requests=${10}
|
||||
|
||||
# check process id is not 0
|
||||
if [[ ${SLURM_PROCID} != "0" ]]; then
|
||||
@ -28,6 +29,99 @@ if [[ ${SLURM_PROCID} != "0" ]]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
do_get_logs(){
|
||||
local input_file=$1
|
||||
local output_file=$2
|
||||
local mode=$3
|
||||
local start_line=$4
|
||||
# check mode is ctx or gen
|
||||
if [ "${mode}" = "ctx" ]; then
|
||||
sed -n "${start_line},\$p" ${input_file} | grep -a "'num_generation_tokens': 0" > ${output_file} || true
|
||||
elif [ "${mode}" = "gen" ]; then
|
||||
sed -n "${start_line},\$p" ${input_file} | grep -a "'num_ctx_requests': 0, 'num_ctx_tokens': 0" > ${output_file} || true
|
||||
else
|
||||
echo "Invalid mode: ${mode}"
|
||||
return 1
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
do_process_all_logs(){
|
||||
local input_folder=$1
|
||||
local output_folder=$2
|
||||
local mode=$3
|
||||
if [ "${mode}" != "line" ] && [ "${mode}" != "log" ] && [ "${mode}" != "clean" ]; then
|
||||
echo "Invalid mode: ${mode}"
|
||||
exit 1
|
||||
fi
|
||||
local ctx_log
|
||||
local ctx_num
|
||||
local gen_log
|
||||
local gen_num
|
||||
local line_count
|
||||
local start_line
|
||||
for ctx_log in ${input_folder}/3_output_CTX_*.log; do
|
||||
if [ -f "${ctx_log}" ]; then
|
||||
ctx_num=$(basename "${ctx_log}" | sed 's/3_output_CTX_\([0-9]*\)\.log/\1/')
|
||||
if [ "${mode}" = "line" ]; then
|
||||
line_count=$(wc -l < ${ctx_log})
|
||||
echo ${line_count} > ${output_folder}/ctx_only_line_${ctx_num}.txt
|
||||
elif [ "${mode}" = "log" ]; then
|
||||
if [ ! -f "${output_folder}/ctx_only_line_${ctx_num}.txt" ]; then
|
||||
start_line=0
|
||||
else
|
||||
start_line=$(cat ${output_folder}/ctx_only_line_${ctx_num}.txt)
|
||||
rm -f ${output_folder}/ctx_only_line_${ctx_num}.txt
|
||||
fi
|
||||
do_get_logs ${ctx_log} ${output_folder}/ctx_only_${ctx_num}.txt "ctx" ${start_line}
|
||||
elif [ "${mode}" = "clean" ]; then
|
||||
rm -f ${ctx_log}
|
||||
fi
|
||||
fi
|
||||
done
|
||||
# process all the gen log files in the input folder
|
||||
for gen_log in ${input_folder}/3_output_GEN_*.log; do
|
||||
if [ -f "${gen_log}" ]; then
|
||||
gen_num=$(basename "${gen_log}" | sed 's/3_output_GEN_\([0-9]*\)\.log/\1/')
|
||||
if [ "${mode}" = "line" ]; then
|
||||
line_count=$(wc -l < ${gen_log})
|
||||
echo ${line_count} > ${output_folder}/gen_only_line_${gen_num}.txt
|
||||
elif [ "${mode}" = "log" ]; then
|
||||
if [ ! -f "${output_folder}/gen_only_line_${gen_num}.txt" ]; then
|
||||
start_line=0
|
||||
else
|
||||
start_line=$(cat ${output_folder}/gen_only_line_${gen_num}.txt)
|
||||
rm -f ${output_folder}/gen_only_line_${gen_num}.txt
|
||||
fi
|
||||
do_get_logs ${gen_log} ${output_folder}/gen_only_${gen_num}.txt "gen" ${start_line}
|
||||
elif [ "${mode}" = "clean" ]; then
|
||||
rm -f ${gen_log}
|
||||
fi
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
mkdir -p ${log_path}/start_logs
|
||||
cp ${log_path}/3_output_CTX_*.log ${log_path}/start_logs/ 2>/dev/null || true
|
||||
cp ${log_path}/3_output_GEN_*.log ${log_path}/start_logs/ 2>/dev/null || true
|
||||
|
||||
# warmup requests for ucx connections
|
||||
if [ "${ucx_warmup_requests}" -gt 0 ]; then
|
||||
echo "warming up ucx connections with small requests... ${ucx_warmup_requests}"
|
||||
python -m tensorrt_llm.serve.scripts.benchmark_serving \
|
||||
--model ${model_name} \
|
||||
--dataset-name random \
|
||||
--random-ids \
|
||||
--random-input-len 100 \
|
||||
--random-output-len 10 \
|
||||
--num-prompts ${ucx_warmup_requests} \
|
||||
--host ${hostname} \
|
||||
--port ${port} \
|
||||
--ignore-eos \
|
||||
--non-streaming
|
||||
echo "UCX warmup done"
|
||||
fi
|
||||
|
||||
echo "Hostname: ${hostname}, Port: ${port}"
|
||||
echo "Starting benchmark..."
|
||||
for concurrency in ${concurrency_list}; do
|
||||
@ -35,6 +129,7 @@ for concurrency in ${concurrency_list}; do
|
||||
num_prompts=$((concurrency * multi_round))
|
||||
echo "Benchmarking with concurrency ${concurrency} ... ${num_prompts} prompts"
|
||||
mkdir -p ${log_path}/concurrency_${concurrency}
|
||||
do_process_all_logs ${log_path}/ ${log_path}/concurrency_${concurrency} "line"
|
||||
python -m tensorrt_llm.serve.scripts.benchmark_serving \
|
||||
--model ${model_name} \
|
||||
--backend openai \
|
||||
@ -53,4 +148,6 @@ for concurrency in ${concurrency_list}; do
|
||||
--percentile-metrics "ttft,tpot,itl,e2el" \
|
||||
$(if [ "${streaming}" = "false" ]; then echo "--non-streaming"; fi)
|
||||
echo "Benchmark with concurrency ${concurrency} done"
|
||||
do_process_all_logs ${log_path}/ ${log_path}/concurrency_${concurrency} "log"
|
||||
done
|
||||
# do_process_all_logs ${log_path}/ ${log_path}/concurrency_${concurrency} "clean"
|
||||
|
||||
@ -42,6 +42,7 @@ streaming=$8
|
||||
log_path=$9
|
||||
hostname=${10}
|
||||
port=${11}
|
||||
ucx_warmup_requests=${12}
|
||||
|
||||
# check process id is not 0
|
||||
if [[ ${SLURM_PROCID} != "0" ]]; then
|
||||
@ -59,14 +60,107 @@ fi
|
||||
echo "Cloning benchmark repository..."
|
||||
git clone "${BENCH_SERVING_REPO}" "${BENCH_SERVING_DIR}"
|
||||
|
||||
do_get_logs(){
|
||||
local input_file=$1
|
||||
local output_file=$2
|
||||
local mode=$3
|
||||
local start_line=$4
|
||||
# check mode is ctx or gen
|
||||
if [ "${mode}" = "ctx" ]; then
|
||||
sed -n "${start_line},\$p" ${input_file} | grep -a "'num_generation_tokens': 0" > ${output_file} || true
|
||||
elif [ "${mode}" = "gen" ]; then
|
||||
sed -n "${start_line},\$p" ${input_file} | grep -a "'num_ctx_requests': 0, 'num_ctx_tokens': 0" > ${output_file} || true
|
||||
else
|
||||
echo "Invalid mode: ${mode}"
|
||||
return 1
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
do_process_all_logs(){
|
||||
local input_folder=$1
|
||||
local output_folder=$2
|
||||
local mode=$3
|
||||
if [ "${mode}" != "line" ] && [ "${mode}" != "log" ] && [ "${mode}" != "clean" ]; then
|
||||
echo "Invalid mode: ${mode}"
|
||||
exit 1
|
||||
fi
|
||||
local ctx_log
|
||||
local ctx_num
|
||||
local gen_log
|
||||
local gen_num
|
||||
local line_count
|
||||
local start_line
|
||||
for ctx_log in ${input_folder}/output_ctx_*.log; do
|
||||
if [ -f "${ctx_log}" ]; then
|
||||
ctx_num=$(basename "${ctx_log}" | sed 's/output_ctx_\([0-9]*\)\.log/\1/')
|
||||
if [ "${mode}" = "line" ]; then
|
||||
line_count=$(wc -l < ${ctx_log})
|
||||
echo ${line_count} > ${output_folder}/ctx_only_line_${ctx_num}.txt
|
||||
elif [ "${mode}" = "log" ]; then
|
||||
if [ ! -f "${output_folder}/ctx_only_line_${ctx_num}.txt" ]; then
|
||||
start_line=0
|
||||
else
|
||||
start_line=$(cat ${output_folder}/ctx_only_line_${ctx_num}.txt)
|
||||
rm -f ${output_folder}/ctx_only_line_${ctx_num}.txt
|
||||
fi
|
||||
do_get_logs ${ctx_log} ${output_folder}/ctx_only_${ctx_num}.txt "ctx" ${start_line}
|
||||
elif [ "${mode}" = "clean" ]; then
|
||||
rm -f ${ctx_log}
|
||||
fi
|
||||
fi
|
||||
done
|
||||
# process all the gen log files in the input folder
|
||||
for gen_log in ${input_folder}/output_gen_*.log; do
|
||||
if [ -f "${gen_log}" ]; then
|
||||
gen_num=$(basename "${gen_log}" | sed 's/output_gen_\([0-9]*\)\.log/\1/')
|
||||
if [ "${mode}" = "line" ]; then
|
||||
line_count=$(wc -l < ${gen_log})
|
||||
echo ${line_count} > ${output_folder}/gen_only_line_${gen_num}.txt
|
||||
elif [ "${mode}" = "log" ]; then
|
||||
if [ ! -f "${output_folder}/gen_only_line_${gen_num}.txt" ]; then
|
||||
start_line=0
|
||||
else
|
||||
start_line=$(cat ${output_folder}/gen_only_line_${gen_num}.txt)
|
||||
rm -f ${output_folder}/gen_only_line_${gen_num}.txt
|
||||
fi
|
||||
do_get_logs ${gen_log} ${output_folder}/gen_only_${gen_num}.txt "gen" ${start_line}
|
||||
elif [ "${mode}" = "clean" ]; then
|
||||
rm -f ${gen_log}
|
||||
fi
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
mkdir -p ${log_path}/start_logs
|
||||
cp ${log_path}/output_ctx_*.log ${log_path}/start_logs/ 2>/dev/null || true
|
||||
cp ${log_path}/output_gen_*.log ${log_path}/start_logs/ 2>/dev/null || true
|
||||
|
||||
# warmup requests for ucx connections
|
||||
if [ "${ucx_warmup_requests}" -gt 0 ]; then
|
||||
echo "warming up ucx connections with small requests... ${ucx_warmup_requests}"
|
||||
python -m tensorrt_llm.serve.scripts.benchmark_serving \
|
||||
--model ${model_name} \
|
||||
--dataset-name random \
|
||||
--random-ids \
|
||||
--random-input-len 100 \
|
||||
--random-output-len 10 \
|
||||
--num-prompts ${ucx_warmup_requests} \
|
||||
--host ${hostname} \
|
||||
--port ${port} \
|
||||
--ignore-eos \
|
||||
--non-streaming
|
||||
echo "UCX warmup done"
|
||||
fi
|
||||
|
||||
# Run benchmarks
|
||||
echo "Starting benchmark..."
|
||||
for concurrency in ${concurrency_list}; do
|
||||
concurrency=$((concurrency * num_gen_servers))
|
||||
num_prompts=$((concurrency * multi_round))
|
||||
output_dir="${log_path}/concurrency_${concurrency}"
|
||||
|
||||
echo "Benchmarking with concurrency ${concurrency} ... ${num_prompts} prompts"
|
||||
do_process_all_logs ${log_path}/ ${log_path}/concurrency_${concurrency} "line"
|
||||
mkdir -p "${output_dir}"
|
||||
|
||||
python "${BENCH_SCRIPT}" \
|
||||
@ -89,4 +183,6 @@ for concurrency in ${concurrency_list}; do
|
||||
$([ "${streaming}" = "false" ] && echo "--non-streaming")
|
||||
|
||||
echo "Benchmark with concurrency ${concurrency} done"
|
||||
do_process_all_logs ${log_path}/ ${log_path}/concurrency_${concurrency} "log"
|
||||
done
|
||||
# do_process_all_logs ${log_path}/ ${log_path}/concurrency_${concurrency} "clean"
|
||||
|
||||
@ -34,7 +34,6 @@ else
|
||||
fi
|
||||
|
||||
if [ "${benchmark_mode}" = "gen_only" ]; then
|
||||
export TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP=1
|
||||
export TLLM_BENCHMARK_REQ_QUEUES_SIZE=${concurrency}
|
||||
fi
|
||||
|
||||
|
||||
@ -168,6 +168,27 @@ def replace_env_in_file(log_dir, file_path, env_var):
|
||||
return tmp_dir
|
||||
|
||||
|
||||
def save_env_file(env_file, server_env_var, worker_env_var, ctx_worker_env_var,
|
||||
gen_worker_env_var):
|
||||
|
||||
def get_env_var_str(env_var_str):
|
||||
env_data = {}
|
||||
for env_var in env_var_str.split():
|
||||
if '=' in env_var:
|
||||
key, value = env_var.split('=', 1)
|
||||
env_data[key] = value
|
||||
return env_data
|
||||
|
||||
env_data = {}
|
||||
env_data['server_env_var'] = get_env_var_str(server_env_var)
|
||||
env_data['worker_env_var'] = get_env_var_str(worker_env_var)
|
||||
env_data['ctx_worker_env_var'] = get_env_var_str(ctx_worker_env_var)
|
||||
env_data['gen_worker_env_var'] = get_env_var_str(gen_worker_env_var)
|
||||
with open(env_file, 'w') as f:
|
||||
json.dump(env_data, f, indent=2)
|
||||
print(f"Environment variables saved to {env_file}")
|
||||
|
||||
|
||||
def submit_job(config, log_dir, dry_run):
|
||||
# Extract configurations
|
||||
slurm_config = config['slurm']
|
||||
@ -179,6 +200,11 @@ def submit_job(config, log_dir, dry_run):
|
||||
worker_config = config['worker_config']
|
||||
benchmark_config = config['benchmark']
|
||||
|
||||
if 'work_dir' in env_config and os.path.isdir(env_config['work_dir']):
|
||||
script_dir = env_config['work_dir']
|
||||
else:
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Set default accuracy configuration for backward compatibility
|
||||
if 'accuracy' not in config:
|
||||
config['accuracy'] = {
|
||||
@ -200,11 +226,15 @@ def submit_job(config, log_dir, dry_run):
|
||||
env_config.setdefault('worker_env_var', '')
|
||||
env_config.setdefault('server_env_var', '')
|
||||
|
||||
worker_env_var = env_config.get('worker_env_var')
|
||||
server_env_var = env_config.get('server_env_var')
|
||||
worker_env_var = env_config.get('worker_env_var', '')
|
||||
ctx_worker_env_var = env_config.get('ctx_worker_env_var', '')
|
||||
gen_worker_env_var = env_config.get('gen_worker_env_var', '')
|
||||
server_env_var = env_config.get('server_env_var', '')
|
||||
if benchmark_config['mode'] == "gen_only_no_context":
|
||||
worker_env_var += " TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1"
|
||||
server_env_var += " TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1"
|
||||
if benchmark_config['mode'] == "gen_only":
|
||||
worker_env_var += " TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP=1"
|
||||
|
||||
profiling_config = config.get('profiling', {})
|
||||
profiling_config.setdefault('nsys_on', False)
|
||||
@ -228,6 +258,8 @@ def submit_job(config, log_dir, dry_run):
|
||||
gen_pp_size = worker_config['gen'].get('pipeline_parallel_size', 1)
|
||||
gen_world_size = gen_tp_size * gen_cp_size * gen_pp_size
|
||||
gen_nodes = calculate_nodes(gen_world_size, gen_num, gpus_per_node)
|
||||
ucx_warmup_requests = 2 * ctx_world_size * \
|
||||
gen_world_size if benchmark_config['mode'] == "e2e" else 0
|
||||
|
||||
total_nodes = ctx_nodes + gen_nodes
|
||||
total_tasks = total_nodes * gpus_per_node
|
||||
@ -251,8 +283,10 @@ def submit_job(config, log_dir, dry_run):
|
||||
{}).get('num_nextn_predict_layers', 0)
|
||||
|
||||
# Create base log directory path
|
||||
if 'log_dir' in env_config and env_config['log_dir']:
|
||||
log_dir = env_config['log_dir']
|
||||
if log_dir is None:
|
||||
log_base = os.path.join(env_config['work_dir'], "logs")
|
||||
log_base = os.path.join(script_dir, "logs")
|
||||
|
||||
date_prefix = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
log_base = os.path.join(log_base, f"{date_prefix}/{isl}-{osl}")
|
||||
@ -266,13 +300,30 @@ def submit_job(config, log_dir, dry_run):
|
||||
# Create full log directory path
|
||||
log_dir = os.path.join(log_base, dir_suffix)
|
||||
|
||||
# Remove existing directory if it exists
|
||||
# if trtllm_config.yaml exists, don't remove the directory, remove other files in the directory except trtllm_config.yaml
|
||||
# also don't remove concurrency_* folders
|
||||
if os.path.exists(log_dir):
|
||||
print(f"[WARNING] Removing existing log directory: {log_dir}")
|
||||
shutil.rmtree(log_dir)
|
||||
os.makedirs(log_dir)
|
||||
if not os.path.exists(os.path.join(log_dir, 'trtllm_config.yaml')):
|
||||
print(f"[WARNING] Removing existing log directory: {log_dir}")
|
||||
shutil.rmtree(log_dir)
|
||||
else:
|
||||
print(
|
||||
f"[WARNING] trtllm_config.yaml exists, not removing the directory: {log_dir}"
|
||||
)
|
||||
for file in os.listdir(log_dir):
|
||||
if file != 'trtllm_config.yaml' and not file.startswith(
|
||||
'concurrency_'):
|
||||
if os.path.isdir(os.path.join(log_dir, file)):
|
||||
shutil.rmtree(os.path.join(log_dir, file))
|
||||
else:
|
||||
os.remove(os.path.join(log_dir, file))
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
print(f"Log will be saved to: {log_dir}")
|
||||
|
||||
# Save environment variables
|
||||
save_env_file(os.path.join(log_dir, "env_vars.json"), server_env_var,
|
||||
worker_env_var, ctx_worker_env_var, gen_worker_env_var)
|
||||
|
||||
# Setup config file paths and save worker configs
|
||||
ctx_config_path = os.path.join(log_dir, 'ctx_config.yaml')
|
||||
gen_config_path = os.path.join(log_dir, 'gen_config.yaml')
|
||||
@ -300,6 +351,8 @@ def submit_job(config, log_dir, dry_run):
|
||||
|
||||
container_name = "disaggr-test"
|
||||
start_server_cmds = []
|
||||
container_mount_str = env_config['container_mount']
|
||||
container_mount_str += f",{script_dir}:{script_dir}"
|
||||
# Generate start worker commands with placeholder hostnames
|
||||
for server_type in allocations.keys():
|
||||
for server_id in allocations[server_type].keys():
|
||||
@ -307,7 +360,11 @@ def submit_job(config, log_dir, dry_run):
|
||||
cuda_devices = ",".join([
|
||||
str(device) for device in list(allocation["nodes"].values())[0]
|
||||
])
|
||||
cur_worker_env_var = worker_env_var + f" CUDA_VISIBLE_DEVICES={cuda_devices}"
|
||||
cur_worker_env_var = worker_env_var + \
|
||||
f" CUDA_VISIBLE_DEVICES={cuda_devices}" + \
|
||||
(f" {ctx_worker_env_var}" if server_type == "CTX" else "") + \
|
||||
(f" {gen_worker_env_var}" if server_type == "GEN" else "")
|
||||
# Use script_dir for start_worker.sh
|
||||
cmd = [
|
||||
"srun -l",
|
||||
f"--nodelist {','.join(allocation['nodes'].keys())}",
|
||||
@ -316,9 +373,9 @@ def submit_job(config, log_dir, dry_run):
|
||||
f"--ntasks-per-node {gpus_per_node}",
|
||||
f"--container-image {env_config['container_image']}",
|
||||
f"--container-name {container_name}",
|
||||
f"--container-mounts {env_config['container_mount']}",
|
||||
f"--container-mounts {container_mount_str}",
|
||||
"--no-container-mount-home --mpi=pmix --overlap",
|
||||
f"bash {os.path.join(env_config['work_dir'], 'start_worker.sh')}",
|
||||
f"bash {os.path.join(script_dir, 'start_worker.sh')}",
|
||||
server_type,
|
||||
str(server_id),
|
||||
env_config['model_path'],
|
||||
@ -336,26 +393,26 @@ def submit_job(config, log_dir, dry_run):
|
||||
]
|
||||
start_server_cmds.append(" ".join(cmd))
|
||||
|
||||
# Generate start server commands
|
||||
# Generate start server commands (use script_dir for start_server.sh)
|
||||
cmd = [
|
||||
"srun -l",
|
||||
f"--nodelist {disagg_server_hostname}",
|
||||
f"--container-name={container_name}",
|
||||
f"--container-image={env_config['container_image']}",
|
||||
f"--container-mounts={env_config['container_mount']}",
|
||||
f"--container-mounts={container_mount_str}",
|
||||
f"--no-container-mount-home --mpi=pmix --overlap -N 1 -n 1",
|
||||
f"bash {env_config['work_dir']}/start_server.sh {os.path.join(log_dir, 'server_config.yaml')} \"{server_env_var}\"",
|
||||
f"bash {os.path.join(script_dir, 'start_server.sh')} {os.path.join(log_dir, 'server_config.yaml')} \"{server_env_var}\"",
|
||||
f"&> {log_dir}/4_output_server.log &",
|
||||
]
|
||||
start_server_cmds.append(" ".join(cmd))
|
||||
|
||||
# Generate wait server command
|
||||
# Generate wait server command (use script_dir for wait_server.sh)
|
||||
cmd = [
|
||||
"srun -l",
|
||||
f"--container-name={container_name}",
|
||||
f"--container-mounts={env_config['container_mount']}",
|
||||
f"--container-mounts={container_mount_str}",
|
||||
f"--mpi=pmix --overlap -N 1 -n 1",
|
||||
f"bash {env_config['work_dir']}/wait_server.sh {disagg_server_hostname} {disagg_server_port}",
|
||||
f"bash {os.path.join(script_dir, 'wait_server.sh')} {disagg_server_hostname} {disagg_server_port}",
|
||||
f"&> {log_dir}/5_wait_server.log",
|
||||
]
|
||||
start_server_cmds.append(" ".join(cmd))
|
||||
@ -363,14 +420,13 @@ def submit_job(config, log_dir, dry_run):
|
||||
with open(os.path.join(log_dir, "start_server_cmds.sh"), "w") as f:
|
||||
f.write("\n".join(start_server_cmds) + "\n")
|
||||
|
||||
# Generate client commands
|
||||
# Generate client commands (use script_dir for benchmark scripts)
|
||||
client_cmds = []
|
||||
client_slurm_prefix = [
|
||||
f"srun -l --container-name={container_name}",
|
||||
f"--container-mounts={env_config['container_mount']}",
|
||||
f"--container-mounts={container_mount_str}",
|
||||
f"--mpi=pmix --overlap -N 1 -n 1",
|
||||
]
|
||||
|
||||
# Append benchmark commands
|
||||
if benchmark_config.get('enable_benchmark', True):
|
||||
env_var = config['benchmark'].get('env_var', {})
|
||||
@ -378,16 +434,21 @@ def submit_job(config, log_dir, dry_run):
|
||||
f"--export \"{convert_envs_to_str(env_var)}\""
|
||||
]
|
||||
if benchmark_config['use_nv_sa_benchmark']:
|
||||
if benchmark_config['mode'] == "gen_only":
|
||||
print(
|
||||
f"[ERROR] SA benchmark client script is not supported for gen_only mode"
|
||||
)
|
||||
sys.exit(1)
|
||||
benchmark_cmd = [
|
||||
f"bash {env_config['work_dir']}/run_benchmark_nv_sa.sh",
|
||||
f"'{env_config['model_path']}' {isl} {osl} {benchmark_config['benchmark_ratio']} {benchmark_config['multi_round']} {gen_num} '{benchmark_config['concurrency_list']}' {benchmark_config['streaming']} '{log_dir}' {disagg_server_hostname} {disagg_server_port}",
|
||||
f"bash {os.path.join(script_dir, 'run_benchmark_nv_sa.sh')}",
|
||||
f"'{env_config['model_path']}' {isl} {osl} {benchmark_config['benchmark_ratio']} {benchmark_config['multi_round']} {gen_num} '{benchmark_config['concurrency_list']}' {benchmark_config['streaming']} '{log_dir}' {disagg_server_hostname} {disagg_server_port} {ucx_warmup_requests}",
|
||||
f"&> {log_dir}/6_bench.log"
|
||||
]
|
||||
client_cmds.append(" ".join(benchmark_prefix + benchmark_cmd))
|
||||
else:
|
||||
benchmark_cmd = [
|
||||
f"bash {env_config['work_dir']}/run_benchmark.sh",
|
||||
f"'{env_config['model_path']}' '{benchmark_config['dataset_file']}' {benchmark_config['multi_round']} {gen_num} '{benchmark_config['concurrency_list']}' {benchmark_config['streaming']} '{log_dir}' {disagg_server_hostname} {disagg_server_port}",
|
||||
f"bash {os.path.join(script_dir, 'run_benchmark.sh')}",
|
||||
f"'{env_config['model_path']}' '{benchmark_config['dataset_file']}' {benchmark_config['multi_round']} {gen_num} '{benchmark_config['concurrency_list']}' {benchmark_config['streaming']} '{log_dir}' {disagg_server_hostname} {disagg_server_port} {ucx_warmup_requests}",
|
||||
f"&> {log_dir}/6_bench.log"
|
||||
]
|
||||
client_cmds.append(" ".join(benchmark_prefix + benchmark_cmd))
|
||||
@ -426,6 +487,18 @@ def submit_job(config, log_dir, dry_run):
|
||||
with open(os.path.join(log_dir, "client_cmds.sh"), "w") as f:
|
||||
f.write("\n".join(client_cmds) + "\n")
|
||||
|
||||
# Resolve slurm script_file path
|
||||
# If it's a relative path, make it relative to script_dir
|
||||
slurm_script_file = slurm_config['script_file']
|
||||
if not os.path.isabs(slurm_script_file):
|
||||
slurm_script_file = os.path.join(script_dir, slurm_script_file)
|
||||
|
||||
# Verify the script file exists
|
||||
if not os.path.exists(slurm_script_file):
|
||||
print(f"[ERROR] SLURM script file not found: {slurm_script_file}",
|
||||
file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Prepare sbatch command
|
||||
# yapf: disable
|
||||
cmd = [
|
||||
@ -437,21 +510,22 @@ def submit_job(config, log_dir, dry_run):
|
||||
f'--nodes={total_nodes}',
|
||||
f'--ntasks={total_tasks}',
|
||||
f'--ntasks-per-node={hw_config["gpus_per_node"]}',
|
||||
*([] if not slurm_config['set_segment'] else [f'--segment={total_nodes}']),
|
||||
*([] if not slurm_config['set_segment']
|
||||
else [f'--segment={total_nodes}']),
|
||||
f'--output={log_dir}/slurm-%j.out',
|
||||
f'--error={log_dir}/slurm-%j.err',
|
||||
*([arg for arg in slurm_config['extra_args'].split() if arg]),
|
||||
slurm_config['script_file'],
|
||||
slurm_script_file,
|
||||
|
||||
# Benchmark Configuration
|
||||
'--benchmark-mode', benchmark_config['mode'],
|
||||
|
||||
# Environment and paths
|
||||
'--trtllm-repo', env_config['trtllm_repo'],
|
||||
'--work-dir', env_config['work_dir'],
|
||||
'--work-dir', script_dir,
|
||||
'--full-logdir', log_dir,
|
||||
'--container-name', container_name,
|
||||
'--container-mount', env_config['container_mount'],
|
||||
'--container-mount', container_mount_str,
|
||||
'--container-image', env_config['container_image'],
|
||||
'--build-wheel', str(env_config['build_wheel']).lower(),
|
||||
'--cuda-architectures', env_config['cuda_architectures'],
|
||||
|
||||
Loading…
Reference in New Issue
Block a user