mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5437384][test] fix trtllm-llmapi-launch multi tests with single launch (#8397)
Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
This commit is contained in:
parent
82430f84dc
commit
995b93bc38
@ -48,6 +48,10 @@ class MPINodeState:
|
||||
'''
|
||||
|
||||
state = None
|
||||
# Global MPICommExecutor instance to be reused across multiple MpiCommSession instances
|
||||
# This is necessary because MPICommExecutor can only be created once per MPI process
|
||||
_global_comm_executor = None
|
||||
_global_mpi_pool = None
|
||||
|
||||
@staticmethod
|
||||
def is_initialized() -> bool:
|
||||
@ -183,6 +187,7 @@ class MpiCommSession(MpiSession):
|
||||
self.n_workers = n_workers
|
||||
self.thread_pool: Optional[ThreadPoolExecutor] = None
|
||||
self.mpi_pool: Optional[MPIPoolExecutor] = None
|
||||
self.owns_mpi_pool = False # Track if this instance owns the mpi_pool
|
||||
|
||||
if n_workers <= 0:
|
||||
raise ValueError(
|
||||
@ -230,9 +235,11 @@ class MpiCommSession(MpiSession):
|
||||
return [future.result() for future in futures]
|
||||
|
||||
def shutdown(self, wait=True):
|
||||
if self.mpi_pool is not None:
|
||||
# Only shutdown the mpi_pool if this instance created it
|
||||
# For shared global mpi_pool, we don't shut it down
|
||||
if self.mpi_pool is not None and self.owns_mpi_pool:
|
||||
self.mpi_pool.shutdown(wait=wait)
|
||||
self.mpi_pool = None
|
||||
self.mpi_pool = None
|
||||
if self.thread_pool is not None:
|
||||
self.thread_pool.shutdown(wait=wait)
|
||||
self.thread_pool = None
|
||||
@ -244,8 +251,37 @@ class MpiCommSession(MpiSession):
|
||||
assert not self.mpi_pool, 'MPI session already started'
|
||||
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=2)
|
||||
comm_executor = MPICommExecutor(self.comm)
|
||||
self.mpi_pool = comm_executor.__enter__()
|
||||
|
||||
# Use global MPICommExecutor if using COMM_WORLD
|
||||
# This is necessary because MPICommExecutor can only be created once per MPI process
|
||||
print_colored_debug(
|
||||
f"_start_mpi_pool: ENABLE_MULTI_DEVICE={ENABLE_MULTI_DEVICE}, self.comm={self.comm}\n",
|
||||
"grey")
|
||||
if ENABLE_MULTI_DEVICE:
|
||||
print_colored_debug(
|
||||
f"_start_mpi_pool: Checking if self.comm == mpi4py.MPI.COMM_WORLD: {self.comm == mpi4py.MPI.COMM_WORLD}\n",
|
||||
"grey")
|
||||
if ENABLE_MULTI_DEVICE and self.comm == mpi4py.MPI.COMM_WORLD:
|
||||
if MPINodeState._global_comm_executor is None:
|
||||
print_colored_debug(
|
||||
"Creating global MPICommExecutor for COMM_WORLD\n",
|
||||
"yellow")
|
||||
MPINodeState._global_comm_executor = MPICommExecutor(self.comm)
|
||||
MPINodeState._global_mpi_pool = MPINodeState._global_comm_executor.__enter__(
|
||||
)
|
||||
else:
|
||||
print_colored_debug(
|
||||
"Reusing global MPICommExecutor for COMM_WORLD\n", "yellow")
|
||||
self.mpi_pool = MPINodeState._global_mpi_pool
|
||||
self.owns_mpi_pool = False
|
||||
else:
|
||||
print_colored_debug(
|
||||
f"_start_mpi_pool: Creating new MPICommExecutor (not COMM_WORLD or ENABLE_MULTI_DEVICE=False)\n",
|
||||
"grey")
|
||||
# For non-COMM_WORLD communicators, create a new executor
|
||||
comm_executor = MPICommExecutor(self.comm)
|
||||
self.mpi_pool = comm_executor.__enter__()
|
||||
self.owns_mpi_pool = True
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown_abort()
|
||||
@ -264,9 +300,35 @@ class RemoteTask(NamedTuple):
|
||||
class RemoteMpiCommSessionClient(MpiSession):
|
||||
'''
|
||||
RemoteMpiCommSessionClient is a variant of MpiCommSession that is used to connect to a remote MPI pool.
|
||||
|
||||
Note: This class uses a global singleton pattern because ZeroMQ PAIR sockets only support
|
||||
one connection at a time. Multiple LLM instances will reuse the same client connection.
|
||||
'''
|
||||
_global_instance = None
|
||||
_global_instance_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, addr: str, hmac_key: Optional[bytes] = None):
|
||||
# Implement singleton pattern to reuse the same client connection
|
||||
# for multiple LLM instances, since PAIR sockets only support one connection
|
||||
with cls._global_instance_lock:
|
||||
if cls._global_instance is None or cls._global_instance.addr != addr:
|
||||
print_colored_debug(
|
||||
f"Creating new global RemoteMpiCommSessionClient for {addr}\n",
|
||||
"yellow")
|
||||
instance = super().__new__(cls)
|
||||
cls._global_instance = instance
|
||||
instance._initialized = False
|
||||
else:
|
||||
print_colored_debug(
|
||||
f"Reusing existing global RemoteMpiCommSessionClient for {addr}\n",
|
||||
"yellow")
|
||||
return cls._global_instance
|
||||
|
||||
def __init__(self, addr: str, hmac_key: Optional[bytes] = None):
|
||||
# Only initialize once
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# FIXME: this is a hack to avoid circular import, resolve later
|
||||
from tensorrt_llm.executor.ipc import ZeroMqQueue
|
||||
self.addr = addr
|
||||
@ -277,6 +339,7 @@ class RemoteMpiCommSessionClient(MpiSession):
|
||||
socket_type=zmq.PAIR,
|
||||
use_hmac_encryption=bool(hmac_key))
|
||||
self._is_shutdown = False
|
||||
self._initialized = True
|
||||
|
||||
def submit(self,
|
||||
task: Callable[..., T],
|
||||
@ -330,10 +393,16 @@ class RemoteMpiCommSessionClient(MpiSession):
|
||||
self.shutdown()
|
||||
|
||||
def shutdown(self, wait=True):
|
||||
pass
|
||||
# NOTE: We do NOT close the queue or mark as shutdown for the singleton instance.
|
||||
# The RemoteMpiCommSessionClient is a global singleton that's reused across multiple
|
||||
# LLM instances. Marking it as shutdown would prevent subsequent LLM instances from
|
||||
# using it. The connection stays open for the entire lifetime of the mgmn setup.
|
||||
print_colored_debug(
|
||||
f"RemoteMpiCommSessionClient.shutdown() called (no-op for singleton)\n",
|
||||
"grey")
|
||||
|
||||
def shutdown_abort(self, grace: float = 60, reason=None):
|
||||
pass
|
||||
self.shutdown()
|
||||
|
||||
|
||||
class RemoteMpiCommSessionServer():
|
||||
@ -394,7 +463,26 @@ class RemoteMpiCommSessionServer():
|
||||
def serve(self):
|
||||
print_colored_debug(
|
||||
f"RemoteMpiCommSessionServer listening on {self.addr}\n", "yellow")
|
||||
pending_futures = []
|
||||
while True:
|
||||
# Wait for any pending futures from previous tasks to complete
|
||||
# This ensures all ranks are ready before accepting the next task
|
||||
if pending_futures:
|
||||
print_colored_debug(
|
||||
f"RemoteMpiCommSessionServer waiting for {len(pending_futures)} pending futures to complete\n",
|
||||
"grey")
|
||||
for future in pending_futures:
|
||||
try:
|
||||
future.result() # Wait for completion
|
||||
except Exception as e:
|
||||
print_colored(
|
||||
f"RemoteMpiCommSessionServer future failed with exception: {e}\n",
|
||||
"red")
|
||||
pending_futures.clear()
|
||||
print_colored_debug(
|
||||
"RemoteMpiCommSessionServer all pending futures completed\n",
|
||||
"grey")
|
||||
|
||||
message: Optional[RemoteTask] = self.queue.get()
|
||||
if message is None:
|
||||
print_colored_debug(
|
||||
@ -411,6 +499,8 @@ class RemoteMpiCommSessionServer():
|
||||
*message.args, **message.kwargs)
|
||||
self.num_results = self.session.n_workers
|
||||
assert len(futures) == self.num_results == mpi_world_size()
|
||||
# Store futures to wait for them before the next task
|
||||
pending_futures = list(futures)
|
||||
if message.sync:
|
||||
for future in futures:
|
||||
future.add_done_callback(self.mpi_future_callback)
|
||||
|
||||
@ -35,6 +35,8 @@ l0_dgx_h100:
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True]
|
||||
# ------------- AutoDeploy tests ---------------
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype
|
||||
# llmapi
|
||||
- unittest/llmapi/test_mpi_session.py::test_llmapi_launch_multiple_tasks
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
|
||||
33
tests/unittest/llmapi/_run_multi_llm_tasks.py
Normal file
33
tests/unittest/llmapi/_run_multi_llm_tasks.py
Normal file
@ -0,0 +1,33 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm.llmapi import SamplingParams
|
||||
from tensorrt_llm.llmapi.utils import print_colored
|
||||
|
||||
# isort: off
|
||||
sys.path.append(os.path.join(cur_dir, '..'))
|
||||
from utils.llm_data import llm_models_root
|
||||
# isort: on
|
||||
|
||||
model_path = llm_models_root() / "llama-models-v2" / "TinyLlama-1.1B-Chat-v1.0"
|
||||
|
||||
|
||||
def run_llm_tp2():
|
||||
with LLM(model=model_path, tensor_parallel_size=2) as llm:
|
||||
sampling_params = SamplingParams(max_tokens=10, end_id=-1)
|
||||
for output in llm.generate(["Hello, my name is"], sampling_params):
|
||||
print(output)
|
||||
|
||||
|
||||
def run_multi_llm_tasks():
|
||||
for i in range(3):
|
||||
print_colored(f"Running LLM task {i}\n", "green")
|
||||
run_llm_tp2()
|
||||
print_colored(f"LLM task {i} completed\n", "green")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_multi_llm_tasks()
|
||||
43
tests/unittest/llmapi/_run_multi_mpi_comm_tasks.py
Normal file
43
tests/unittest/llmapi/_run_multi_mpi_comm_tasks.py
Normal file
@ -0,0 +1,43 @@
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
import click
|
||||
|
||||
from tensorrt_llm.executor.utils import LlmLauncherEnvs
|
||||
from tensorrt_llm.llmapi.mpi_session import RemoteMpiCommSessionClient
|
||||
from tensorrt_llm.llmapi.utils import print_colored
|
||||
|
||||
|
||||
def run_task(task_type: Literal["submit", "submit_sync"]):
|
||||
tasks = range(10)
|
||||
assert os.environ[
|
||||
LlmLauncherEnvs.
|
||||
TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR] is not None, "TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR is not set"
|
||||
client = RemoteMpiCommSessionClient(
|
||||
os.environ[LlmLauncherEnvs.TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR])
|
||||
|
||||
for task in tasks:
|
||||
if task_type == "submit":
|
||||
client.submit(print_colored, f"{task}\n", "green")
|
||||
elif task_type == "submit_sync":
|
||||
res = client.submit_sync(print_colored, f"{task}\n", "green")
|
||||
print(res)
|
||||
|
||||
|
||||
def run_multi_tasks(task_type: Literal["submit", "submit_sync"]):
|
||||
for i in range(3):
|
||||
print_colored(f"Running MPI comm task {i}\n", "green")
|
||||
run_task(task_type)
|
||||
print_colored(f"MPI comm task {i} completed\n", "green")
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--task_type",
|
||||
type=click.Choice(["submit", "submit_sync"]),
|
||||
default="submit")
|
||||
def main(task_type: Literal["submit", "submit_sync"]):
|
||||
run_multi_tasks(task_type)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -5,6 +5,8 @@ import threading
|
||||
from subprocess import PIPE, Popen
|
||||
from typing import Literal
|
||||
|
||||
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
|
||||
@ -12,6 +14,11 @@ from tensorrt_llm.llmapi.mpi_session import (MPINodeState, MpiPoolSession,
|
||||
RemoteMpiCommSessionClient,
|
||||
split_mpi_env)
|
||||
|
||||
# isort: off
|
||||
sys.path.append(os.path.join(cur_dir, '..'))
|
||||
from utils.util import skip_single_gpu
|
||||
# isort: on
|
||||
|
||||
|
||||
def task0():
|
||||
if MPINodeState.state is None:
|
||||
@ -108,3 +115,54 @@ def task1():
|
||||
def test_split_mpi_env():
|
||||
session = MpiPoolSession(n_workers=4)
|
||||
session.submit_sync(task1)
|
||||
|
||||
|
||||
@skip_single_gpu
|
||||
@pytest.mark.parametrize(
|
||||
"task_script", ["_run_mpi_comm_task.py", "_run_multi_mpi_comm_tasks.py"])
|
||||
def test_llmapi_launch_multiple_tasks(task_script: str):
|
||||
"""
|
||||
Test that the trtllm-llmapi-launch can run multiple tasks.
|
||||
"""
|
||||
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
test_file = os.path.join(cur_dir, "_run_multi_llm_tasks.py")
|
||||
assert os.path.exists(test_file), f"Test file {test_file} does not exist"
|
||||
command = [
|
||||
"mpirun", "-n", "2", "--allow-run-as-root", "trtllm-llmapi-launch",
|
||||
"python3", test_file
|
||||
]
|
||||
print(' '.join(command))
|
||||
|
||||
with Popen(command,
|
||||
env=os.environ,
|
||||
stdout=PIPE,
|
||||
stderr=PIPE,
|
||||
bufsize=1,
|
||||
start_new_session=True,
|
||||
universal_newlines=True,
|
||||
cwd=os.path.dirname(os.path.abspath(__file__))) as process:
|
||||
# Function to read from a stream and write to output
|
||||
def read_stream(stream, output_stream):
|
||||
for line in stream:
|
||||
output_stream.write(line)
|
||||
output_stream.flush()
|
||||
|
||||
# Create threads to read stdout and stderr concurrently
|
||||
stdout_thread = threading.Thread(target=read_stream,
|
||||
args=(process.stdout, sys.stdout))
|
||||
stderr_thread = threading.Thread(target=read_stream,
|
||||
args=(process.stderr, sys.stderr))
|
||||
|
||||
# Start both threads
|
||||
stdout_thread.start()
|
||||
stderr_thread.start()
|
||||
|
||||
# Wait for the process to complete
|
||||
return_code = process.wait()
|
||||
|
||||
# Wait for both threads to finish reading
|
||||
stdout_thread.join()
|
||||
stderr_thread.join()
|
||||
|
||||
if return_code != 0:
|
||||
raise subprocess.CalledProcessError(return_code, command)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user