TensorRT-LLMs/tensorrt_llm/llmapi/mgmn_leader_node.py
Yan Chunwei ad4226d946
fix: trtllm-bench build trt engine on slurm (#3825)
* add submit_sync to RemoteMpiSessionClient

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

add barrier

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

fix comment

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

disable test

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

* fix

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

---------

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
2025-04-27 22:26:23 +08:00

30 lines
1.0 KiB
Python

'''
This script is used to start the MPICommSession in the rank0 and wait for the
MPI Proxy process to connect and get the MPI task to run.
'''
from tensorrt_llm._utils import mpi_world_size
from tensorrt_llm.executor.utils import get_spawn_proxy_process_ipc_addr_env
from tensorrt_llm.llmapi.mpi_session import RemoteMpiCommSessionServer
from tensorrt_llm.llmapi.utils import print_colored_debug
def launch_server_main(sub_comm=None):
num_ranks = sub_comm.Get_size() if sub_comm is not None else mpi_world_size(
)
print_colored_debug(f"Starting MPI Comm Server with {num_ranks} workers\n",
"yellow")
server = RemoteMpiCommSessionServer(
comm=sub_comm,
n_workers=num_ranks,
addr=get_spawn_proxy_process_ipc_addr_env(),
is_comm=True)
print_colored_debug(
f"MPI Comm Server started at {get_spawn_proxy_process_ipc_addr_env()}")
server.serve()
print_colored_debug("RemoteMpiCommSessionServer stopped\n", "yellow")
if __name__ == '__main__':
launch_server_main()