mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
69 lines
2.2 KiB
Python
69 lines
2.2 KiB
Python
'''
|
|
This script is used to start the MPICommSession in the rank0 and wait for the
|
|
MPI Proxy process to connect and get the MPI task to run.
|
|
'''
|
|
from typing import Literal
|
|
|
|
import click
|
|
import zmq
|
|
|
|
from tensorrt_llm._utils import global_mpi_rank, mpi_world_size
|
|
from tensorrt_llm.executor.ipc import ZeroMqQueue
|
|
from tensorrt_llm.executor.utils import get_spawn_proxy_process_ipc_addr_env
|
|
from tensorrt_llm.llmapi.mpi_session import RemoteMpiCommSessionServer
|
|
from tensorrt_llm.llmapi.utils import logger_debug
|
|
|
|
|
|
def launch_server_main(sub_comm=None):
|
|
num_ranks = sub_comm.Get_size() if sub_comm is not None else mpi_world_size(
|
|
)
|
|
logger_debug(f"Starting MPI Comm Server with {num_ranks} workers\n",
|
|
"yellow")
|
|
server = RemoteMpiCommSessionServer(
|
|
comm=sub_comm,
|
|
n_workers=num_ranks,
|
|
addr=get_spawn_proxy_process_ipc_addr_env(),
|
|
is_comm=True)
|
|
logger_debug(
|
|
f"MPI Comm Server started at {get_spawn_proxy_process_ipc_addr_env()}")
|
|
server.serve()
|
|
|
|
logger_debug("RemoteMpiCommSessionServer stopped\n", "yellow")
|
|
|
|
|
|
def stop_server_main():
|
|
queue = ZeroMqQueue((get_spawn_proxy_process_ipc_addr_env(), None),
|
|
use_hmac_encryption=False,
|
|
is_server=False,
|
|
socket_type=zmq.PAIR)
|
|
|
|
try:
|
|
logger_debug(
|
|
f"RemoteMpiCommSessionClient [rank{global_mpi_rank()}] send shutdown signal to server\n",
|
|
"green")
|
|
queue.put(None) # ask RemoteMpiCommSessionServer to shutdown
|
|
except zmq.error.ZMQError as e:
|
|
logger_debug(f"Error during RemoteMpiCommSessionClient shutdown: {e}\n",
|
|
"red")
|
|
|
|
|
|
@click.command()
|
|
@click.option("--action", type=click.Choice(["start", "stop"]), default="start")
|
|
def main(action: Literal["start", "stop"] = "start"):
|
|
'''
|
|
Arguments:
|
|
action: The action to perform.
|
|
start: Start the MPI Comm Server.
|
|
stop: Stop the MPI Comm Server.
|
|
'''
|
|
if action == "start":
|
|
launch_server_main()
|
|
elif action == "stop":
|
|
stop_server_main()
|
|
else:
|
|
raise ValueError(f"Invalid action: {action}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|