TensorRT-LLMs/tensorrt_llm/executor/ipc.py
Yan Chunwei c5e803ba48
chore: code cleanup for error logging and SharedMemory in proxy.py (#3432)
* cleanup log

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

* remove shared-memory

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

* remove ExecutorResponse

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

* add assert for postproc

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

---------

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
2025-04-10 21:57:06 +08:00

210 lines
6.2 KiB
Python

import time
import traceback
from queue import Queue
from typing import Any, Optional
import zmq
import zmq.asyncio
from tensorrt_llm.logger import logger
from ..llmapi.utils import (ManagedThread, enable_llm_debug, nvtx_mark,
nvtx_range, print_colored, print_colored_debug)
class ZeroMqQueue:
''' A Queue-like container for IPC using ZeroMQ. '''
socket_type_str = {
zmq.PAIR: "PAIR",
zmq.PULL: "PULL",
zmq.PUSH: "PUSH",
}
def __init__(self,
address: Optional[str] = None,
*,
socket_type: int = zmq.PAIR,
is_server: bool,
is_async: bool = False,
name: Optional[str] = None):
'''
Parameters:
address (Tuple[str, str], optional): The address (tcp-ip_port, authkey) for the IPC. Defaults to None.
is_server (bool): Whether the current process is the server or the client.
'''
self.socket_type = socket_type
self.address = address or "tcp://127.0.0.1:*"
self.is_server = is_server
self.context = zmq.Context() if not is_async else zmq.asyncio.Context()
self.poller = None
self.socket = None
self._setup_done = False
self.name = name
self.socket_type = socket_type
self.socket = self.context.socket(socket_type)
if (socket_type == zmq.PAIR
and self.is_server) or socket_type == zmq.PULL:
self.socket.bind(
self.address
) # Binds to the address and occupy a port immediately
self.address = self.socket.getsockopt(zmq.LAST_ENDPOINT).decode()
print_colored_debug(
f"Server [{name}] bound to {self.address} in {self.socket_type_str[socket_type]}\n",
"green")
def setup_lazily(self):
if self._setup_done:
return
self._setup_done = True
if not self.is_server:
print_colored_debug(
f"Client [{self.name}] connecting to {self.address} in {self.socket_type_str[self.socket_type]}\n",
"green")
self.socket.connect(self.address)
self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)
def poll(self, timeout: int) -> bool:
"""
Parameters:
timeout (int): Timeout in seconds
"""
self.setup_lazily()
events = dict(self.poller.poll(timeout=timeout * 1000))
if self.socket in events and events[self.socket] == zmq.POLLIN:
return True
else:
return False
def put(self, obj: Any):
self.setup_lazily()
with nvtx_range("send", color="blue", category="IPC"):
self.socket.send_pyobj(obj)
async def put_async(self, obj: Any):
self.setup_lazily()
try:
await self.socket.send_pyobj(obj)
except TypeError as e:
logger.error(f"Cannot pickle {obj}")
raise e
except Exception as e:
logger.error(f"Error sending object: {e}")
logger.error(traceback.format_exc())
raise e
nvtx_mark("ipc.send", color="blue", category="IPC")
def get(self) -> Any:
self.setup_lazily()
return self.socket.recv_pyobj()
async def get_async(self) -> Any:
self.setup_lazily()
return await self.socket.recv_pyobj()
def close(self):
if self.socket:
self.socket.close()
self.socket = None
if self.context:
self.context.term()
self.context = None
def __del__(self):
self.close()
IpcQueue = ZeroMqQueue
class FusedIpcQueue:
''' A Queue-like container for IPC with optional message batched. '''
def __init__(self,
address: Optional[str] = None,
*,
is_server: bool,
fuse_message=False,
fuse_size=100000,
error_queue=None,
queue_cls=ZeroMqQueue,
**kwargs):
self.queue = queue_cls(address=address, is_server=is_server, **kwargs)
self.fuse_message = fuse_message
self.error_queue = error_queue
self.fuse_size = fuse_size
self._message_counter = 0
self._obj_counter = 0
self._send_thread = None
self.sending_queue = Queue() if fuse_message else None
def setup_sender(self):
if not self.fuse_message or self._send_thread is not None:
return
def send_task():
while True:
qsize = self.sending_queue.qsize()
if qsize > 0:
qsize = min(self.fuse_size, qsize)
self._obj_counter += qsize
message = [
self.sending_queue.get_nowait() for _ in range(qsize)
]
self.queue.put(message)
self._message_counter += 1
else:
time.sleep(0.001)
self._send_thread = ManagedThread(send_task,
name="fused_send_thread",
error_queue=self.error_queue)
self._send_thread.start()
def put(self, obj: Any):
self.setup_sender()
if self.fuse_message:
self.sending_queue.put_nowait(obj)
else:
batch = obj if isinstance(obj, list) else [obj]
self.queue.put(batch)
def get(self) -> Any:
return self.queue.get()
@property
def address(self) -> str:
return self.queue.address
def __del__(self):
self.close()
def print_fuse_stats(self):
if self._message_counter > 0:
print_colored(
f"IPCQueue: {self._message_counter} messages, {self._obj_counter} objects sent, average: {self._obj_counter/self._message_counter}.\n",
"green")
def close(self):
self.queue.close()
if self._send_thread is not None:
self._send_thread.stop()
self._send_thread.join()
self._send_thread = None
if enable_llm_debug():
self.print_fuse_stats()