TensorRT-LLMs/tensorrt_llm/llmapi/mgmn_leader_node.py
Yan Chunwei f81caf5491
[None][chore] replace print_colored_debug with logger_debug (#8417)
Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com>
2025-10-22 17:54:38 +08:00

69 lines
2.2 KiB
Python

'''
This script is used to start the MPICommSession in the rank0 and wait for the
MPI Proxy process to connect and get the MPI task to run.
'''
from typing import Literal
import click
import zmq
from tensorrt_llm._utils import global_mpi_rank, mpi_world_size
from tensorrt_llm.executor.ipc import ZeroMqQueue
from tensorrt_llm.executor.utils import get_spawn_proxy_process_ipc_addr_env
from tensorrt_llm.llmapi.mpi_session import RemoteMpiCommSessionServer
from tensorrt_llm.llmapi.utils import logger_debug
def launch_server_main(sub_comm=None):
num_ranks = sub_comm.Get_size() if sub_comm is not None else mpi_world_size(
)
logger_debug(f"Starting MPI Comm Server with {num_ranks} workers\n",
"yellow")
server = RemoteMpiCommSessionServer(
comm=sub_comm,
n_workers=num_ranks,
addr=get_spawn_proxy_process_ipc_addr_env(),
is_comm=True)
logger_debug(
f"MPI Comm Server started at {get_spawn_proxy_process_ipc_addr_env()}")
server.serve()
logger_debug("RemoteMpiCommSessionServer stopped\n", "yellow")
def stop_server_main():
queue = ZeroMqQueue((get_spawn_proxy_process_ipc_addr_env(), None),
use_hmac_encryption=False,
is_server=False,
socket_type=zmq.PAIR)
try:
logger_debug(
f"RemoteMpiCommSessionClient [rank{global_mpi_rank()}] send shutdown signal to server\n",
"green")
queue.put(None) # ask RemoteMpiCommSessionServer to shutdown
except zmq.error.ZMQError as e:
logger_debug(f"Error during RemoteMpiCommSessionClient shutdown: {e}\n",
"red")
@click.command()
@click.option("--action", type=click.Choice(["start", "stop"]), default="start")
def main(action: Literal["start", "stop"] = "start"):
'''
Arguments:
action: The action to perform.
start: Start the MPI Comm Server.
stop: Stop the MPI Comm Server.
'''
if action == "start":
launch_server_main()
elif action == "stop":
stop_server_main()
else:
raise ValueError(f"Invalid action: {action}")
if __name__ == '__main__':
main()