#!/bin/bash set -Eeo pipefail task_with_command=("$@") native_mpi_rank=$OMPI_COMM_WORLD_RANK mpi_rank=${SLURM_PROCID:-${OMPI_COMM_WORLD_RANK:-${PMI_RANK:-${PMI_ID:-0}}}} log_stderr() { echo -e "\033[33m$@\033[0m" >&2; } log_stderr "mpi_rank: $mpi_rank" pid=$(ps -o pid= -p $$ | tr -d ' ') # Tell TRTLLM to spawn a additional process for the Proxy export TLLM_SPAWN_PROXY_PROCESS=1 function mpi_world_size { if [ -n "$SLURM_NTASKS" ]; then echo "$SLURM_NTASKS" elif [ -n "$OMPI_COMM_WORLD_SIZE" ]; then echo "$OMPI_COMM_WORLD_SIZE" else echo "1" fi } function maybe_export_free_tcp_addr_for_spawn_proxy_process { # use user specified address if provided if [ -n "$TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR" ]; then log_stderr "Using user-provided TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR: $TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR" return fi # find free port starting from 10012 local free_port=$(python -c 'import socket; s=socket.socket(); port = 10012 while True: try: s.bind(("", port)) break except OSError: port += 1 print(port); s.close()') export TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR="tcp://127.0.0.1:${free_port}" log_stderr "TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR: $TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR" } export tllm_mpi_size=$(mpi_world_size) log_stderr "tllm_mpi_size: $tllm_mpi_size" maybe_export_free_tcp_addr_for_spawn_proxy_process export TLLM_SPAWN_PROXY_PROCESS_IPC_HMAC_KEY=$(openssl rand -hex 32) if [ -z "$mpi_rank" ] || [ "$mpi_rank" -eq 0 ]; then log_stderr "Rank${mpi_rank} run ${task_with_command[@]} in background" # MPI doesn't allow spawn a process sharing the MPI environment in a MPI # process, or duplicate MPI_Init in the child process will cause undefined # behavior. Thus we need to clean the MPI environment in the parent process # before spawning the child process, and restore the MPI environment later # before running MPI operations in the parent process. mpi_blacklist=( OMPI_ PMIX_ PMI_ SLURM_ MPI_ UCX_ I_MPI_ HYDRA_ KMP_ MPICH_ MV2_ CRAY_ ) ( # Remove MPI-related variables only in the subshell context for var in $(compgen -e); do for prefix in "${mpi_blacklist[@]}"; do if [[ "$var" == "$prefix"* ]]; then unset "$var" break fi done done # Turn off "exit on error" so the following lines always run set +e # Execute the task with cleaned environment "${task_with_command[@]}" task_exit_code=$? log_stderr "Rank${mpi_rank} Task exit code: $task_exit_code" # Stop the MPI Comm server python3 -m tensorrt_llm.llmapi.mgmn_leader_node --action stop mpi_exit_code=$? log_stderr "Rank${mpi_rank} MPI Comm server exit code: $mpi_exit_code" # Propagate task exit status if [ $task_exit_code -ne 0 ]; then exit $task_exit_code else exit $mpi_exit_code fi ) & # Turn off "exit on error" so the following lines always run set +e # Capture subshell PID subshell_pid=$! log_stderr "Rank${mpi_rank} Subshell PID: $subshell_pid" log_stderr "Rank${mpi_rank} run mgmn leader node with mpi_world_size: $(mpi_world_size) ..." log_stderr "Rank0 host: $HOSTNAME" python3 -m tensorrt_llm.llmapi.mgmn_leader_node mgmn_leader_node_exit_code=$? log_stderr "Rank${mpi_rank} MGMN leader node exit code: $mgmn_leader_node_exit_code" # Wait for subshell wait $subshell_pid # This is subshell's exit code subshell_exit_code=$? log_stderr "Rank${mpi_rank} Subshell exit code: $subshell_exit_code" # Propagate subshell exit status if [ $subshell_exit_code -ne 0 ]; then exit $subshell_exit_code else exit $mgmn_leader_node_exit_code fi else # Turn off "exit on error" so the following lines always run set +e log_stderr "Rank${mpi_rank} run mgmn worker node with mpi_world_size: $(mpi_world_size) ..." python3 -m tensorrt_llm.llmapi.mgmn_worker_node mgmn_worker_node_exit_code=$? log_stderr "Rank${mpi_rank} MGMN worker node exit code: $mgmn_worker_node_exit_code" exit $mgmn_worker_node_exit_code fi