[https://nvbugs/5437384][test] fix trtllm-llmapi-launch multi tests with single launch (#8397)

Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com>
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
This commit is contained in:
Yan Chunwei 2025-10-17 12:14:43 +08:00 committed by GitHub
parent 82430f84dc
commit 995b93bc38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 232 additions and 6 deletions

View File

@ -48,6 +48,10 @@ class MPINodeState:
'''
state = None
# Global MPICommExecutor instance to be reused across multiple MpiCommSession instances
# This is necessary because MPICommExecutor can only be created once per MPI process
_global_comm_executor = None
_global_mpi_pool = None
@staticmethod
def is_initialized() -> bool:
@ -183,6 +187,7 @@ class MpiCommSession(MpiSession):
self.n_workers = n_workers
self.thread_pool: Optional[ThreadPoolExecutor] = None
self.mpi_pool: Optional[MPIPoolExecutor] = None
self.owns_mpi_pool = False # Track if this instance owns the mpi_pool
if n_workers <= 0:
raise ValueError(
@ -230,9 +235,11 @@ class MpiCommSession(MpiSession):
return [future.result() for future in futures]
def shutdown(self, wait=True):
if self.mpi_pool is not None:
# Only shutdown the mpi_pool if this instance created it
# For shared global mpi_pool, we don't shut it down
if self.mpi_pool is not None and self.owns_mpi_pool:
self.mpi_pool.shutdown(wait=wait)
self.mpi_pool = None
self.mpi_pool = None
if self.thread_pool is not None:
self.thread_pool.shutdown(wait=wait)
self.thread_pool = None
@ -244,8 +251,37 @@ class MpiCommSession(MpiSession):
assert not self.mpi_pool, 'MPI session already started'
self.thread_pool = ThreadPoolExecutor(max_workers=2)
comm_executor = MPICommExecutor(self.comm)
self.mpi_pool = comm_executor.__enter__()
# Use global MPICommExecutor if using COMM_WORLD
# This is necessary because MPICommExecutor can only be created once per MPI process
print_colored_debug(
f"_start_mpi_pool: ENABLE_MULTI_DEVICE={ENABLE_MULTI_DEVICE}, self.comm={self.comm}\n",
"grey")
if ENABLE_MULTI_DEVICE:
print_colored_debug(
f"_start_mpi_pool: Checking if self.comm == mpi4py.MPI.COMM_WORLD: {self.comm == mpi4py.MPI.COMM_WORLD}\n",
"grey")
if ENABLE_MULTI_DEVICE and self.comm == mpi4py.MPI.COMM_WORLD:
if MPINodeState._global_comm_executor is None:
print_colored_debug(
"Creating global MPICommExecutor for COMM_WORLD\n",
"yellow")
MPINodeState._global_comm_executor = MPICommExecutor(self.comm)
MPINodeState._global_mpi_pool = MPINodeState._global_comm_executor.__enter__(
)
else:
print_colored_debug(
"Reusing global MPICommExecutor for COMM_WORLD\n", "yellow")
self.mpi_pool = MPINodeState._global_mpi_pool
self.owns_mpi_pool = False
else:
print_colored_debug(
f"_start_mpi_pool: Creating new MPICommExecutor (not COMM_WORLD or ENABLE_MULTI_DEVICE=False)\n",
"grey")
# For non-COMM_WORLD communicators, create a new executor
comm_executor = MPICommExecutor(self.comm)
self.mpi_pool = comm_executor.__enter__()
self.owns_mpi_pool = True
def __del__(self):
self.shutdown_abort()
@ -264,9 +300,35 @@ class RemoteTask(NamedTuple):
class RemoteMpiCommSessionClient(MpiSession):
'''
RemoteMpiCommSessionClient is a variant of MpiCommSession that is used to connect to a remote MPI pool.
Note: This class uses a global singleton pattern because ZeroMQ PAIR sockets only support
one connection at a time. Multiple LLM instances will reuse the same client connection.
'''
_global_instance = None
_global_instance_lock = threading.Lock()
def __new__(cls, addr: str, hmac_key: Optional[bytes] = None):
# Implement singleton pattern to reuse the same client connection
# for multiple LLM instances, since PAIR sockets only support one connection
with cls._global_instance_lock:
if cls._global_instance is None or cls._global_instance.addr != addr:
print_colored_debug(
f"Creating new global RemoteMpiCommSessionClient for {addr}\n",
"yellow")
instance = super().__new__(cls)
cls._global_instance = instance
instance._initialized = False
else:
print_colored_debug(
f"Reusing existing global RemoteMpiCommSessionClient for {addr}\n",
"yellow")
return cls._global_instance
def __init__(self, addr: str, hmac_key: Optional[bytes] = None):
# Only initialize once
if self._initialized:
return
# FIXME: this is a hack to avoid circular import, resolve later
from tensorrt_llm.executor.ipc import ZeroMqQueue
self.addr = addr
@ -277,6 +339,7 @@ class RemoteMpiCommSessionClient(MpiSession):
socket_type=zmq.PAIR,
use_hmac_encryption=bool(hmac_key))
self._is_shutdown = False
self._initialized = True
def submit(self,
task: Callable[..., T],
@ -330,10 +393,16 @@ class RemoteMpiCommSessionClient(MpiSession):
self.shutdown()
def shutdown(self, wait=True):
pass
# NOTE: We do NOT close the queue or mark as shutdown for the singleton instance.
# The RemoteMpiCommSessionClient is a global singleton that's reused across multiple
# LLM instances. Marking it as shutdown would prevent subsequent LLM instances from
# using it. The connection stays open for the entire lifetime of the mgmn setup.
print_colored_debug(
f"RemoteMpiCommSessionClient.shutdown() called (no-op for singleton)\n",
"grey")
def shutdown_abort(self, grace: float = 60, reason=None):
pass
self.shutdown()
class RemoteMpiCommSessionServer():
@ -394,7 +463,26 @@ class RemoteMpiCommSessionServer():
def serve(self):
print_colored_debug(
f"RemoteMpiCommSessionServer listening on {self.addr}\n", "yellow")
pending_futures = []
while True:
# Wait for any pending futures from previous tasks to complete
# This ensures all ranks are ready before accepting the next task
if pending_futures:
print_colored_debug(
f"RemoteMpiCommSessionServer waiting for {len(pending_futures)} pending futures to complete\n",
"grey")
for future in pending_futures:
try:
future.result() # Wait for completion
except Exception as e:
print_colored(
f"RemoteMpiCommSessionServer future failed with exception: {e}\n",
"red")
pending_futures.clear()
print_colored_debug(
"RemoteMpiCommSessionServer all pending futures completed\n",
"grey")
message: Optional[RemoteTask] = self.queue.get()
if message is None:
print_colored_debug(
@ -411,6 +499,8 @@ class RemoteMpiCommSessionServer():
*message.args, **message.kwargs)
self.num_results = self.session.n_workers
assert len(futures) == self.num_results == mpi_world_size()
# Store futures to wait for them before the next task
pending_futures = list(futures)
if message.sync:
for future in futures:
future.add_done_callback(self.mpi_future_callback)

View File

@ -35,6 +35,8 @@ l0_dgx_h100:
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True]
# ------------- AutoDeploy tests ---------------
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype
# llmapi
- unittest/llmapi/test_mpi_session.py::test_llmapi_launch_multiple_tasks
- condition:
ranges:
system_gpu_count:

View File

@ -0,0 +1,33 @@
import os
import sys
cur_dir = os.path.dirname(os.path.abspath(__file__))
from tensorrt_llm import LLM
from tensorrt_llm.llmapi import SamplingParams
from tensorrt_llm.llmapi.utils import print_colored
# isort: off
sys.path.append(os.path.join(cur_dir, '..'))
from utils.llm_data import llm_models_root
# isort: on
model_path = llm_models_root() / "llama-models-v2" / "TinyLlama-1.1B-Chat-v1.0"
def run_llm_tp2():
with LLM(model=model_path, tensor_parallel_size=2) as llm:
sampling_params = SamplingParams(max_tokens=10, end_id=-1)
for output in llm.generate(["Hello, my name is"], sampling_params):
print(output)
def run_multi_llm_tasks():
for i in range(3):
print_colored(f"Running LLM task {i}\n", "green")
run_llm_tp2()
print_colored(f"LLM task {i} completed\n", "green")
if __name__ == "__main__":
run_multi_llm_tasks()

View File

@ -0,0 +1,43 @@
import os
from typing import Literal
import click
from tensorrt_llm.executor.utils import LlmLauncherEnvs
from tensorrt_llm.llmapi.mpi_session import RemoteMpiCommSessionClient
from tensorrt_llm.llmapi.utils import print_colored
def run_task(task_type: Literal["submit", "submit_sync"]):
tasks = range(10)
assert os.environ[
LlmLauncherEnvs.
TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR] is not None, "TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR is not set"
client = RemoteMpiCommSessionClient(
os.environ[LlmLauncherEnvs.TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR])
for task in tasks:
if task_type == "submit":
client.submit(print_colored, f"{task}\n", "green")
elif task_type == "submit_sync":
res = client.submit_sync(print_colored, f"{task}\n", "green")
print(res)
def run_multi_tasks(task_type: Literal["submit", "submit_sync"]):
for i in range(3):
print_colored(f"Running MPI comm task {i}\n", "green")
run_task(task_type)
print_colored(f"MPI comm task {i} completed\n", "green")
@click.command()
@click.option("--task_type",
type=click.Choice(["submit", "submit_sync"]),
default="submit")
def main(task_type: Literal["submit", "submit_sync"]):
run_multi_tasks(task_type)
if __name__ == "__main__":
main()

View File

@ -5,6 +5,8 @@ import threading
from subprocess import PIPE, Popen
from typing import Literal
cur_dir = os.path.dirname(os.path.abspath(__file__))
import pytest
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
@ -12,6 +14,11 @@ from tensorrt_llm.llmapi.mpi_session import (MPINodeState, MpiPoolSession,
RemoteMpiCommSessionClient,
split_mpi_env)
# isort: off
sys.path.append(os.path.join(cur_dir, '..'))
from utils.util import skip_single_gpu
# isort: on
def task0():
if MPINodeState.state is None:
@ -108,3 +115,54 @@ def task1():
def test_split_mpi_env():
session = MpiPoolSession(n_workers=4)
session.submit_sync(task1)
@skip_single_gpu
@pytest.mark.parametrize(
"task_script", ["_run_mpi_comm_task.py", "_run_multi_mpi_comm_tasks.py"])
def test_llmapi_launch_multiple_tasks(task_script: str):
"""
Test that the trtllm-llmapi-launch can run multiple tasks.
"""
cur_dir = os.path.dirname(os.path.abspath(__file__))
test_file = os.path.join(cur_dir, "_run_multi_llm_tasks.py")
assert os.path.exists(test_file), f"Test file {test_file} does not exist"
command = [
"mpirun", "-n", "2", "--allow-run-as-root", "trtllm-llmapi-launch",
"python3", test_file
]
print(' '.join(command))
with Popen(command,
env=os.environ,
stdout=PIPE,
stderr=PIPE,
bufsize=1,
start_new_session=True,
universal_newlines=True,
cwd=os.path.dirname(os.path.abspath(__file__))) as process:
# Function to read from a stream and write to output
def read_stream(stream, output_stream):
for line in stream:
output_stream.write(line)
output_stream.flush()
# Create threads to read stdout and stderr concurrently
stdout_thread = threading.Thread(target=read_stream,
args=(process.stdout, sys.stdout))
stderr_thread = threading.Thread(target=read_stream,
args=(process.stderr, sys.stderr))
# Start both threads
stdout_thread.start()
stderr_thread.start()
# Wait for the process to complete
return_code = process.wait()
# Wait for both threads to finish reading
stdout_thread.join()
stderr_thread.join()
if return_code != 0:
raise subprocess.CalledProcessError(return_code, command)