TensorRT-LLMs/tensorrt_llm/executor/rpc/rpc_server.py
Yan Chunwei fb51de6c2e
[TRTLLM-8189][chore] enhance GenerationExecutor with RPC (part1) (#5543)
Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com>
Signed-off-by: chunweiy <chunweiy@nvidia.com>
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
Signed-off-by: chunweiy <328693+Superjomn@users.noreply.github.com>
2025-10-05 17:28:20 +08:00

518 lines
22 KiB
Python

import asyncio
import inspect
import queue
import threading
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
from ...llmapi.utils import ManagedThread, logger_debug
from ...logger import logger
from ..ipc import ZeroMqQueue
from .rpc_common import (RPCError, RPCRequest, RPCResponse, RPCStreamingError,
RPCTimeout)
class RPCServer:
"""
An RPC Server that listens for requests and executes them concurrently.
"""
def __init__(self,
instance,
hmac_key=None,
num_workers: int = 4,
timeout: float = 0.5,
async_run_task: bool = False):
"""
Initializes the server with an instance.
Args:
instance: The instance whose methods will be exposed via RPC.
hmac_key (bytes, optional): HMAC key for encryption.
num_workers (int): Number of worker threads or worker tasks that help parallelize the task execution.
timeout (int): Timeout for RPC calls.
async_run_task (bool): Whether to run the task asynchronously.
NOTE: make num_workers larger if there are some streaming tasks runs infinitely.
"""
self._instance = instance
self._hmac_key = hmac_key
self._num_workers = num_workers
self._address = None
self._timeout = timeout
self._client_socket = None
# set the stop event to True, and all the workers will exit
self._stop_event = threading.Event()
self._num_pending_requests = 0
self._functions = {
"_rpc_shutdown": lambda: self.shutdown(is_remote_call=True),
"_rpc_get_attr": lambda name: self.get_attr(name),
}
self._dispatcher_thread: Optional[ManagedThread] = None
if async_run_task:
self._executor = ThreadPoolExecutor(
max_workers=num_workers, thread_name_prefix="rpc_server_worker")
else:
self._executor = None
self._queue = None
# Automatically register the instance
self.register_instance(instance)
logger_debug(f"RPC Server initialized with {num_workers} workers.",
color="green")
@property
def address(self) -> str:
assert self._client_socket is not None, "Client socket is not bound"
return self._client_socket.address[0]
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.shutdown()
def bind(self, address="tcp://*:5555"):
"""
Bind the server to the specified address.
Args:
address (str): The ZMQ address to bind the client-facing socket.
"""
self._address = address
self._client_socket = ZeroMqQueue(address=(address, self._hmac_key),
is_server=True,
is_async=True,
use_hmac_encryption=False)
logger.info(f"RPC Server bound to {self._address}")
def shutdown(self, is_remote_call: bool = False):
"""Internal method to trigger server shutdown.
Args:
is_remote_call: Whether the shutdown is called by a remote call.
This should be True when client.server_shutdown() is called.
"""
# NOTE: shutdown is also a remote method, so it could be executed by
# a thread in a worker executor thread
if self._stop_event.is_set():
return
logger_debug(
"RPC Server shutdown signal received. Terminating server...")
# Set the stop event to True, this will trigger the dispatcher routine and
# the worker routine to prepare for exit, like stopping accepting new requests,
# and continue to process the pending requests.
self._stop_event.set()
# The worker routine should process the pending requests
logger_debug(
f"RPC Server shutdown: {self._num_pending_requests} pending requests"
)
while self._num_pending_requests > 0:
time.sleep(0.01)
logger_debug(f"RPC Server shutdown finished pending requests")
if not is_remote_call:
# Block the thread until shutdown is finished
# 1. Wait for the dispatcher thread to exit, so that no new requests are accepted
logger_debug(f"RPC Server dispatcher thread joining")
if self._dispatcher_thread:
self._dispatcher_thread.join()
self._dispatcher_thread = None
logger_debug(f"RPC Server dispatcher thread joined")
# 2. Wait for the executor to exit, it will wait for the pending requests to be processed
if self._executor:
self._executor.shutdown(wait=True)
self._executor = None
# 3. (Optionally) Close the client socket, this doesn't affect
# anything since zmq client will not timeout even if the target is not available
if self._client_socket:
self._client_socket.close()
else:
# if the shutdown is called by a remote call, this method itself will
# be executed in a executor thread, so we cannot join the dispatcher thread as
# the dispatcher thread is awaiting for the shutdown result.
logger_debug(
f"RPC Server to shutdown: {self._num_pending_requests} pending requests"
)
while self._num_pending_requests > 0:
time.sleep(0.01)
logger_debug(f"RPC Server shutdown finished pending requests")
def register_function(self, func, name=None):
"""Exposes a single function to clients."""
fname = name or func.__name__
if fname in self._functions:
logger.warning(
f"Function '{fname}' is already registered. Overwriting.")
self._functions[fname] = func
logger_debug(f"Registered function: {fname}")
def register_instance(self, instance):
"""Exposes all public methods of a class instance."""
logger_debug(
f"Registering instance of class: {instance.__class__.__name__}")
for name in dir(instance):
if not name.startswith('_'):
attr = getattr(instance, name)
if callable(attr):
self.register_function(attr, name)
def get_attr(self, name: str):
""" Get the attribute of the RPC server.
This is mainly used for testing. """
return getattr(self, name)
async def _dispatcher_routine(self, stop_event: threading.Event):
assert self._client_socket is not None, "Client socket is not bound"
assert self._queue is not None, "RPC queue is not initialized"
# Once shutdown, the dispatcher will exit first, and the workers will
# continue to process the pending requests.
while not stop_event.is_set():
try:
req: RPCRequest = await self._client_socket.get_async_noblock(
timeout=0.5)
logger_debug(f"RPC dispatcher got request: {req}")
except asyncio.TimeoutError:
await asyncio.sleep(0)
continue
except Exception as e:
logger.error(f"RPC dispatcher caught an exception: {e}")
logger.error(traceback.format_exc())
continue
await self._queue.put(req) # type: ignore
# shutdown methods depend on _num_pending_requests, so
# they should not be counted
if req.method_name not in ["_rpc_shutdown", "shutdown"]:
self._num_pending_requests += 1
logger_debug(
f"Dispatcher received request {req}, pending: {self._num_pending_requests}"
)
# TODO optimization: resolve the sequential scheduling for the remote calls
# Suppose tons of submit remote call block the FIFO queue, and the later get_stats remote calls may be blocked
# There could be two dispatch modes:
# 1. (current) mix mode, share the same routine/pool
# 2. (promising) stream mode, specific remote_call -> stream -> specific routine/pool
# - get_stats() - 1, remote_call -> dedicated queue -> dedicated routine/pool
# - submit() - 3 -> dedicated queue -> dedicated routine/pool
# TODO potential optimization: for submit(), batch the ad-hoc requests in an interval like 5ms, reduce the IPC count
async def _worker_routine(self, stop_event: threading.Event):
"""The routine executed by each worker thread."""
assert self._client_socket is not None, "Client socket is not bound"
assert self._queue is not None, "RPC queue is not initialized"
while (not stop_event.is_set()) or self._num_pending_requests > 0:
try:
req: RPCRequest = await asyncio.wait_for(
self._queue.get(), # type: ignore
timeout=self._timeout)
except asyncio.TimeoutError:
await asyncio.sleep(0)
continue
# check if the method name is in the functions
if req.method_name not in self._functions:
logger.error(
f"Method '{req.method_name}' not found in RPC server.")
self._num_pending_requests -= 1
if not req.need_response:
continue
if req.is_streaming:
await self._client_socket.put_async(
RPCResponse(
req.request_id,
None,
RPCStreamingError(
f"Method '{req.method_name}' not found in RPC server.",
traceback=traceback.format_exc()),
stream_status='error'))
else:
response = RPCResponse(
req.request_id,
None,
RPCError(
f"Method '{req.method_name}' not found in RPC server.",
traceback=traceback.format_exc()),
)
await self._client_socket.put_async(response)
continue
func = self._functions[req.method_name]
if req.is_streaming:
if inspect.isasyncgenfunction(func):
await self._process_streaming_request(req)
else:
# Non-streaming function called with streaming flag
response = RPCResponse(
req.request_id,
None,
RPCStreamingError(
f"Method '{req.method_name}' is not a streaming function."
),
# need to redirect the error to the client's streaming queue
is_streaming=True,
stream_status='error',
)
await self._client_socket.put_async(response)
else:
# Process regular request
response = await self._process_request(req)
# Some tasks don't need response, e.g. submit_request or shutdown
if req.need_response and response is not None:
logger_debug(
f"RPC Server sending response for request {req}, pending: {self._num_pending_requests}"
)
if await self._send_response(req, response):
logger_debug(
f"RPC Server sent response for request {req}")
# Only decrement if this request was counted in the first place
if req.method_name not in ["_rpc_shutdown", "shutdown"]:
self._num_pending_requests -= 1
def _calculate_adjusted_timeout(self,
req: RPCRequest,
is_streaming: bool = False) -> float:
"""Calculate adjusted timeout based on pending overhead.
Args:
req: The RPC request
is_streaming: Whether this is for a streaming request
Returns:
The adjusted timeout value
"""
adjusted_timeout = req.timeout
if req.creation_timestamp is not None and req.timeout is not None and req.timeout > 0:
pending_time = time.time() - req.creation_timestamp
adjusted_timeout = max(0.1, req.timeout -
pending_time) # Keep at least 0.1s timeout
if pending_time > 0.1: # Only log if significant pending time
method_type = "streaming " if is_streaming else ""
logger_debug(
f"RPC Server adjusted timeout for {method_type}{req.method_name}: "
f"original={req.timeout}s, pending={pending_time:.3f}s, adjusted={adjusted_timeout:.3f}s"
)
return adjusted_timeout
async def _process_request(self, req: RPCRequest) -> Optional[RPCResponse]:
"""Process a request. Returns None for streaming requests (handled separately)."""
func = self._functions[req.method_name]
# Calculate adjusted timeout based on pending overhead
adjusted_timeout = self._calculate_adjusted_timeout(req)
try:
if inspect.iscoroutinefunction(func):
# Execute async function directly in event loop, no need to run in executor due to the GIL
logger_debug(
f"RPC Server running async task {req.method_name} in dispatcher"
)
result = await asyncio.wait_for(func(*req.args, **req.kwargs),
timeout=adjusted_timeout)
else:
# Execute sync function in thread executor
loop = asyncio.get_running_loop()
def call_with_kwargs():
return func(*req.args, **req.kwargs)
logger_debug(
f"RPC Server running async task {req.method_name} in worker"
)
# TODO: let num worker control the pool size
result = await asyncio.wait_for(loop.run_in_executor(
self._executor, call_with_kwargs),
timeout=adjusted_timeout)
logger_debug(f"RPC Server returned result for request {req}")
response = RPCResponse(req.request_id, result)
except asyncio.TimeoutError:
response = RPCResponse(
req.request_id, None,
RPCTimeout(
f"Method '{req.method_name}' timed out after {req.timeout} seconds",
traceback=traceback.format_exc()))
except Exception as e:
response = RPCResponse(
req.request_id, None,
RPCError(str(e), cause=e, traceback=traceback.format_exc()))
return response
async def _process_streaming_request(self, req: RPCRequest):
"""Process a streaming request by sending multiple responses."""
func = self._functions[req.method_name]
if not inspect.isasyncgenfunction(func):
await self._client_socket.put_async(
RPCResponse(
req.request_id,
None,
RPCStreamingError(
f"Method '{req.method_name}' is not an async generator.",
traceback=traceback.format_exc()),
# need to redirect the error to the client's streaming queue
stream_status='error'))
return
sequence_number = 0
# Calculate adjusted timeout based on pending overhead
adjusted_timeout = self._calculate_adjusted_timeout(req,
is_streaming=True)
try:
logger_debug(f"RPC Server running streaming task {req.method_name}")
# Send start signal
await self._client_socket.put_async(
RPCResponse(req.request_id, None, None, True, sequence_number,
'start'))
sequence_number += 1
# Apply timeout to the entire streaming operation if specified
if adjusted_timeout is not None and adjusted_timeout > 0:
# Create a task for the async generator with timeout
async def stream_with_timeout():
nonlocal sequence_number
async for result in func(*req.args, **req.kwargs):
logger_debug(
f"RPC Server got data and ready to send result {result}"
)
response = RPCResponse(req.request_id, result, None,
True, sequence_number, 'data')
if not await self._send_response(req, response):
# Stop streaming after a pickle error
return
sequence_number += 1
# Use wait_for for timeout handling
await asyncio.wait_for(stream_with_timeout(),
timeout=adjusted_timeout)
else:
# No timeout specified, stream normally
async for result in func(*req.args, **req.kwargs):
logger_debug(
f"RPC Server got data and ready to send result {result}"
)
response = RPCResponse(req.request_id, result, None, True,
sequence_number, 'data')
if not await self._send_response(req, response):
# Stop streaming after a pickle error
return
sequence_number += 1
# Send end signal
await self._client_socket.put_async(
RPCResponse(req.request_id, None, None, True, sequence_number,
'end'))
except asyncio.TimeoutError:
await self._client_socket.put_async(
RPCResponse(
req.request_id, None,
RPCTimeout(
f"Streaming method '{req.method_name}' timed out",
traceback=traceback.format_exc()), True,
sequence_number, 'error'))
except Exception as e:
response = RPCResponse(
req.request_id, None,
RPCStreamingError(str(e), traceback=traceback.format_exc()),
True, sequence_number, 'error')
await self._send_response(req, response)
async def _send_response(self, req: RPCRequest,
response: RPCResponse) -> bool:
"""Safely sends a response, handling pickle errors."""
try:
await self._client_socket.put_async(response)
return True
except Exception as e:
logger.error(
f"Failed to pickle response for request {req.request_id}: {e}")
error_msg = f"Failed to pickle response: {e}"
if req.is_streaming:
error_cls = RPCStreamingError
# For streaming, we also need sequence number. The original response has it.
sequence_number = response.sequence_number if response else None
error_response = RPCResponse(
req.request_id,
None,
error_cls(error_msg, traceback=traceback.format_exc()),
is_streaming=True,
sequence_number=sequence_number,
stream_status='error')
else:
error_cls = RPCError
error_response = RPCResponse(
req.request_id, None,
error_cls(error_msg, traceback=traceback.format_exc()))
try:
await self._client_socket.put_async(error_response)
except Exception as e_inner:
logger.error(
f"Failed to send error response for request {req.request_id}: {e_inner}"
)
return False
def start(self):
"""Binds sockets, starts workers, and begins proxying messages."""
if self._client_socket is None:
raise RuntimeError(
"Server must be bound to an address before starting. Call bind() first."
)
self._client_socket.setup_lazily()
logger.info(f"RPC Server started and listening on {self._address}")
async def tasks():
self._queue = asyncio.Queue()
await asyncio.gather(
self._dispatcher_routine(self._stop_event), *[
self._worker_routine(self._stop_event)
for i in range(self._num_workers)
])
def loop() -> bool:
asyncio.run(tasks())
return True # ManagedThread
error_queue = queue.Queue()
self._dispatcher_thread = ManagedThread(task=loop,
stop_event=self._stop_event,
name="rpc_dispatcher_thread",
error_queue=error_queue)
self._dispatcher_thread.start()
logger.info("RPC Server has started.")
Server = RPCServer