mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[None][feat] Support disagg slurm jobs rescheduling (#11218)
This commit is contained in:
parent
f6fff18142
commit
e2bd9cce1e
@ -56,14 +56,16 @@ cleanup_on_failure() {
|
||||
replace_placeholder() {
|
||||
file_path="$1"
|
||||
all_nodes_str="$2"
|
||||
new_file_path="$3"
|
||||
cp "$file_path" "$new_file_path"
|
||||
IFS=',' read -r -a node_array <<< "$all_nodes_str"
|
||||
for i in "${!node_array[@]}"; do
|
||||
current_val="${node_array[$i]}"
|
||||
placeholder="<node${i}_placeholder>"
|
||||
|
||||
# Use sed to replace the placeholder with the value in-place
|
||||
sed -i "s|$placeholder|$current_val|g" "${file_path}"
|
||||
echo "Replaced $placeholder with $current_val in ${file_path}"
|
||||
sed -i "s|$placeholder|$current_val|g" "${new_file_path}"
|
||||
echo "Replaced $placeholder with $current_val in ${new_file_path}"
|
||||
done
|
||||
}
|
||||
|
||||
@ -142,12 +144,15 @@ all_nodes=($(scontrol show hostname $SLURM_NODELIST | sort))
|
||||
all_nodes_str=$(IFS=','; echo "${all_nodes[*]}")
|
||||
echo "all_nodes_str: ${all_nodes_str}"
|
||||
|
||||
start_server_cmds_base_file=${full_logdir}/start_server_cmds_base.sh
|
||||
start_server_cmds_file=${full_logdir}/start_server_cmds.sh
|
||||
replace_placeholder ${start_server_cmds_file} ${all_nodes_str}
|
||||
replace_placeholder "${start_server_cmds_base_file}" "${all_nodes_str}" "${start_server_cmds_file}"
|
||||
server_config_base_file=${full_logdir}/server_config_base.yaml
|
||||
server_config_file=${full_logdir}/server_config.yaml
|
||||
replace_placeholder ${server_config_file} ${all_nodes_str}
|
||||
replace_placeholder "${server_config_base_file}" "${all_nodes_str}" "${server_config_file}"
|
||||
client_cmds_base_file=${full_logdir}/client_cmds_base.sh
|
||||
client_cmds_file=${full_logdir}/client_cmds.sh
|
||||
replace_placeholder ${client_cmds_file} ${all_nodes_str}
|
||||
replace_placeholder "${client_cmds_base_file}" "${all_nodes_str}" "${client_cmds_file}"
|
||||
|
||||
# start the servers (skip ctx workers if TRTLLM_DISAGG_BENCHMARK_GEN_ONLY is set).
|
||||
echo "Starting worker commands from ${start_server_cmds_file}..."
|
||||
@ -164,7 +169,6 @@ done
|
||||
echo "Server is ready!"
|
||||
|
||||
# Start client commands
|
||||
client_cmds_file=${full_logdir}/client_cmds.sh
|
||||
echo "Starting client commands from ${client_cmds_file}..."
|
||||
while read -r cmd <&3; do
|
||||
echo "Starting client command: ${cmd}"
|
||||
|
||||
@ -457,7 +457,7 @@ def submit_job(config, log_dir, dry_run):
|
||||
|
||||
# Generate disagg server config
|
||||
server_config = convert_allocations_to_server_config(allocations)
|
||||
with open(os.path.join(log_dir, "server_config.yaml"), "w") as f:
|
||||
with open(os.path.join(log_dir, "server_config_base.yaml"), "w") as f:
|
||||
yaml.dump(server_config, f)
|
||||
disagg_server_hostname = server_config['hostname']
|
||||
disagg_server_port = server_config['port']
|
||||
@ -557,7 +557,7 @@ def submit_job(config, log_dir, dry_run):
|
||||
]
|
||||
start_server_cmds.append(" ".join(cmd))
|
||||
|
||||
with open(os.path.join(log_dir, "start_server_cmds.sh"), "w") as f:
|
||||
with open(os.path.join(log_dir, "start_server_cmds_base.sh"), "w") as f:
|
||||
f.write("\n".join(start_server_cmds) + "\n")
|
||||
|
||||
# Generate client commands (use script_dir for benchmark scripts)
|
||||
@ -632,7 +632,7 @@ def submit_job(config, log_dir, dry_run):
|
||||
]
|
||||
client_cmds.append(" ".join(done_cmd))
|
||||
|
||||
with open(os.path.join(log_dir, "client_cmds.sh"), "w") as f:
|
||||
with open(os.path.join(log_dir, "client_cmds_base.sh"), "w") as f:
|
||||
f.write("\n".join(client_cmds) + "\n")
|
||||
|
||||
# Resolve slurm script_file path
|
||||
|
||||
Loading…
Reference in New Issue
Block a user