import asyncio import concurrent.futures import os from concurrent.futures import ProcessPoolExecutor from queue import Empty, Queue from typing import Any, Callable, List, NamedTuple, Optional from tensorrt_llm._utils import mpi_rank from tensorrt_llm.llmapi.utils import print_colored_debug from tensorrt_llm.logger import logger from ..llmapi.mpi_session import (MpiCommSession, MpiPoolSession, MpiSession, RemoteMpiCommSessionClient) from ..llmapi.utils import print_colored_debug PERIODICAL_RESP_IN_AWAIT = os.getenv( "TLLM_EXECUTOR_PERIODICAL_RESP_IN_AWAIT") == "1" def get_spawn_proxy_process_ipc_addr_env() -> str | None: ''' Get the IPC address for the spawn proxy process dynamically. ''' return os.getenv("TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR") def get_spawn_proxy_process_env() -> bool: ''' Get the environment variable for the spawn proxy process dynamically. ''' return os.getenv("TLLM_SPAWN_PROXY_PROCESS") == "1" if PERIODICAL_RESP_IN_AWAIT: logger.info("Using periodical responses in await_responses") def create_mpi_comm_session( n_workers: int) -> RemoteMpiCommSessionClient | MpiPoolSession: assert mpi_rank( ) == 0, f"create_mpi_comm_session must be called by rank 0, but it was called by rank {mpi_rank()}" if get_spawn_proxy_process_env(): assert get_spawn_proxy_process_ipc_addr_env( ), "TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR is not set." print_colored_debug( f"Using RemoteMpiPoolSessionClient to bind to external MPI processes at {get_spawn_proxy_process_ipc_addr_env()}\n", "yellow") return RemoteMpiCommSessionClient( get_spawn_proxy_process_ipc_addr_env()) else: print_colored_debug( f"Using MpiCommSession to bind to external MPI processes\n", "yellow") return MpiCommSession(n_workers=n_workers) def has_event_loop() -> bool: try: asyncio.get_running_loop() except RuntimeError: return False return True class RequestError(RuntimeError): ''' The error raised when the request is failed. ''' class ProcessPoolExecutorSession(MpiSession): # This process pool is introduced for better recoverable exceptions handling. # It replaces MpiPoolExecutor for single-gpu case. def __init__(self, n_workers: int, **kwargs): self.n_workers = n_workers self.mpi_pool = ProcessPoolExecutor(max_workers=self.n_workers, **kwargs) def submit(self, task: Callable, *args, **kwargs) -> List[concurrent.futures.Future]: return [ self.mpi_pool.submit(task, *args, **kwargs) for i in range(self.n_workers) ] def submit_sync(self, task: Callable, *args, **kwargs) -> List[Any]: futures = [ self.mpi_pool.submit(task, *args, **kwargs) for i in range(self.n_workers) ] return [future.result() for future in futures] def shutdown(self): self.mpi_pool.shutdown(wait=True) class ErrorResponse(NamedTuple): client_id: int error_msg: str request_id: int class IntraProcessQueue: ''' A Queue-like container for IPC within the same process. ''' def __init__(self): self.queue = Queue() def put(self, obj: Any): self.queue.put(obj) def get(self, timeout=None) -> Any: return self.queue.get(timeout=timeout) def close(self): pass def poll(self, timeout=None) -> bool: try: # Try to get an item from the queue without blocking item = self.queue.get(timeout=timeout) # If successful, put the item back to not alter the state self.queue.put(item) return True except Empty: # If the queue thread is empty, return False return False class WorkerCommIpcAddrs(NamedTuple): ''' IPC addresses (str) and HMAC keys (bytes) for communication with the worker processes. ''' request_queue_addr: tuple[str, Optional[bytes]] request_error_queue_addr: tuple[str, Optional[bytes]] result_queue_addr: tuple[str, Optional[bytes]] stats_queue_addr: tuple[str, Optional[bytes]] kv_cache_events_queue_addr: tuple[str, Optional[bytes]]