TensorRT-LLMs/tensorrt_llm/executor/ipc.py
Yan Chunwei 85406f9dda
[https://nvbugs/5720482][fix] Fix test rpc streaming (#9902)
Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com>
2025-12-13 01:14:43 -08:00

528 lines
19 KiB
Python

import asyncio
import hashlib
import hmac
import os
import pickle # nosec B403
import threading
import time
import traceback
from queue import Queue
from typing import Any, Optional
import zmq
import zmq.asyncio
from tensorrt_llm.logger import logger
from .._utils import nvtx_mark, nvtx_range_debug
from ..llmapi.utils import (ManagedThread, enable_llm_debug, logger_debug,
print_colored)
class ZeroMqQueue:
''' A Queue-like container for IPC using ZeroMQ. '''
socket_type_str = {
zmq.PAIR: "PAIR",
zmq.PULL: "PULL",
zmq.PUSH: "PUSH",
zmq.ROUTER: "ROUTER",
zmq.DEALER: "DEALER",
}
def __init__(self,
address: Optional[tuple[str, Optional[bytes]]] = None,
*,
socket_type: int = zmq.PAIR,
is_server: bool,
is_async: bool = False,
name: Optional[str] = None,
use_hmac_encryption: bool = True):
'''
Parameters:
address (tuple[str, Optional[bytes]], optional): The address (tcp-ip_port, hmac_auth_key) for the IPC. Defaults to None. If hmac_auth_key is None and use_hmac_encryption is False, the queue will not use HMAC encryption.
socket_type (int): The type of socket to use. Defaults to zmq.PAIR.
is_server (bool): Whether the current process is the server or the client.
is_async (bool): Whether to use asyncio for the socket. Defaults to False.
name (str, optional): The name of the queue. Defaults to None.
use_hmac_encryption (bool): Whether to use HMAC encryption for pickled data. Defaults to True.
'''
self.socket_type = socket_type
self.address_endpoint = address[
0] if address is not None else "tcp://127.0.0.1:*"
self.is_server = is_server
self.context = zmq.Context() if not is_async else zmq.asyncio.Context()
self.poller = None
self.socket = None
self._setup_done = False
self.name = name
self.socket = self.context.socket(socket_type)
self.socket.set_hwm(0)
# For ROUTER sockets, track the last identity to enable replies. For now we assume there is only one client in our case.
self._last_identity = None
self.hmac_key = address[1] if address is not None else None
self.use_hmac_encryption = use_hmac_encryption
self._setup_lock = threading.Lock()
# Thread safety debugging
self._zmq_thread_id = None
self._zmq_debug_enabled = os.environ.get('TLLM_LLMAPI_ZMQ_DEBUG',
'0') != '0'
# Check HMAC key condition
if self.use_hmac_encryption and not self.is_server and self.hmac_key is None:
raise ValueError(
"Client must receive HMAC key when encryption is enabled")
elif not self.use_hmac_encryption and self.hmac_key is not None:
raise ValueError(
"Server and client should not receive HMAC key when encryption is disabled"
)
if (socket_type == zmq.PAIR and self.is_server
) or socket_type == zmq.PULL or socket_type == zmq.ROUTER:
self.socket.bind(
self.address_endpoint
) # Binds to the address and occupy a port immediately
self.address_endpoint = self.socket.getsockopt(
zmq.LAST_ENDPOINT).decode()
logger_debug(
f"Server [{name}] bound to {self.address_endpoint} in {self.socket_type_str[socket_type]}\n",
"green")
if self.use_hmac_encryption and not self.hmac_key:
# Initialize HMAC key for pickle encryption
logger.info(f"Generating a new HMAC key for server {self.name}")
self.hmac_key = os.urandom(32)
self.address = (self.address_endpoint, self.hmac_key)
def setup_lazily(self):
# Early return if setup is already done
if self._setup_done:
return
with self._setup_lock:
if self._setup_done:
return
self._setup_done = True
if not self.is_server:
logger_debug(
f"Client [{self.name}] connecting to {self.address_endpoint} in {self.socket_type_str[self.socket_type]}\n",
"green")
self.socket.connect(self.address_endpoint)
self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)
def _check_thread_safety(self):
"""Check if the current thread is the same as the thread that first used the socket."""
if not self._zmq_debug_enabled:
return
current_thread_id = threading.get_ident()
if self._zmq_thread_id is None:
# First call - capture the thread ID
self._zmq_thread_id = current_thread_id
logger_debug(
f"ZMQ socket [{self.name}] initialized on thread {current_thread_id}",
"cyan")
elif self._zmq_thread_id != current_thread_id:
# Thread mismatch - raise error
raise RuntimeError(
f"ZMQ thread safety violation detected in [{self.name}]: "
f"Socket created on thread {self._zmq_thread_id}, "
f"but accessed from thread {current_thread_id}. "
f"ZMQ sockets are not thread-safe!")
def poll(self, timeout: int) -> bool:
"""
Parameters:
timeout (int): Timeout in seconds
"""
self.setup_lazily()
self._check_thread_safety()
events = dict(self.poller.poll(timeout=timeout * 1000))
if self.socket in events and events[self.socket] == zmq.POLLIN:
return True
else:
return False
def put(self, obj: Any, routing_id: Optional[bytes] = None):
self.setup_lazily()
self._check_thread_safety()
with nvtx_range_debug("send", color="blue", category="IPC"):
if self.use_hmac_encryption or self.socket_type == zmq.ROUTER:
# Need manual serialization for encryption or ROUTER multipart
data = self._prepare_data(obj)
self._send_data(data, routing_id=routing_id)
else:
# Standard socket without encryption - use pyobj directly
self.socket.send_pyobj(obj)
def put_noblock(self,
obj: Any,
*,
retry: int = 1,
wait_time: float = 0.001):
'''
Put an object into the queue without blocking, and retry if the send fails.
NOTE: It won't raise any error if the send fails.
Parameters:
obj (Any): The object to send.
retry (int): The number of times to retry sending the object.
wait_time (float): The time to wait before retrying.
'''
assert retry >= 0 and retry <= 10, "Retry must be between 0 and 10, adjust the wait_time if needed"
self.setup_lazily()
self._check_thread_safety()
with nvtx_range_debug("send", color="blue", category="IPC"):
data = self._prepare_data(obj)
try:
self._send_data(data, flags=zmq.NOBLOCK)
except zmq.Again:
if retry > 0:
time.sleep(wait_time)
self.put_noblock(obj, retry=retry - 1, wait_time=wait_time)
else:
logger.error(f"Failed to send object: {obj}")
async def put_async(self, obj: Any, routing_id: Optional[bytes] = None):
self.setup_lazily()
self._check_thread_safety()
try:
if self.use_hmac_encryption or self.socket_type == zmq.ROUTER:
# Need manual serialization for encryption or ROUTER multipart
data = self._prepare_data(obj)
await self._send_data_async(data, routing_id=routing_id)
else:
# Standard socket without encryption
await self.socket.send_pyobj(obj)
except TypeError as e:
logger.error(f"Cannot pickle {obj}")
raise e
except Exception as e:
logger.error(f"Error sending object: {e}")
logger.error(traceback.format_exc())
raise e
nvtx_mark("ipc.send", color="blue", category="IPC")
async def put_async_noblock(self, obj: Any):
self.setup_lazily()
self._check_thread_safety()
try:
if self.use_hmac_encryption:
data = pickle.dumps(obj) # nosec B301
signed_data = self._sign_data(data)
await self.socket.send(signed_data, flags=zmq.NOBLOCK)
else:
await self.socket.send_pyobj(obj, flags=zmq.NOBLOCK)
except Exception as e:
logger.error(f"Error sending object: {e}")
logger.error(traceback.format_exc())
raise e
def get(self) -> Any:
self.setup_lazily()
self._check_thread_safety()
return self._recv_data()
async def get_async(self) -> Any:
self.setup_lazily()
self._check_thread_safety()
return await self._recv_data_async()
async def get_async_noblock(self,
timeout: float = 0.5,
return_identity: bool = False) -> Any:
"""Get data with timeout using polling to avoid message drops.
This method uses ZMQ's NOBLOCK flag with polling instead of asyncio.wait_for
to prevent cancelling recv operations which can cause message drops.
Args:
timeout: Timeout in seconds
return_identity: Whether to return the identity of the sender (for ROUTER sockets)
Returns:
The received object, or (object, identity) if return_identity is True
Raises:
asyncio.TimeoutError: If timeout is reached without receiving data
"""
self.setup_lazily()
self._check_thread_safety()
# Use polling loop instead of asyncio.wait_for to avoid cancelling recv
# which can cause message drops
deadline = asyncio.get_event_loop().time() + timeout
while True:
try:
# Try non-blocking receive
if self.socket_type == zmq.ROUTER:
identity, data = await self.socket.recv_multipart(
flags=zmq.NOBLOCK)
self._last_identity = identity
obj = self._parse_data(data)
if return_identity:
return obj, identity
else:
return obj
else:
if self.use_hmac_encryption:
data = await self.socket.recv(flags=zmq.NOBLOCK)
obj = self._parse_data(data)
else:
obj = await self.socket.recv_pyobj(flags=zmq.NOBLOCK)
if return_identity:
return obj, None
else:
return obj
except zmq.Again:
# No message available yet
if asyncio.get_event_loop().time() >= deadline:
raise asyncio.TimeoutError()
# Short sleep to avoid busy-waiting
await asyncio.sleep(0.01)
def close(self):
if self.socket:
self.socket.close()
self.socket = None
if self.context:
self.context.term()
self.context = None
def _verify_hmac(self, data: bytes, actual_hmac: bytes) -> bool:
"""Verify the HMAC of received pickle data."""
expected_hmac = hmac.new(self.hmac_key, data, hashlib.sha256).digest()
return hmac.compare_digest(expected_hmac, actual_hmac)
def _sign_data(self, data_before_encoding: bytes) -> bytes:
"""Generate HMAC for data."""
hmac_signature = hmac.new(self.hmac_key, data_before_encoding,
hashlib.sha256).digest()
return data_before_encoding + hmac_signature
def __del__(self):
self.close()
def _prepare_data(self, obj: Any) -> bytes:
"""Serialize object and optionally add HMAC signature."""
data = pickle.dumps(obj) # nosec B301
if self.use_hmac_encryption:
return self._sign_data(data)
return data
def _parse_data(self, data: bytes) -> Any:
"""Parse data and optionally verify HMAC signature."""
if self.use_hmac_encryption:
# Split data and HMAC
message_data = data[:-32]
actual_hmac = data[-32:]
# Verify HMAC
if not self._verify_hmac(message_data, actual_hmac):
raise RuntimeError("HMAC verification failed")
return pickle.loads(message_data) # nosec B301
else:
return pickle.loads(data) # nosec B301
def _send_data(self,
data: bytes,
flags: int = 0,
routing_id: Optional[bytes] = None):
"""Send data using appropriate API based on socket type."""
if self.socket_type == zmq.ROUTER:
identity = routing_id if routing_id is not None else self._last_identity
if identity is None:
raise ValueError("ROUTER socket requires identity")
self.socket.send_multipart([identity, data], flags=flags)
else:
self.socket.send(data, flags=flags)
async def _send_data_async(self,
data: bytes,
routing_id: Optional[bytes] = None):
"""Async version of _send_data."""
if self.socket_type == zmq.ROUTER:
identity = routing_id if routing_id is not None else self._last_identity
if identity is None:
raise ValueError("ROUTER socket requires identity")
await self.socket.send_multipart([identity, data])
else:
await self.socket.send(data)
def _recv_data(self, return_identity: bool = False) -> Any:
"""Receive data using appropriate API based on socket type."""
if self.socket_type == zmq.ROUTER:
identity, data = self.socket.recv_multipart()
self._last_identity = identity # Store for replies
obj = self._parse_data(data)
if return_identity:
return obj, identity
return obj
else:
if self.use_hmac_encryption:
data = self.socket.recv()
obj = self._parse_data(data)
else:
obj = self.socket.recv_pyobj()
if return_identity:
return obj, None
return obj
async def _recv_data_async(self, return_identity: bool = False) -> Any:
"""Async version of _recv_data."""
if self.socket_type == zmq.ROUTER:
identity, data = await self.socket.recv_multipart()
self._last_identity = identity # Store for replies
obj = self._parse_data(data)
if return_identity:
return obj, identity
return obj
else:
if self.use_hmac_encryption:
data = await self.socket.recv()
obj = self._parse_data(data)
else:
obj = await self.socket.recv_pyobj()
if return_identity:
return obj, None
return obj
def notify_with_retry(self, message, max_retries=5, timeout=1):
"""
Notify with automatic retry on failure (for DEALER socket pattern).
Args:
message: Message to send
max_retries: Maximum retry attempts (default: 5)
timeout: Timeout in seconds for each attempt (default: 1)
Returns:
bool: True if acknowledgment received, False if failed after all retries
"""
if self.socket_type != zmq.DEALER:
raise ValueError(
"notify_with_retry is only supported for DEALER socket for now")
self._check_thread_safety()
retry_count = 0
while retry_count < max_retries:
try:
self.put(message)
# Wait for ACK with timeout
if self.poll(timeout):
self.get()
return True
else:
retry_count += 1
except Exception as e:
logger.error(f"Failed to notify with retry: {e}")
retry_count += 1
return False
IpcQueue = ZeroMqQueue
class FusedIpcQueue:
''' A Queue-like container for IPC with optional message batched. '''
def __init__(self,
address: Optional[tuple[str, Optional[bytes]]] = None,
*,
is_server: bool,
fuse_message=False,
fuse_size=100000,
error_queue=None,
queue_cls=ZeroMqQueue,
**kwargs):
self.queue = queue_cls(address=address, is_server=is_server, **kwargs)
self.fuse_message = fuse_message
self.error_queue = error_queue
self.fuse_size = fuse_size
self._message_counter = 0
self._obj_counter = 0
self._send_thread = None
self.sending_queue = Queue() if fuse_message else None
def setup_sender(self):
if not self.fuse_message or self._send_thread is not None:
return
def send_task():
while True:
qsize = self.sending_queue.qsize()
if qsize > 0:
qsize = min(self.fuse_size, qsize)
self._obj_counter += qsize
message = [
self.sending_queue.get_nowait() for _ in range(qsize)
]
self.queue.put(message)
self._message_counter += 1
else:
time.sleep(0.001)
self._send_thread = ManagedThread(send_task,
name="fused_send_thread",
error_queue=self.error_queue)
self._send_thread.start()
def put(self, obj: Any):
self.setup_sender()
if self.fuse_message:
self.sending_queue.put_nowait(obj)
else:
batch = obj if isinstance(obj, list) else [obj]
self.queue.put(batch)
def get(self) -> Any:
return self.queue.get()
@property
def address(self) -> tuple[str, Optional[bytes]]:
return self.queue.address
def __del__(self):
self.close()
def print_fuse_stats(self):
if self._message_counter > 0:
print_colored(
f"IPCQueue: {self._message_counter} messages, {self._obj_counter} objects sent, average: {self._obj_counter/self._message_counter}.\n",
"green")
def close(self):
self.queue.close()
if self._send_thread is not None:
self._send_thread.stop()
self._send_thread.join()
self._send_thread = None
if enable_llm_debug():
self.print_fuse_stats()