mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com> Signed-off-by: chunweiy <chunweiy@nvidia.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> Signed-off-by: chunweiy <328693+Superjomn@users.noreply.github.com>
518 lines
22 KiB
Python
518 lines
22 KiB
Python
import asyncio
|
|
import inspect
|
|
import queue
|
|
import threading
|
|
import time
|
|
import traceback
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import Optional
|
|
|
|
from ...llmapi.utils import ManagedThread, logger_debug
|
|
from ...logger import logger
|
|
from ..ipc import ZeroMqQueue
|
|
from .rpc_common import (RPCError, RPCRequest, RPCResponse, RPCStreamingError,
|
|
RPCTimeout)
|
|
|
|
|
|
class RPCServer:
|
|
"""
|
|
An RPC Server that listens for requests and executes them concurrently.
|
|
"""
|
|
|
|
def __init__(self,
|
|
instance,
|
|
hmac_key=None,
|
|
num_workers: int = 4,
|
|
timeout: float = 0.5,
|
|
async_run_task: bool = False):
|
|
"""
|
|
Initializes the server with an instance.
|
|
|
|
Args:
|
|
instance: The instance whose methods will be exposed via RPC.
|
|
hmac_key (bytes, optional): HMAC key for encryption.
|
|
num_workers (int): Number of worker threads or worker tasks that help parallelize the task execution.
|
|
timeout (int): Timeout for RPC calls.
|
|
async_run_task (bool): Whether to run the task asynchronously.
|
|
|
|
NOTE: make num_workers larger if there are some streaming tasks runs infinitely.
|
|
"""
|
|
self._instance = instance
|
|
self._hmac_key = hmac_key
|
|
self._num_workers = num_workers
|
|
self._address = None
|
|
self._timeout = timeout
|
|
self._client_socket = None
|
|
|
|
# set the stop event to True, and all the workers will exit
|
|
self._stop_event = threading.Event()
|
|
|
|
self._num_pending_requests = 0
|
|
|
|
self._functions = {
|
|
"_rpc_shutdown": lambda: self.shutdown(is_remote_call=True),
|
|
"_rpc_get_attr": lambda name: self.get_attr(name),
|
|
}
|
|
self._dispatcher_thread: Optional[ManagedThread] = None
|
|
if async_run_task:
|
|
self._executor = ThreadPoolExecutor(
|
|
max_workers=num_workers, thread_name_prefix="rpc_server_worker")
|
|
else:
|
|
self._executor = None
|
|
|
|
self._queue = None
|
|
|
|
# Automatically register the instance
|
|
self.register_instance(instance)
|
|
|
|
logger_debug(f"RPC Server initialized with {num_workers} workers.",
|
|
color="green")
|
|
|
|
@property
|
|
def address(self) -> str:
|
|
assert self._client_socket is not None, "Client socket is not bound"
|
|
return self._client_socket.address[0]
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
self.shutdown()
|
|
|
|
def bind(self, address="tcp://*:5555"):
|
|
"""
|
|
Bind the server to the specified address.
|
|
|
|
Args:
|
|
address (str): The ZMQ address to bind the client-facing socket.
|
|
"""
|
|
self._address = address
|
|
self._client_socket = ZeroMqQueue(address=(address, self._hmac_key),
|
|
is_server=True,
|
|
is_async=True,
|
|
use_hmac_encryption=False)
|
|
logger.info(f"RPC Server bound to {self._address}")
|
|
|
|
def shutdown(self, is_remote_call: bool = False):
|
|
"""Internal method to trigger server shutdown.
|
|
|
|
Args:
|
|
is_remote_call: Whether the shutdown is called by a remote call.
|
|
This should be True when client.server_shutdown() is called.
|
|
"""
|
|
# NOTE: shutdown is also a remote method, so it could be executed by
|
|
# a thread in a worker executor thread
|
|
|
|
if self._stop_event.is_set():
|
|
return
|
|
|
|
logger_debug(
|
|
"RPC Server shutdown signal received. Terminating server...")
|
|
|
|
# Set the stop event to True, this will trigger the dispatcher routine and
|
|
# the worker routine to prepare for exit, like stopping accepting new requests,
|
|
# and continue to process the pending requests.
|
|
self._stop_event.set()
|
|
|
|
# The worker routine should process the pending requests
|
|
logger_debug(
|
|
f"RPC Server shutdown: {self._num_pending_requests} pending requests"
|
|
)
|
|
|
|
while self._num_pending_requests > 0:
|
|
time.sleep(0.01)
|
|
logger_debug(f"RPC Server shutdown finished pending requests")
|
|
|
|
if not is_remote_call:
|
|
# Block the thread until shutdown is finished
|
|
|
|
# 1. Wait for the dispatcher thread to exit, so that no new requests are accepted
|
|
logger_debug(f"RPC Server dispatcher thread joining")
|
|
if self._dispatcher_thread:
|
|
self._dispatcher_thread.join()
|
|
self._dispatcher_thread = None
|
|
logger_debug(f"RPC Server dispatcher thread joined")
|
|
|
|
# 2. Wait for the executor to exit, it will wait for the pending requests to be processed
|
|
if self._executor:
|
|
self._executor.shutdown(wait=True)
|
|
self._executor = None
|
|
|
|
# 3. (Optionally) Close the client socket, this doesn't affect
|
|
# anything since zmq client will not timeout even if the target is not available
|
|
if self._client_socket:
|
|
self._client_socket.close()
|
|
else:
|
|
# if the shutdown is called by a remote call, this method itself will
|
|
# be executed in a executor thread, so we cannot join the dispatcher thread as
|
|
# the dispatcher thread is awaiting for the shutdown result.
|
|
logger_debug(
|
|
f"RPC Server to shutdown: {self._num_pending_requests} pending requests"
|
|
)
|
|
|
|
while self._num_pending_requests > 0:
|
|
time.sleep(0.01)
|
|
logger_debug(f"RPC Server shutdown finished pending requests")
|
|
|
|
def register_function(self, func, name=None):
|
|
"""Exposes a single function to clients."""
|
|
fname = name or func.__name__
|
|
if fname in self._functions:
|
|
logger.warning(
|
|
f"Function '{fname}' is already registered. Overwriting.")
|
|
self._functions[fname] = func
|
|
logger_debug(f"Registered function: {fname}")
|
|
|
|
def register_instance(self, instance):
|
|
"""Exposes all public methods of a class instance."""
|
|
logger_debug(
|
|
f"Registering instance of class: {instance.__class__.__name__}")
|
|
for name in dir(instance):
|
|
if not name.startswith('_'):
|
|
attr = getattr(instance, name)
|
|
if callable(attr):
|
|
self.register_function(attr, name)
|
|
|
|
def get_attr(self, name: str):
|
|
""" Get the attribute of the RPC server.
|
|
This is mainly used for testing. """
|
|
return getattr(self, name)
|
|
|
|
async def _dispatcher_routine(self, stop_event: threading.Event):
|
|
assert self._client_socket is not None, "Client socket is not bound"
|
|
assert self._queue is not None, "RPC queue is not initialized"
|
|
|
|
# Once shutdown, the dispatcher will exit first, and the workers will
|
|
# continue to process the pending requests.
|
|
while not stop_event.is_set():
|
|
try:
|
|
req: RPCRequest = await self._client_socket.get_async_noblock(
|
|
timeout=0.5)
|
|
logger_debug(f"RPC dispatcher got request: {req}")
|
|
except asyncio.TimeoutError:
|
|
await asyncio.sleep(0)
|
|
continue
|
|
except Exception as e:
|
|
logger.error(f"RPC dispatcher caught an exception: {e}")
|
|
logger.error(traceback.format_exc())
|
|
continue
|
|
|
|
await self._queue.put(req) # type: ignore
|
|
|
|
# shutdown methods depend on _num_pending_requests, so
|
|
# they should not be counted
|
|
if req.method_name not in ["_rpc_shutdown", "shutdown"]:
|
|
self._num_pending_requests += 1
|
|
logger_debug(
|
|
f"Dispatcher received request {req}, pending: {self._num_pending_requests}"
|
|
)
|
|
|
|
# TODO optimization: resolve the sequential scheduling for the remote calls
|
|
# Suppose tons of submit remote call block the FIFO queue, and the later get_stats remote calls may be blocked
|
|
# There could be two dispatch modes:
|
|
# 1. (current) mix mode, share the same routine/pool
|
|
# 2. (promising) stream mode, specific remote_call -> stream -> specific routine/pool
|
|
# - get_stats() - 1, remote_call -> dedicated queue -> dedicated routine/pool
|
|
# - submit() - 3 -> dedicated queue -> dedicated routine/pool
|
|
# TODO potential optimization: for submit(), batch the ad-hoc requests in an interval like 5ms, reduce the IPC count
|
|
async def _worker_routine(self, stop_event: threading.Event):
|
|
"""The routine executed by each worker thread."""
|
|
assert self._client_socket is not None, "Client socket is not bound"
|
|
assert self._queue is not None, "RPC queue is not initialized"
|
|
|
|
while (not stop_event.is_set()) or self._num_pending_requests > 0:
|
|
try:
|
|
req: RPCRequest = await asyncio.wait_for(
|
|
self._queue.get(), # type: ignore
|
|
timeout=self._timeout)
|
|
except asyncio.TimeoutError:
|
|
await asyncio.sleep(0)
|
|
continue
|
|
|
|
# check if the method name is in the functions
|
|
if req.method_name not in self._functions:
|
|
logger.error(
|
|
f"Method '{req.method_name}' not found in RPC server.")
|
|
self._num_pending_requests -= 1
|
|
|
|
if not req.need_response:
|
|
continue
|
|
if req.is_streaming:
|
|
await self._client_socket.put_async(
|
|
RPCResponse(
|
|
req.request_id,
|
|
None,
|
|
RPCStreamingError(
|
|
f"Method '{req.method_name}' not found in RPC server.",
|
|
traceback=traceback.format_exc()),
|
|
stream_status='error'))
|
|
else:
|
|
response = RPCResponse(
|
|
req.request_id,
|
|
None,
|
|
RPCError(
|
|
f"Method '{req.method_name}' not found in RPC server.",
|
|
traceback=traceback.format_exc()),
|
|
)
|
|
await self._client_socket.put_async(response)
|
|
|
|
continue
|
|
|
|
func = self._functions[req.method_name]
|
|
if req.is_streaming:
|
|
if inspect.isasyncgenfunction(func):
|
|
await self._process_streaming_request(req)
|
|
else:
|
|
# Non-streaming function called with streaming flag
|
|
response = RPCResponse(
|
|
req.request_id,
|
|
None,
|
|
RPCStreamingError(
|
|
f"Method '{req.method_name}' is not a streaming function."
|
|
),
|
|
# need to redirect the error to the client's streaming queue
|
|
is_streaming=True,
|
|
stream_status='error',
|
|
)
|
|
await self._client_socket.put_async(response)
|
|
else:
|
|
# Process regular request
|
|
response = await self._process_request(req)
|
|
|
|
# Some tasks don't need response, e.g. submit_request or shutdown
|
|
if req.need_response and response is not None:
|
|
logger_debug(
|
|
f"RPC Server sending response for request {req}, pending: {self._num_pending_requests}"
|
|
)
|
|
if await self._send_response(req, response):
|
|
logger_debug(
|
|
f"RPC Server sent response for request {req}")
|
|
|
|
# Only decrement if this request was counted in the first place
|
|
if req.method_name not in ["_rpc_shutdown", "shutdown"]:
|
|
self._num_pending_requests -= 1
|
|
|
|
def _calculate_adjusted_timeout(self,
|
|
req: RPCRequest,
|
|
is_streaming: bool = False) -> float:
|
|
"""Calculate adjusted timeout based on pending overhead.
|
|
|
|
Args:
|
|
req: The RPC request
|
|
is_streaming: Whether this is for a streaming request
|
|
|
|
Returns:
|
|
The adjusted timeout value
|
|
"""
|
|
adjusted_timeout = req.timeout
|
|
if req.creation_timestamp is not None and req.timeout is not None and req.timeout > 0:
|
|
pending_time = time.time() - req.creation_timestamp
|
|
adjusted_timeout = max(0.1, req.timeout -
|
|
pending_time) # Keep at least 0.1s timeout
|
|
if pending_time > 0.1: # Only log if significant pending time
|
|
method_type = "streaming " if is_streaming else ""
|
|
logger_debug(
|
|
f"RPC Server adjusted timeout for {method_type}{req.method_name}: "
|
|
f"original={req.timeout}s, pending={pending_time:.3f}s, adjusted={adjusted_timeout:.3f}s"
|
|
)
|
|
return adjusted_timeout
|
|
|
|
async def _process_request(self, req: RPCRequest) -> Optional[RPCResponse]:
|
|
"""Process a request. Returns None for streaming requests (handled separately)."""
|
|
func = self._functions[req.method_name]
|
|
|
|
# Calculate adjusted timeout based on pending overhead
|
|
adjusted_timeout = self._calculate_adjusted_timeout(req)
|
|
|
|
try:
|
|
if inspect.iscoroutinefunction(func):
|
|
# Execute async function directly in event loop, no need to run in executor due to the GIL
|
|
logger_debug(
|
|
f"RPC Server running async task {req.method_name} in dispatcher"
|
|
)
|
|
result = await asyncio.wait_for(func(*req.args, **req.kwargs),
|
|
timeout=adjusted_timeout)
|
|
else:
|
|
# Execute sync function in thread executor
|
|
loop = asyncio.get_running_loop()
|
|
|
|
def call_with_kwargs():
|
|
return func(*req.args, **req.kwargs)
|
|
|
|
logger_debug(
|
|
f"RPC Server running async task {req.method_name} in worker"
|
|
)
|
|
# TODO: let num worker control the pool size
|
|
result = await asyncio.wait_for(loop.run_in_executor(
|
|
self._executor, call_with_kwargs),
|
|
timeout=adjusted_timeout)
|
|
|
|
logger_debug(f"RPC Server returned result for request {req}")
|
|
response = RPCResponse(req.request_id, result)
|
|
|
|
except asyncio.TimeoutError:
|
|
response = RPCResponse(
|
|
req.request_id, None,
|
|
RPCTimeout(
|
|
f"Method '{req.method_name}' timed out after {req.timeout} seconds",
|
|
traceback=traceback.format_exc()))
|
|
|
|
except Exception as e:
|
|
response = RPCResponse(
|
|
req.request_id, None,
|
|
RPCError(str(e), cause=e, traceback=traceback.format_exc()))
|
|
|
|
return response
|
|
|
|
async def _process_streaming_request(self, req: RPCRequest):
|
|
"""Process a streaming request by sending multiple responses."""
|
|
func = self._functions[req.method_name]
|
|
|
|
if not inspect.isasyncgenfunction(func):
|
|
await self._client_socket.put_async(
|
|
RPCResponse(
|
|
req.request_id,
|
|
None,
|
|
RPCStreamingError(
|
|
f"Method '{req.method_name}' is not an async generator.",
|
|
traceback=traceback.format_exc()),
|
|
# need to redirect the error to the client's streaming queue
|
|
stream_status='error'))
|
|
return
|
|
|
|
sequence_number = 0
|
|
|
|
# Calculate adjusted timeout based on pending overhead
|
|
adjusted_timeout = self._calculate_adjusted_timeout(req,
|
|
is_streaming=True)
|
|
|
|
try:
|
|
logger_debug(f"RPC Server running streaming task {req.method_name}")
|
|
# Send start signal
|
|
await self._client_socket.put_async(
|
|
RPCResponse(req.request_id, None, None, True, sequence_number,
|
|
'start'))
|
|
sequence_number += 1
|
|
|
|
# Apply timeout to the entire streaming operation if specified
|
|
if adjusted_timeout is not None and adjusted_timeout > 0:
|
|
# Create a task for the async generator with timeout
|
|
async def stream_with_timeout():
|
|
nonlocal sequence_number
|
|
async for result in func(*req.args, **req.kwargs):
|
|
logger_debug(
|
|
f"RPC Server got data and ready to send result {result}"
|
|
)
|
|
response = RPCResponse(req.request_id, result, None,
|
|
True, sequence_number, 'data')
|
|
if not await self._send_response(req, response):
|
|
# Stop streaming after a pickle error
|
|
return
|
|
sequence_number += 1
|
|
|
|
# Use wait_for for timeout handling
|
|
await asyncio.wait_for(stream_with_timeout(),
|
|
timeout=adjusted_timeout)
|
|
else:
|
|
# No timeout specified, stream normally
|
|
async for result in func(*req.args, **req.kwargs):
|
|
logger_debug(
|
|
f"RPC Server got data and ready to send result {result}"
|
|
)
|
|
response = RPCResponse(req.request_id, result, None, True,
|
|
sequence_number, 'data')
|
|
if not await self._send_response(req, response):
|
|
# Stop streaming after a pickle error
|
|
return
|
|
sequence_number += 1
|
|
|
|
# Send end signal
|
|
await self._client_socket.put_async(
|
|
RPCResponse(req.request_id, None, None, True, sequence_number,
|
|
'end'))
|
|
|
|
except asyncio.TimeoutError:
|
|
await self._client_socket.put_async(
|
|
RPCResponse(
|
|
req.request_id, None,
|
|
RPCTimeout(
|
|
f"Streaming method '{req.method_name}' timed out",
|
|
traceback=traceback.format_exc()), True,
|
|
sequence_number, 'error'))
|
|
|
|
except Exception as e:
|
|
response = RPCResponse(
|
|
req.request_id, None,
|
|
RPCStreamingError(str(e), traceback=traceback.format_exc()),
|
|
True, sequence_number, 'error')
|
|
await self._send_response(req, response)
|
|
|
|
async def _send_response(self, req: RPCRequest,
|
|
response: RPCResponse) -> bool:
|
|
"""Safely sends a response, handling pickle errors."""
|
|
try:
|
|
await self._client_socket.put_async(response)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to pickle response for request {req.request_id}: {e}")
|
|
error_msg = f"Failed to pickle response: {e}"
|
|
if req.is_streaming:
|
|
error_cls = RPCStreamingError
|
|
# For streaming, we also need sequence number. The original response has it.
|
|
sequence_number = response.sequence_number if response else None
|
|
error_response = RPCResponse(
|
|
req.request_id,
|
|
None,
|
|
error_cls(error_msg, traceback=traceback.format_exc()),
|
|
is_streaming=True,
|
|
sequence_number=sequence_number,
|
|
stream_status='error')
|
|
else:
|
|
error_cls = RPCError
|
|
error_response = RPCResponse(
|
|
req.request_id, None,
|
|
error_cls(error_msg, traceback=traceback.format_exc()))
|
|
|
|
try:
|
|
await self._client_socket.put_async(error_response)
|
|
except Exception as e_inner:
|
|
logger.error(
|
|
f"Failed to send error response for request {req.request_id}: {e_inner}"
|
|
)
|
|
return False
|
|
|
|
def start(self):
|
|
"""Binds sockets, starts workers, and begins proxying messages."""
|
|
if self._client_socket is None:
|
|
raise RuntimeError(
|
|
"Server must be bound to an address before starting. Call bind() first."
|
|
)
|
|
|
|
self._client_socket.setup_lazily()
|
|
logger.info(f"RPC Server started and listening on {self._address}")
|
|
|
|
async def tasks():
|
|
self._queue = asyncio.Queue()
|
|
await asyncio.gather(
|
|
self._dispatcher_routine(self._stop_event), *[
|
|
self._worker_routine(self._stop_event)
|
|
for i in range(self._num_workers)
|
|
])
|
|
|
|
def loop() -> bool:
|
|
asyncio.run(tasks())
|
|
return True # ManagedThread
|
|
|
|
error_queue = queue.Queue()
|
|
self._dispatcher_thread = ManagedThread(task=loop,
|
|
stop_event=self._stop_event,
|
|
name="rpc_dispatcher_thread",
|
|
error_queue=error_queue)
|
|
self._dispatcher_thread.start()
|
|
|
|
logger.info("RPC Server has started.")
|
|
|
|
|
|
Server = RPCServer
|