mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
30 lines
1.0 KiB
Python
30 lines
1.0 KiB
Python
'''
|
|
This script is used to start the MPICommSession in the rank0 and wait for the
|
|
MPI Proxy process to connect and get the MPI task to run.
|
|
'''
|
|
from tensorrt_llm._utils import mpi_world_size
|
|
from tensorrt_llm.executor.utils import get_spawn_proxy_process_ipc_addr_env
|
|
from tensorrt_llm.llmapi.mpi_session import RemoteMpiCommSessionServer
|
|
from tensorrt_llm.llmapi.utils import print_colored_debug
|
|
|
|
|
|
def launch_server_main(sub_comm=None):
|
|
num_ranks = sub_comm.Get_size() if sub_comm is not None else mpi_world_size(
|
|
)
|
|
print_colored_debug(f"Starting MPI Comm Server with {num_ranks} workers\n",
|
|
"yellow")
|
|
server = RemoteMpiCommSessionServer(
|
|
comm=sub_comm,
|
|
n_workers=num_ranks,
|
|
addr=get_spawn_proxy_process_ipc_addr_env(),
|
|
is_comm=True)
|
|
print_colored_debug(
|
|
f"MPI Comm Server started at {get_spawn_proxy_process_ipc_addr_env()}")
|
|
server.serve()
|
|
|
|
print_colored_debug("MPI Comm Server stopped")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
launch_server_main()
|