mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> Co-authored-by: Tao Li @ NVIDIA <tali@nvidia.com>
136 lines
4.5 KiB
Python
136 lines
4.5 KiB
Python
import asyncio
|
|
import concurrent.futures
|
|
import os
|
|
from concurrent.futures import ProcessPoolExecutor
|
|
from queue import Empty, Queue
|
|
from typing import Any, Callable, List, NamedTuple, Optional
|
|
|
|
from tensorrt_llm._utils import mpi_rank
|
|
from tensorrt_llm.llmapi.utils import print_colored_debug
|
|
from tensorrt_llm.logger import logger
|
|
|
|
from ..llmapi.mpi_session import (MpiCommSession, MpiPoolSession, MpiSession,
|
|
RemoteMpiCommSessionClient)
|
|
from ..llmapi.utils import print_colored_debug
|
|
|
|
PERIODICAL_RESP_IN_AWAIT = os.getenv(
|
|
"TLLM_EXECUTOR_PERIODICAL_RESP_IN_AWAIT") == "1"
|
|
|
|
|
|
def get_spawn_proxy_process_ipc_addr_env() -> str | None:
|
|
''' Get the IPC address for the spawn proxy process dynamically. '''
|
|
return os.getenv("TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR")
|
|
|
|
|
|
def get_spawn_proxy_process_env() -> bool:
|
|
''' Get the environment variable for the spawn proxy process dynamically. '''
|
|
return os.getenv("TLLM_SPAWN_PROXY_PROCESS") == "1"
|
|
|
|
|
|
if PERIODICAL_RESP_IN_AWAIT:
|
|
logger.info("Using periodical responses in await_responses")
|
|
|
|
|
|
def create_mpi_comm_session(
|
|
n_workers: int) -> RemoteMpiCommSessionClient | MpiPoolSession:
|
|
assert mpi_rank(
|
|
) == 0, f"create_mpi_comm_session must be called by rank 0, but it was called by rank {mpi_rank()}"
|
|
if get_spawn_proxy_process_env():
|
|
assert get_spawn_proxy_process_ipc_addr_env(
|
|
), "TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR is not set."
|
|
print_colored_debug(
|
|
f"Using RemoteMpiPoolSessionClient to bind to external MPI processes at {get_spawn_proxy_process_ipc_addr_env()}\n",
|
|
"yellow")
|
|
hmac_key = os.getenv("TLLM_SPAWN_PROXY_PROCESS_IPC_HMAC_KEY")
|
|
# Convert the hex string to bytes
|
|
if hmac_key is not None:
|
|
hmac_key = bytes.fromhex(hmac_key)
|
|
return RemoteMpiCommSessionClient(
|
|
addr=get_spawn_proxy_process_ipc_addr_env(), hmac_key=hmac_key)
|
|
else:
|
|
print_colored_debug(
|
|
f"Using MpiCommSession to bind to external MPI processes\n",
|
|
"yellow")
|
|
return MpiCommSession(n_workers=n_workers)
|
|
|
|
|
|
def has_event_loop() -> bool:
|
|
try:
|
|
asyncio.get_running_loop()
|
|
except RuntimeError:
|
|
return False
|
|
return True
|
|
|
|
|
|
class RequestError(RuntimeError):
|
|
''' The error raised when the request is failed. '''
|
|
|
|
|
|
class ProcessPoolExecutorSession(MpiSession):
|
|
# This process pool is introduced for better recoverable exceptions handling.
|
|
# It replaces MpiPoolExecutor for single-gpu case.
|
|
|
|
def __init__(self, n_workers: int, **kwargs):
|
|
self.n_workers = n_workers
|
|
self.mpi_pool = ProcessPoolExecutor(max_workers=self.n_workers,
|
|
**kwargs)
|
|
|
|
def submit(self, task: Callable, *args,
|
|
**kwargs) -> List[concurrent.futures.Future]:
|
|
return [
|
|
self.mpi_pool.submit(task, *args, **kwargs)
|
|
for i in range(self.n_workers)
|
|
]
|
|
|
|
def submit_sync(self, task: Callable, *args, **kwargs) -> List[Any]:
|
|
futures = [
|
|
self.mpi_pool.submit(task, *args, **kwargs)
|
|
for i in range(self.n_workers)
|
|
]
|
|
return [future.result() for future in futures]
|
|
|
|
def shutdown(self):
|
|
self.mpi_pool.shutdown(wait=True)
|
|
|
|
|
|
class ErrorResponse(NamedTuple):
|
|
client_id: int
|
|
error_msg: str
|
|
request_id: int
|
|
|
|
|
|
class IntraProcessQueue:
|
|
''' A Queue-like container for IPC within the same process. '''
|
|
|
|
def __init__(self):
|
|
self.queue = Queue()
|
|
|
|
def put(self, obj: Any):
|
|
self.queue.put(obj)
|
|
|
|
def get(self, timeout=None) -> Any:
|
|
return self.queue.get(timeout=timeout)
|
|
|
|
def close(self):
|
|
pass
|
|
|
|
def poll(self, timeout=None) -> bool:
|
|
try:
|
|
# Try to get an item from the queue without blocking
|
|
item = self.queue.get(timeout=timeout)
|
|
# If successful, put the item back to not alter the state
|
|
self.queue.put(item)
|
|
return True
|
|
except Empty:
|
|
# If the queue thread is empty, return False
|
|
return False
|
|
|
|
|
|
class WorkerCommIpcAddrs(NamedTuple):
|
|
''' IPC addresses (str) and HMAC keys (bytes) for communication with the worker processes. '''
|
|
request_queue_addr: tuple[str, Optional[bytes]]
|
|
request_error_queue_addr: tuple[str, Optional[bytes]]
|
|
result_queue_addr: tuple[str, Optional[bytes]]
|
|
stats_queue_addr: tuple[str, Optional[bytes]]
|
|
kv_cache_events_queue_addr: tuple[str, Optional[bytes]]
|