TensorRT-LLMs/tensorrt_llm/executor/rpc/rpc_server.py
Yan Chunwei 85406f9dda
[https://nvbugs/5720482][fix] Fix test rpc streaming (#9902)
Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com>
2025-12-13 01:14:43 -08:00

768 lines
32 KiB
Python

import asyncio
import inspect
import os
import threading
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Dict, List, Optional
import zmq
from ...llmapi.utils import logger_debug
from ...logger import logger
from ..ipc import ZeroMqQueue
from .rpc_common import (RPCCancelled, RPCError, RPCRequest, RPCResponse,
RPCStreamingError, RPCTimeout)
class RPCServer:
"""
An RPC Server that listens for requests and executes them concurrently.
"""
def __init__(self,
instance: Any,
hmac_key: Optional[bytes] = None,
num_workers: int = 4,
timeout: float = 0.5,
async_run_task: bool = False) -> None:
"""
Initializes the server with an instance.
Args:
instance: The instance whose methods will be exposed via RPC.
hmac_key (bytes, optional): HMAC key for encryption.
num_workers (int): Number of worker threads or worker tasks that help parallelize the task execution.
timeout (int): Timeout for RPC calls.
async_run_task (bool): Whether to run the task asynchronously.
NOTE: make num_workers larger or the remote() and remote_future() may
be blocked by the thread pool.
"""
self._instance = instance
self._hmac_key = hmac_key
self._num_workers = num_workers
self._address = None
self._timeout = timeout
self._client_socket = None
# Asyncio components
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._main_task: Optional[asyncio.Task] = None
self._worker_tasks: List[asyncio.Task] = []
self._shutdown_event: Optional[asyncio.Event] = None
self._server_thread: Optional[threading.Thread] = None
self._stop_event: threading.Event = threading.Event(
) # for thread-safe shutdown
self._num_pending_requests = 0
self._functions: Dict[str, Callable[..., Any]] = {
# Some built-in methods for RPC server
"_rpc_shutdown": lambda: self.shutdown(is_remote_call=True),
"_rpc_get_attr": lambda name: self.get_attr(name),
}
if async_run_task:
self._executor = ThreadPoolExecutor(
max_workers=num_workers, thread_name_prefix="rpc_server_worker")
else:
self._executor = None
self.register_instance(instance)
logger_debug(
f"[server] RPCServer initialized with {num_workers} workers.",
color="green")
@property
def address(self) -> str:
assert self._client_socket is not None, "Client socket is not bound"
return self._client_socket.address[0]
def __enter__(self) -> 'RPCServer':
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.shutdown()
def bind(self, address: str = "tcp://*:5555") -> None:
"""
Bind the server to the specified address.
Args:
address (str): The ZMQ address to bind the client-facing socket.
"""
self._address = address
# Check if PAIR mode is enabled via environment variable
use_pair_mode = os.environ.get('TLLM_LLMAPI_ZMQ_PAIR', '0') != '0'
socket_type = zmq.PAIR if use_pair_mode else zmq.ROUTER
if use_pair_mode:
logger_debug(
"[server] Using zmq.PAIR socket type for RPC communication")
self._client_socket = ZeroMqQueue(address=(address, self._hmac_key),
is_server=True,
is_async=True,
use_hmac_encryption=self._hmac_key
is not None,
socket_type=socket_type,
name="rpc_server")
logger.info(f"RPCServer is bound to {self._address}")
def shutdown(self, is_remote_call: bool = False) -> None:
"""Internal method to trigger server shutdown.
Args:
is_remote_call: Whether the shutdown is called by a remote call.
This should be True when client.server_shutdown() is called.
"""
# NOTE: shutdown is also a remote method, so it could be executed by
# a thread in a worker executor thread
if self._stop_event.is_set():
return
logger_debug(
"[server] RPCServer is shutting down. Terminating server immediately..."
)
# Set the stop event to True, this will trigger immediate shutdown
self._stop_event.set()
# Log pending requests that will be cancelled
logger_debug(
f"[server] RPCServer is shutting down: {self._num_pending_requests} pending requests will be cancelled"
)
# Signal asyncio shutdown event if available
if self._shutdown_event and self._loop:
self._loop.call_soon_threadsafe(self._shutdown_event.set)
if not is_remote_call:
# Block the thread until shutdown is finished
# 1. Cancel the main task gracefully which will trigger proper cleanup
if self._main_task and not self._main_task.done():
self._loop.call_soon_threadsafe(self._main_task.cancel)
# 2. Wait for the server thread to exit (this will wait for proper cleanup)
if self._server_thread and self._server_thread.is_alive():
logger_debug(
"[server] RPCServer is waiting for server thread to exit")
self._server_thread.join()
self._server_thread = None
logger_debug("[server] RPCServer thread joined")
# 3. Shutdown the executor immediately without waiting for tasks
if self._executor:
self._executor.shutdown(wait=False)
self._executor = None
# 4. Close the client socket
if self._client_socket:
self._client_socket.close()
else:
# if the shutdown is called by a remote call, this method itself will
# be executed in a executor thread, so we cannot join the server thread
logger_debug(
f"[server] RPC Server shutdown initiated: {self._num_pending_requests} pending requests will be cancelled"
)
logger_debug("[server] RPCServer is shutdown successfully",
color="yellow")
def register_function(self,
func: Callable[..., Any],
name: Optional[str] = None) -> None:
"""Exposes a single function to clients.
Args:
func: The function to register.
name: The name of the function. If not provided, the name of the function will be used.
"""
fname = name or func.__name__
if fname in self._functions:
logger.warning(
f"Function '{fname}' is already registered. Overwriting.")
self._functions[fname] = func
logger_debug(f"[server] Registered function: {fname}")
def register_instance(self, instance: Any) -> None:
"""Exposes all public methods of a class instance.
Args:
instance: The instance to register.
"""
logger_debug(
f"[server] Registering instance of class: {instance.__class__.__name__}"
)
for name in dir(instance):
if not name.startswith('_'):
attr = getattr(instance, name)
if callable(attr):
self.register_function(attr, name)
def get_attr(self, name: str) -> Any:
""" Get the attribute of the RPC server.
Args:
name: The name of the attribute to get.
"""
return getattr(self, name)
async def _drain_pending_requests(self) -> None:
"""Drain any remaining requests from the socket and send cancellation responses."""
if self._client_socket is None:
return
logger_debug("[server] Draining pending requests after shutdown")
drained_count = 0
# Give a short window to drain any in-flight requests
end_time = asyncio.get_event_loop().time() + 2
while asyncio.get_event_loop().time() < end_time:
try:
req, routing_id = await asyncio.wait_for(
self._client_socket.get_async_noblock(return_identity=True),
timeout=2)
req.routing_id = routing_id
drained_count += 1
logger_debug(f"[server] Draining request after shutdown: {req}")
# Send cancellation response
await self._send_error_response(
req,
RPCCancelled("Server is shutting down, request cancelled"))
except asyncio.TimeoutError:
# No more requests to drain
break
except Exception as e:
logger.debug(f"Error draining request: {e}")
break
if drained_count > 0:
logger_debug(
f"[server] Drained {drained_count} requests after shutdown")
async def _run_server(self) -> None:
"""Main server loop that handles incoming requests directly."""
assert self._client_socket is not None, "Client socket is not bound"
logger_debug("[server] RPC Server main loop started")
# Create worker tasks
for i in range(self._num_workers):
task = asyncio.create_task(self._process_requests())
self._worker_tasks.append(task)
try:
# Wait for all worker tasks to complete
await asyncio.gather(*self._worker_tasks)
except asyncio.CancelledError:
logger_debug("[server] RPC Server main loop cancelled")
# Cancel all worker tasks
for task in self._worker_tasks:
if not task.done():
task.cancel()
# Wait for all tasks to finish cancellation
await asyncio.gather(*self._worker_tasks, return_exceptions=True)
except Exception as e:
logger.error(f"RPC Server main loop error: {e}")
logger.error(traceback.format_exc())
finally:
logger_debug("[server] RPC Server main loop exiting")
# TODO optimization: resolve the sequential scheduling for the remote calls
# Suppose tons of submit remote call block the FIFO queue, and the later get_stats remote calls may be blocked
# There could be two dispatch modes:
# 1. (current) mix mode, share the same routine/pool
# 2. (promising) stream mode, specific remote_call -> stream -> specific routine/pool
# - get_stats() - 1, remote_call -> dedicated queue -> dedicated routine/pool
# - submit() - 3 -> dedicated queue -> dedicated routine/pool
# TODO potential optimization: for submit(), batch the ad-hoc requests in an interval like 5ms, reduce the IPC count
async def _send_error_response(self, req: RPCRequest,
error: Exception) -> None:
"""Send an error response for a request."""
if not req.need_response:
return
if req.is_streaming:
await self._client_socket.put_async(
RPCResponse(
req.request_id,
result=None,
error=error,
is_streaming=
True, # Important: mark as streaming so it gets routed correctly
stream_status='error'),
routing_id=req.routing_id)
logger_debug(
f"[server] Sent error response for request {req.request_id}",
color="green")
else:
await self._client_socket.put_async(RPCResponse(req.request_id,
result=None,
error=error),
routing_id=req.routing_id)
logger_debug(
f"[server] Sent error response for request {req.request_id}",
color="green")
async def _handle_shutdown_request(self, req: RPCRequest) -> bool:
"""Handle a request during shutdown. Returns True if handled."""
if not self._shutdown_event.is_set():
return False
# Allow shutdown methods to proceed
if req.method_name in ["_rpc_shutdown", "shutdown"]:
return False
# Send cancellation error for all other requests
await self._send_error_response(
req, RPCCancelled("Server is shutting down, request cancelled"))
# Decrement pending count
self._num_pending_requests -= 1
return True
async def _process_requests(self) -> None:
"""Process incoming requests directly from the socket."""
assert self._client_socket is not None, "Client socket is not bound"
while not self._shutdown_event.is_set():
try:
#logger_debug(f"[server] Worker waiting for request", color="green")
# Read request directly from socket with timeout
req, routing_id = await asyncio.wait_for(
self._client_socket.get_async_noblock(return_identity=True),
timeout=2)
req.routing_id = routing_id
logger_debug(f"[server] Worker got request: {req}",
color="green")
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
logger_debug("[server] RPC worker cancelled")
break
except Exception as e:
if self._shutdown_event.is_set():
break
logger.error(f"RPC worker caught an exception: {e}")
logger.error(traceback.format_exc())
continue
# shutdown methods depend on _num_pending_requests, so
# they should not be counted
if req.method_name not in ["_rpc_shutdown", "shutdown"]:
self._num_pending_requests += 1
logger_debug(
f"[server] Worker received request {req}, pending: {self._num_pending_requests}"
)
# Check if we should cancel due to shutdown
if await self._handle_shutdown_request(req):
continue
# Check if the method exists
if req.method_name not in self._functions:
logger.error(
f"Method '{req.method_name}' not found in RPC server.")
self._num_pending_requests -= 1
error = RPCStreamingError if req.is_streaming else RPCError
await self._send_error_response(
req,
error(
f"Method '{req.method_name}' not found in RPC server.",
traceback=traceback.format_exc()))
continue
func = self._functions[req.method_name]
# Final shutdown check before processing
if await self._handle_shutdown_request(req):
continue
# Process the request
if req.is_streaming:
if inspect.isasyncgenfunction(func):
await self._process_streaming_request(req)
else:
# Non-streaming function called with streaming flag
await self._send_error_response(
req,
RPCStreamingError(
f"Method '{req.method_name}' is not a streaming function."
))
else:
# Process regular request
response = await self._process_request(req)
# Send response if needed
if req.need_response and response is not None:
logger_debug(
f"[server] RPC Server sending response for request {req}, pending: {self._num_pending_requests}"
)
if await self._send_response(req, response):
logger_debug(
f"[server] RPC Server sent response for request {req}"
)
# Decrement pending count
if req.method_name not in ["_rpc_shutdown", "shutdown"]:
self._num_pending_requests -= 1
def _calculate_adjusted_timeout(self,
req: RPCRequest,
is_streaming: bool = False) -> float:
"""Calculate adjusted timeout based on pending overhead.
Args:
req: The RPC request
is_streaming: Whether this is for a streaming request
Returns:
The adjusted timeout value
"""
adjusted_timeout = req.timeout
if req.creation_timestamp is not None and req.timeout is not None and req.timeout > 0:
pending_time = time.time() - req.creation_timestamp
adjusted_timeout = max(0.1, req.timeout -
pending_time) # Keep at least 0.1s timeout
if pending_time > 0.1: # Only log if significant pending time
method_type = "streaming " if is_streaming else ""
logger_debug(
f"[server] RPC Server adjusted timeout for {method_type}{req.method_name}: "
f"original={req.timeout}s, pending={pending_time:.3f}s, adjusted={adjusted_timeout:.3f}s"
)
return adjusted_timeout
async def _process_request(self, req: RPCRequest) -> Optional[RPCResponse]:
"""Process a request. Returns None for streaming requests (handled separately)."""
func = self._functions[req.method_name]
# Calculate adjusted timeout based on pending overhead
adjusted_timeout = self._calculate_adjusted_timeout(req)
try:
if inspect.iscoroutinefunction(func):
# Execute async function directly in event loop, no need to run in executor due to the GIL
logger_debug(
f"[server] RPC Server running async task {req.method_name} in dispatcher"
)
result = await asyncio.wait_for(func(*req.args, **req.kwargs),
timeout=adjusted_timeout)
else:
# Execute sync function in thread executor
loop = asyncio.get_running_loop()
def call_with_kwargs():
return func(*req.args, **req.kwargs)
logger_debug(
f"[server] RPC Server running async task {req.method_name} in worker"
)
# TODO: let num worker control the pool size
result = await asyncio.wait_for(loop.run_in_executor(
self._executor, call_with_kwargs),
timeout=adjusted_timeout)
response = RPCResponse(req.request_id, result=result)
except asyncio.TimeoutError:
response = RPCResponse(
req.request_id,
result=None,
error=RPCTimeout(
f"Method '{req.method_name}' timed out after {req.timeout} seconds",
traceback=traceback.format_exc()))
except Exception as e:
response = RPCResponse(req.request_id,
result=None,
error=RPCError(
str(e),
cause=e,
traceback=traceback.format_exc()))
return response
async def _process_streaming_request(self, req: RPCRequest) -> None:
"""Process a streaming request by sending multiple responses."""
func = self._functions[req.method_name]
if not inspect.isasyncgenfunction(func):
await self._client_socket.put_async(RPCResponse(
req.request_id,
result=None,
error=RPCStreamingError(
f"Method '{req.method_name}' is not an async generator.",
traceback=traceback.format_exc()),
is_streaming=True,
stream_status='error'),
routing_id=req.routing_id)
return
chunk_index = 0
adjusted_timeout: float = self._calculate_adjusted_timeout(
req, is_streaming=True)
try:
logger_debug(
f"[server] RPC Server running streaming task {req.method_name}")
# Send start signal
await self._client_socket.put_async(RPCResponse(
req.request_id,
result=None,
error=None,
is_streaming=True,
chunk_index=chunk_index,
stream_status='start'),
routing_id=req.routing_id)
logger_debug(
f"[server] Sent start signal for request {req.request_id}",
color="green")
chunk_index += 1
# Apply timeout to the entire streaming operation if specified
if adjusted_timeout is not None and adjusted_timeout > 0:
# Create a task for the async generator with timeout
async def stream_with_timeout():
nonlocal chunk_index
async for result in func(*req.args, **req.kwargs):
if result is None or result == []:
# Skip None values or empty list to save bandwidth
# TODO[Superjomn]: add a flag to control this behavior
continue
# Check if shutdown was triggered
if self._shutdown_event.is_set():
raise RPCCancelled(
"Server is shutting down, streaming cancelled")
logger_debug(
f"[server] RPC Server got data and ready to send result {result}"
)
response = RPCResponse(req.request_id,
result=result,
error=None,
is_streaming=True,
chunk_index=chunk_index,
stream_status='data')
if not await self._send_response(req, response):
# Stop streaming after a pickle error
return
logger_debug(
f"[server] Sent response for request {req.request_id}",
color="green")
chunk_index += 1
# Use wait_for for timeout handling
await asyncio.wait_for(stream_with_timeout(),
timeout=adjusted_timeout)
else:
# No timeout specified, stream normally
async for result in func(*req.args, **req.kwargs):
if result is None or result == []:
continue # Skip None values or empty list
# Check if shutdown was triggered
if self._shutdown_event.is_set():
raise RPCCancelled(
"Server is shutting down, streaming cancelled")
logger_debug(
f"[server] RPC Server got data and ready to send result {result}"
)
response = RPCResponse(req.request_id,
result=result,
error=None,
is_streaming=True,
chunk_index=chunk_index,
stream_status='data')
if not await self._send_response(req, response):
# Stop streaming after a pickle error
return
chunk_index += 1
# Send end signal
await self._client_socket.put_async(RPCResponse(
req.request_id,
result=None,
error=None,
is_streaming=True,
chunk_index=chunk_index,
stream_status='end'),
routing_id=req.routing_id)
logger_debug(
f"[server] Sent end signal for request {req.request_id}",
color="green")
except RPCCancelled as e:
# Server is shutting down, send cancelled error
await self._client_socket.put_async(RPCResponse(
req.request_id,
result=None,
error=e,
is_streaming=True,
chunk_index=chunk_index,
stream_status='error'),
routing_id=req.routing_id)
logger_debug(
f"[server] Sent error signal for request {req.request_id}",
color="green")
except asyncio.TimeoutError:
await self._client_socket.put_async(RPCResponse(
req.request_id,
result=None,
error=RPCTimeout(
f"Streaming method '{req.method_name}' timed out",
traceback=traceback.format_exc()),
is_streaming=True,
chunk_index=chunk_index,
stream_status='error'),
routing_id=req.routing_id)
except Exception as e:
response = RPCResponse(
req.request_id,
result=None,
error=RPCStreamingError(str(e),
traceback=traceback.format_exc()),
is_streaming=True,
chunk_index=chunk_index,
stream_status='error')
await self._send_response(req, response)
async def _send_response(self, req: RPCRequest,
response: RPCResponse) -> bool:
"""Safely sends a response, handling pickle errors."""
try:
await self._client_socket.put_async(response,
routing_id=req.routing_id)
logger_debug(f"[server] Sent response for request {req.request_id}",
color="green")
return True
except Exception as e:
logger.error(
f"Failed to pickle response for request {req.request_id}: {e}")
error_msg = f"Failed to pickle response: {e}"
if req.is_streaming:
error_cls = RPCStreamingError
chunk_index = response.chunk_index if response else None
error_response = RPCResponse(
req.request_id,
result=None,
error=error_cls(error_msg,
traceback=traceback.format_exc()),
is_streaming=True,
chunk_index=chunk_index,
stream_status='error')
else:
error_cls = RPCError
error_response = RPCResponse(
req.request_id,
result=None,
error=error_cls(error_msg,
traceback=traceback.format_exc()))
try:
await self._client_socket.put_async(error_response,
routing_id=req.routing_id)
logger_debug(
f"[server] Sent error response for request {req.request_id}",
color="green")
except Exception as e_inner:
logger.error(
f"Failed to send error response for request {req.request_id}: {e_inner}"
)
return False
def start(self) -> None:
"""Binds sockets, starts workers, and begins proxying messages."""
if self._client_socket is None:
raise RuntimeError(
"Server must be bound to an address before starting. Call bind() first."
)
self._client_socket.setup_lazily()
logger.info(f"RPC Server started and listening on {self._address}")
# Create and configure the event loop
self._loop = asyncio.new_event_loop()
self._shutdown_event = asyncio.Event()
async def run_server():
"""Run the server until shutdown."""
try:
await self._run_server()
except asyncio.CancelledError:
logger_debug("[server] Server task cancelled")
except Exception as e:
logger.error(f"Server error: {e}")
logger.error(traceback.format_exc())
finally:
# Cancel all worker tasks
for task in self._worker_tasks:
if not task.done():
task.cancel()
# Wait for all tasks to complete
if self._worker_tasks:
await asyncio.gather(*self._worker_tasks,
return_exceptions=True)
# Drain any remaining requests and send cancellation responses
await self._drain_pending_requests()
logger_debug("[server] All server tasks completed")
self._main_task = self._loop.create_task(run_server())
def run_loop():
asyncio.set_event_loop(self._loop)
try:
self._loop.run_until_complete(self._main_task)
except RuntimeError as e:
# This can happen if the event loop is stopped while futures are pending
error_str = str(e)
if "Event loop stopped before Future completed" in error_str:
# This is expected during shutdown - ignore it
logger.debug(
f"[server] Expected shutdown error: {error_str}")
else:
# This is an unexpected RuntimeError - log full details
import traceback
logger.error(f"Event loop error: {error_str}")
logger.error(f"Traceback: {traceback.format_exc()}")
except Exception as e:
logger.error(f"Event loop error: {e}")
finally:
# Clean up any remaining tasks
pending = asyncio.all_tasks(self._loop)
for task in pending:
task.cancel()
if pending:
try:
self._loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True))
except RuntimeError:
# Event loop might already be closed
pass
self._loop.close()
self._server_thread = threading.Thread(target=run_loop,
name="rpc_server_thread",
daemon=True)
self._server_thread.start()
logger.info("RPC Server has started.")
Server = RPCServer