mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* add trtllm-bench mgmn test Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
33 lines
1.1 KiB
Python
33 lines
1.1 KiB
Python
'''
|
|
This script is used to start the MPICommSession in the rank0 and wait for the
|
|
MPI Proxy process to connect and get the MPI task to run.
|
|
'''
|
|
from tensorrt_llm._utils import mpi_world_size
|
|
from tensorrt_llm.executor.utils import (
|
|
get_spawn_proxy_process_ipc_addr_env,
|
|
get_spawn_proxy_process_ipc_hmac_key_env)
|
|
from tensorrt_llm.llmapi.mpi_session import RemoteMpiCommSessionServer
|
|
from tensorrt_llm.llmapi.utils import print_colored_debug
|
|
|
|
|
|
def launch_server_main(sub_comm=None):
|
|
num_ranks = sub_comm.Get_size() if sub_comm is not None else mpi_world_size(
|
|
)
|
|
print_colored_debug(f"Starting MPI Comm Server with {num_ranks} workers\n",
|
|
"yellow")
|
|
server = RemoteMpiCommSessionServer(
|
|
comm=sub_comm,
|
|
n_workers=num_ranks,
|
|
addr=get_spawn_proxy_process_ipc_addr_env(),
|
|
hmac_key=get_spawn_proxy_process_ipc_hmac_key_env(),
|
|
is_comm=True)
|
|
print_colored_debug(
|
|
f"MPI Comm Server started at {get_spawn_proxy_process_ipc_addr_env()}")
|
|
server.serve()
|
|
|
|
print_colored_debug("RemoteMpiCommSessionServer stopped\n", "yellow")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
launch_server_main()
|