TensorRT-LLMs/tensorrt_llm/executor/rpc/rpc_client.py
Yan Chunwei b86256eb54
[TRTLLM-9144][fix] enhance RPC robustness (#8711)
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com>
Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com>
Co-authored-by: Erin Ho <14718778+hchings@users.noreply.github.com>
2025-12-02 21:37:59 +08:00

720 lines
30 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import concurrent.futures
import os
import threading
import time
import uuid
from datetime import datetime
from typing import Any, AsyncIterator, Dict, Optional
import zmq
from tensorrt_llm._utils import (customized_gc_thresholds, nvtx_mark_debug,
nvtx_range_debug)
from ...llmapi.utils import (AsyncQueue, _SyncQueue, enable_llmapi_debug,
logger_debug)
from ...logger import logger
from ..ipc import ZeroMqQueue
from .rpc_common import (RPCCancelled, RPCParams, RPCRequest, RPCResponse,
RPCStreamingError, RPCTimeout)
class RemoteCall:
"""Helper class to enable chained remote call syntax like client.method().remote()"""
def __init__(self, client: 'RPCClient', method_name: str, *args, **kwargs):
self.client = client
self.method_name = method_name
self.args = args
self.kwargs = kwargs
def _prepare_and_call(self, timeout: Optional[float], need_response: bool,
mode: str, call_method: str) -> Any:
"""Common method to prepare RPC params and make the call.
Args:
timeout: Timeout for the RPC call
need_response: Whether a response is expected
mode: The RPC mode ("sync", "async", "future")
call_method: The method name to call on the client
Returns:
The result of the client method call
"""
rpc_params = RPCParams(timeout=timeout,
need_response=need_response,
mode=mode)
self.kwargs["__rpc_params"] = rpc_params
client_method = getattr(self.client, call_method)
return client_method(self.method_name, *self.args, **self.kwargs)
def remote(self,
timeout: Optional[float] = None,
need_response: bool = True) -> Any:
"""Synchronous remote call with optional RPC parameters."""
return self._prepare_and_call(timeout, need_response, "sync",
"_call_sync")
def remote_async(self,
timeout: Optional[float] = None,
need_response: bool = True):
"""Asynchronous remote call that returns a coroutine."""
return self._prepare_and_call(timeout, need_response, "async",
"_call_async")
def remote_future(self,
timeout: Optional[float] = None,
need_response: bool = True) -> concurrent.futures.Future:
"""Remote call that returns a Future object."""
return self._prepare_and_call(timeout, need_response, "future",
"_call_future")
def remote_streaming(self,
timeout: Optional[float] = None) -> AsyncIterator[Any]:
"""Remote call for streaming results."""
# Streaming always needs a response
return self._prepare_and_call(timeout, True, "async", "_call_streaming")
class RPCClient:
"""
An RPC Client that connects to the RPCServer.
"""
def __init__(self,
address: str,
hmac_key=None,
timeout: Optional[float] = None,
num_workers: int = 4):
'''
Args:
address: The ZMQ address to connect to.
hmac_key: The HMAC key for encryption.
timeout: The timeout (seconds) for RPC calls.
num_workers: The number of workers for the RPC client.
'''
self._address = address
self._timeout = timeout
# Check if PAIR mode is enabled via environment variable
use_pair_mode = os.environ.get('TLLM_LLMAPI_ZMQ_PAIR', '0') != '0'
socket_type = zmq.PAIR if use_pair_mode else zmq.DEALER
if use_pair_mode:
logger_debug(
"[client] Using zmq.PAIR socket type for RPC communication")
self._client_socket = ZeroMqQueue(address=(address, hmac_key),
is_server=False,
is_async=True,
use_hmac_encryption=False,
socket_type=socket_type,
name="rpc_client")
self._pending_futures = {}
# map request_id to the queue for streaming responses
self._streaming_queues: Dict[str, AsyncQueue] = {}
self._streaming_queues_lock = threading.RLock(
) # Protect cross-thread access
self._reader_task = None
self._executor = concurrent.futures.ThreadPoolExecutor(
max_workers=num_workers, thread_name_prefix="rpc_client_worker")
self._server_stopped = False
self._closed = False
self._loop = None
self._loop_thread = None
self._reader_asyncio_task = None # Track the asyncio task for proper cancellation
self._loop_lock = threading.Lock(
) # Lock to protect event loop initialization
# Eagerly create the background event loop so that all subsequent
# RPC calls (sync or streaming) can assume it exists. This removes a
# race between the first streaming call (which previously created the
# loop lazily) and immediate fire-and-forget calls like `submit()`.
self._ensure_event_loop()
# Force ZeroMqQueue client connection during initialization
# This ensures the socket is connected before any RPC operations
self._client_socket.setup_lazily()
# Start the response reader eagerly to avoid race conditions with streaming
# This ensures the reader is processing responses before any RPC calls
self._start_response_reader_eagerly()
logger_debug(
f"[client] RPC Client initialized. Connected to {self._address}")
def shutdown_server(self):
"""Shutdown the server."""
if self._server_stopped:
return
self._rpc_shutdown().remote()
self._server_stopped = True
def close(self):
"""Gracefully close the client, cleaning up background tasks."""
if self._closed:
return
self._closed = True
logger_debug("[client] RPC Client closing")
# Notify any active streaming consumers so they don't hang waiting for
# further data. This must be done *before* shutting down the event
# loop/executor because they may depend on the loop to complete.
self._broadcast_streaming_error(RPCCancelled("RPC client closed"))
# 1. Cancel the reader task
if self._reader_task and not self._reader_task.done():
if self._loop and self._loop.is_running(
) and self._reader_asyncio_task:
try:
async def cancel_reader_task():
if self._reader_asyncio_task and not self._reader_asyncio_task.done(
):
self._reader_asyncio_task.cancel()
try:
await self._reader_asyncio_task
except asyncio.CancelledError:
pass
cancel_future = asyncio.run_coroutine_threadsafe(
cancel_reader_task(), self._loop)
cancel_future.result(timeout=2.0)
logger_debug("[client] Reader task cancelled successfully")
except concurrent.futures.TimeoutError:
logger.warning("Reader task did not exit gracefully")
except Exception as e:
logger_debug(f"[client] Reader task cleanup: {e}")
# 2. Stop the event loop
if self._loop and self._loop.is_running():
self._loop.call_soon_threadsafe(self._loop.stop)
# 3. Join the event loop thread
if self._loop_thread:
self._loop_thread.join(timeout=2.0)
if self._loop_thread.is_alive():
logger.warning("Event loop thread did not exit gracefully")
# 4. Shutdown the executor
if self._executor:
self._executor.shutdown(wait=True)
# 5. Close the socket
if self._client_socket:
self._client_socket.close()
logger_debug("[client] RPC Client closed")
def _handle_streaming_response(self, response: RPCResponse):
"""Handle a streaming response by putting it in the appropriate queue.
Args:
response: The streaming response to handle
"""
assert response.stream_status in [
'start', 'data', 'end', 'error'
], f"Invalid stream status: {response.stream_status}"
with self._streaming_queues_lock:
queue = self._streaming_queues.get(response.request_id)
if queue:
logger_debug(
f"[client] [{datetime.now().isoformat()}] Found streaming queue for response: request_id={response.request_id}, "
f"status={response.stream_status}")
if queue:
# put to the sync queue, as the current event loop is
# different from the one in call_async or call_streaming
assert isinstance(queue, AsyncQueue)
if enable_llmapi_debug() or logger.level == 'debug':
logger_debug(
f"[client] RPC Client putting response to AsyncQueue: status={response.stream_status}, request_id={response.request_id}"
)
queue.sync_q.put(response)
# Clean up if stream ended
if response.stream_status in ['end', 'error']:
with self._streaming_queues_lock:
self._streaming_queues.pop(response.request_id, None)
def _handle_regular_response(self, response: RPCResponse):
"""Handle a regular (non-streaming) response by setting the future result.
Args:
response: The response to handle
"""
if future_info := self._pending_futures.get(response.request_id):
future, target_loop = future_info
if not future.done():
def safe_set_result():
"""Safely set result on future, handling race conditions."""
try:
if not future.done():
if response.error is None:
future.set_result(response.result)
else:
future.set_exception(response.error)
except asyncio.InvalidStateError:
# Future was cancelled or completed between the check and set
# This is expected in high-load scenarios, just log and continue
if enable_llmapi_debug() or logger.level == 'debug':
logger_debug(
f"[client] Future already done for request_id: {response.request_id}, skipping"
)
if enable_llmapi_debug() or logger.level == 'debug':
if response.error is None:
logger_debug(
f"[client] Setting result for request_id: {response.request_id}"
)
else:
logger_debug(
f"[client] Setting exception for request_id: {response.request_id}, error: {response.error}"
)
target_loop.call_soon_threadsafe(safe_set_result)
else:
if enable_llmapi_debug() or logger.level == 'debug':
logger_debug(
f"[client] No future found for request_id: {response.request_id}"
)
self._pending_futures.pop(response.request_id, None)
async def _handle_reader_exception(self, exception: Exception):
"""Propagate an exception to all pending futures and streaming queues.
Args:
exception: The exception to propagate
"""
logger.error(f"Exception in RPC response reader: {exception}")
# Propagate exception to all pending futures
for (future, target_loop) in self._pending_futures.values():
if not future.done():
def safe_set_exception(f=future, exc=exception):
"""Safely set exception on future, handling race conditions."""
try:
if not f.done():
f.set_exception(exc)
except asyncio.InvalidStateError:
# Future was cancelled or completed, this is fine
pass
target_loop.call_soon_threadsafe(safe_set_exception)
# Propagate to streaming queues via common helper
self._broadcast_streaming_error(exception)
async def _wait_for_response(self) -> RPCResponse:
"""Wait for a response from the socket.
Returns:
RPCResponse from the server
"""
# Use timeout-based recv to handle cancellation gracefully
# This prevents the CancelledError from being logged as an exception
while True:
try:
# Short timeout allows periodic checks for cancellation
return await self._client_socket.get_async_noblock(timeout=2)
except asyncio.TimeoutError:
# Check if we should exit due to cancellation
if self._closed or (self._reader_asyncio_task
and self._reader_asyncio_task.cancelled()):
raise asyncio.CancelledError("Reader task cancelled")
# Otherwise continue polling
continue
async def _response_reader(self):
"""Task to read responses from the socket and set results on futures."""
logger_debug("[client] Response reader started")
self._reader_asyncio_task = asyncio.current_task()
try:
# Add initial delay to ensure socket is fully connected
# This helps prevent race conditions during initialization
await asyncio.sleep(0.1)
logger_debug("[client] Response reader ready to process messages")
with customized_gc_thresholds(10000):
last_alive_log = time.time()
while not self._closed:
# Periodic alive logging for debugging
if time.time() - last_alive_log > 5.0:
logger_debug(
"[client] Response reader is alive and waiting for responses"
)
last_alive_log = time.time()
with nvtx_range_debug("response_reader",
color="cyan",
category="RPC"):
try:
response = await self._wait_for_response()
logger_debug(
f"[client] [{datetime.now().isoformat()}] Received response: {response}"
)
nvtx_mark_debug(
f"RPC.response.{'streaming' if response.is_streaming else 'sync'}",
color="black",
category="RPC")
# Optimize: Check debug flag before expensive string operations
# This avoids holding GIL for f-string evaluation when debug is disabled
if enable_llmapi_debug() or logger.level == 'debug':
logger_debug(
f"[client] [{datetime.now().isoformat()}] RPC Client received response: request_id={response.request_id}, "
f"is_streaming={response.is_streaming}, "
f"pending_futures={len(self._pending_futures)}"
)
with nvtx_range_debug("handle_response",
color="purple",
category="RPC"):
if response.is_streaming:
self._handle_streaming_response(response)
else:
self._handle_regular_response(response)
except asyncio.CancelledError:
# Re-raise cancellation to exit cleanly
raise
except Exception as e:
# Log the error but continue reading unless it's a critical error
logger.error(f"Error processing response: {e}",
exc_info=True)
# For critical errors, propagate and exit
if isinstance(e, (ConnectionError, zmq.ZMQError)):
await self._handle_reader_exception(e)
break
# For other errors, try to continue
await asyncio.sleep(
0.1
) # Brief pause to avoid tight loop on repeated errors
except asyncio.CancelledError:
logger_debug("[client] Response reader cancelled")
finally:
logger_debug("[client] Response reader exiting gracefully")
self._reader_task = None
self._reader_asyncio_task = None
def _start_response_reader_eagerly(self):
"""Start the response reader immediately during initialization.
This ensures the reader is ready before any RPC calls are made,
preventing race conditions with streaming responses.
"""
if self._reader_task is not None and not self._reader_task.done():
return # Already running
try:
if self._loop and self._loop.is_running():
future = asyncio.run_coroutine_threadsafe(
self._response_reader(), self._loop)
self._reader_task = future
# Wait a bit to ensure the reader is actually processing
time.sleep(0.2)
logger_debug(
"[client] Response reader started eagerly during initialization"
)
else:
raise RuntimeError(
"Event loop not running when trying to start response reader"
)
except Exception as e:
logger.error(f"Failed to start response reader eagerly: {e}")
self._reader_task = None
raise
async def _call_async(self, method_name, *args, **kwargs):
"""Async version of RPC call.
Args:
method_name: Method name to call
*args: Positional arguments
**kwargs: Keyword arguments
__rpc_params: RPCParams object containing RPC parameters.
Returns:
The result of the remote method call
"""
if enable_llmapi_debug() or logger.level == 'debug':
logger_debug(
f"[client] [{datetime.now().isoformat()}] RPC client calling method: {method_name}"
)
nvtx_mark_debug(f"RPC.async.{method_name}",
color="yellow",
category="RPC")
if self._server_stopped:
raise RPCCancelled("Server is shutting down, request cancelled")
rpc_params = kwargs.pop("__rpc_params", RPCParams())
need_response = rpc_params.need_response
timeout = rpc_params.timeout if rpc_params.timeout is not None else self._timeout
request_id = uuid.uuid4().hex
request = RPCRequest(request_id,
method_name=method_name,
args=args,
kwargs=kwargs,
need_response=need_response,
timeout=timeout)
await self._client_socket.put_async(request)
logger_debug(f"[client] RPC Client sent request: {request}")
if not need_response:
return None
loop = asyncio.get_running_loop()
future = loop.create_future()
self._pending_futures[request_id] = (future, loop)
try:
# If timeout, the remote call should return a timeout error timely,
# so we add 1 second to the timeout to ensure the client can get
# that result.
if timeout is None:
res = await future
else:
# Add 1 second to the timeout to ensure the client can get
res = await asyncio.wait_for(future, timeout)
return res
except RPCCancelled:
self._server_stopped = True
raise
except asyncio.TimeoutError:
raise RPCTimeout(
f"Request '{method_name}' timed out after {timeout}s")
except Exception as e:
raise e
finally:
self._pending_futures.pop(request_id, None)
def _ensure_event_loop(self):
"""Create and start the background event loop.
This is called once during initialization to create the dedicated
event loop for all socket I/O operations.
"""
if self._loop is not None:
return # Already created
self._loop = asyncio.new_event_loop()
def run_loop():
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
self._loop_thread = threading.Thread(target=run_loop,
daemon=True,
name="rpc_client_loop")
self._loop_thread.start()
# Wait briefly to ensure the loop is running before returning
time.sleep(0.2)
def _call_sync(self, method_name, *args, **kwargs):
"""Synchronous version of RPC call."""
if enable_llmapi_debug() or logger.level == 'debug':
logger_debug(f"[client] RPC Client calling method: {method_name}")
nvtx_mark_debug(f"RPC.sync.{method_name}",
color="green",
category="RPC")
future = asyncio.run_coroutine_threadsafe(
self._call_async(method_name, *args, **kwargs), self._loop)
result = future.result()
return result
def _call_future(self, name: str, *args,
**kwargs) -> concurrent.futures.Future:
"""
Call a remote method and return a Future.
Args:
name: Method name to call
*args: Positional arguments
**kwargs: Keyword arguments
Returns:
A Future object that can be used to retrieve the result
"""
nvtx_mark_debug(f"RPC.future.{name}", color="blue", category="RPC")
def _async_to_sync():
future = asyncio.run_coroutine_threadsafe(
self._call_async(name, *args, **kwargs), self._loop)
return future.result()
return self._executor.submit(_async_to_sync)
async def _call_streaming(self, name: str, *args,
**kwargs) -> AsyncIterator[Any]:
"""
Call a remote async generator method and get streaming results.
Implementation note: The outgoing request is sent on the RPCClients
private event-loop to obey the single-loop rule. The returned items
are yielded in the callers loop via AsyncQueue, which is thread-safe.
"""
nvtx_mark_debug(f"RPC.streaming.{name}", color="red", category="RPC")
if self._server_stopped:
raise RPCCancelled("Server is shutting down, request cancelled")
rpc_params = kwargs.pop("__rpc_params", RPCParams())
timeout = rpc_params.timeout if rpc_params.timeout is not None else self._timeout
request_id = uuid.uuid4().hex
# Use AsyncQueue to ensure proper cross-thread communication
queue = AsyncQueue()
# Recreate sync_q with the current running loop for proper cross-thread communication
queue._sync_q = _SyncQueue(queue, asyncio.get_running_loop())
# Register queue with lock to ensure thread-safe access
with self._streaming_queues_lock:
self._streaming_queues[request_id] = queue
#logger_debug(f"[{datetime.now().isoformat()}] Registered streaming queue for request_id={request_id}")
# Build the RPCRequest object here it's pickle-able and small but
# *do not* touch the ZeroMQ socket from this (caller) event-loop.
request = RPCRequest(request_id,
method_name=name,
args=args,
kwargs=kwargs,
need_response=True,
timeout=timeout,
is_streaming=True)
# Send the request on the RPCClient's dedicated loop to guarantee that
# **all** socket I/O happens from exactly one thread / loop.
async def _send_streaming_request(req: RPCRequest):
"""Private helper executed in the client loop to put the request."""
logger_debug(
f"[client] [{datetime.now().isoformat()}] Sending streaming request: {req.method_name}, request_id={req.request_id}"
)
await self._client_socket.put_async(req)
logger_debug(
f"[client][{datetime.now().isoformat()}] Streaming request sent successfully: {req.method_name}, request_id={req.request_id}"
)
send_future = asyncio.run_coroutine_threadsafe(
_send_streaming_request(request), self._loop)
# Wait until the request is actually on the wire before entering the
# user-visible streaming loop. We wrap the concurrent.futures.Future so
# we can await it in the caller's asyncio context.
await asyncio.wrap_future(send_future)
try:
logger_debug(
f"[client] [{datetime.now().isoformat()}] Starting to read streaming responses for request_id={request_id}"
)
# Read streaming responses
while True:
if timeout is None:
response = await queue.get()
else:
response = await asyncio.wait_for(queue.get(),
timeout=timeout)
if enable_llmapi_debug() or logger.level == 'debug':
logger_debug(
f"[client] [{datetime.now().isoformat()}] RPC Client _call_streaming received [{response.stream_status}] response",
color="green")
if response.stream_status == 'start':
# Start of stream
continue
elif response.stream_status == 'data':
yield response.result
elif response.stream_status == 'end':
# End of stream
break
elif response.stream_status == 'error':
# Error in stream
if response.error:
raise response.error
else:
raise RPCStreamingError("Unknown streaming error")
except asyncio.TimeoutError:
raise RPCTimeout(
f"Streaming request '{name}' timed out after {timeout}s")
finally:
# Clean up
with self._streaming_queues_lock:
self._streaming_queues.pop(request_id, None)
def _broadcast_streaming_error(self, exc: Exception):
"""Send an error response to all pending streaming queues so that
any coroutines blocked in _call_streaming can exit promptly.
Args:
exc: The exception instance to propagate downstream.
"""
# Iterate over a copy because callbacks may mutate the dict
with self._streaming_queues_lock:
streaming_items = list(self._streaming_queues.items())
for request_id, queue in streaming_items:
if not isinstance(queue, AsyncQueue):
continue
try:
# Use the underlying sync_q to ensure cross-thread delivery
queue.sync_q.put(
RPCResponse(
request_id,
result=None,
error=exc,
is_streaming=True,
chunk_index=0,
stream_status='error',
))
except Exception as e:
# Best-effort; log and continue
if enable_llmapi_debug() or logger.level == 'debug':
logger_debug(
f"[client] [{datetime.now().isoformat()}] Failed to broadcast streaming error for {request_id}: {e}"
)
def get_server_attr(self, name: str):
""" Get the attribute of the RPC server.
This is mainly used for testing. """
return self._rpc_get_attr(name).remote()
def __getattr__(self, name):
"""
Magically handles calls to non-existent methods.
Returns a callable that when invoked returns a RemoteCall instance.
This enables the new syntax:
client.method(args).remote()
await client.method(args).remote_async()
client.method(args).remote_future()
async for x in client.method(args).remote_streaming()
"""
logger_debug(
f"[client] [{datetime.now().isoformat()}] RPC Client getting attribute: {name}"
)
def method_caller(*args, **kwargs):
return RemoteCall(self, name, *args, **kwargs)
return method_caller
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()