TensorRT-LLMs/tensorrt_llm/llmapi/mgmn_leader_node.py
Kaiyu Xie 2631f21089
Update (#2978)
Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
2025-03-23 16:39:35 +08:00

30 lines
1.0 KiB
Python

'''
This script is used to start the MPICommSession in the rank0 and wait for the
MPI Proxy process to connect and get the MPI task to run.
'''
from tensorrt_llm._utils import mpi_world_size
from tensorrt_llm.executor.utils import get_spawn_proxy_process_ipc_addr_env
from tensorrt_llm.llmapi.mpi_session import RemoteMpiCommSessionServer
from tensorrt_llm.llmapi.utils import print_colored_debug
def launch_server_main(sub_comm=None):
num_ranks = sub_comm.Get_size() if sub_comm is not None else mpi_world_size(
)
print_colored_debug(f"Starting MPI Comm Server with {num_ranks} workers\n",
"yellow")
server = RemoteMpiCommSessionServer(
comm=sub_comm,
n_workers=num_ranks,
addr=get_spawn_proxy_process_ipc_addr_env(),
is_comm=True)
print_colored_debug(
f"MPI Comm Server started at {get_spawn_proxy_process_ipc_addr_env()}")
server.serve()
print_colored_debug("MPI Comm Server stopped")
if __name__ == '__main__':
launch_server_main()