mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com> Signed-off-by: chunweiy <chunweiy@nvidia.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> Signed-off-by: chunweiy <328693+Superjomn@users.noreply.github.com>
498 lines
20 KiB
Python
498 lines
20 KiB
Python
import asyncio
|
|
import concurrent.futures
|
|
import threading
|
|
import uuid
|
|
from typing import Any, AsyncIterator, Dict, Optional
|
|
|
|
from ...llmapi.utils import AsyncQueue, _SyncQueue, logger_debug
|
|
from ...logger import logger
|
|
from ..ipc import ZeroMqQueue
|
|
from .rpc_common import (RPCCancelled, RPCParams, RPCRequest, RPCResponse,
|
|
RPCStreamingError, RPCTimeout)
|
|
|
|
|
|
class RemoteCall:
|
|
"""Helper class to enable chained remote call syntax like client.method().remote()"""
|
|
|
|
def __init__(self, client: 'RPCClient', method_name: str, *args, **kwargs):
|
|
self.client = client
|
|
self.method_name = method_name
|
|
self.args = args
|
|
self.kwargs = kwargs
|
|
|
|
def _prepare_and_call(self, timeout: Optional[float], need_response: bool,
|
|
mode: str, call_method: str) -> Any:
|
|
"""Common method to prepare RPC params and make the call.
|
|
|
|
Args:
|
|
timeout: Timeout for the RPC call
|
|
need_response: Whether a response is expected
|
|
mode: The RPC mode ("sync", "async", "future")
|
|
call_method: The method name to call on the client
|
|
|
|
Returns:
|
|
The result of the client method call
|
|
"""
|
|
rpc_params = RPCParams(timeout=timeout,
|
|
need_response=need_response,
|
|
mode=mode)
|
|
self.kwargs["__rpc_params"] = rpc_params
|
|
client_method = getattr(self.client, call_method)
|
|
return client_method(self.method_name, *self.args, **self.kwargs)
|
|
|
|
def remote(self,
|
|
timeout: Optional[float] = None,
|
|
need_response: bool = True) -> Any:
|
|
"""Synchronous remote call with optional RPC parameters."""
|
|
return self._prepare_and_call(timeout, need_response, "sync",
|
|
"_call_sync")
|
|
|
|
def remote_async(self,
|
|
timeout: Optional[float] = None,
|
|
need_response: bool = True):
|
|
"""Asynchronous remote call that returns a coroutine."""
|
|
return self._prepare_and_call(timeout, need_response, "async",
|
|
"_call_async")
|
|
|
|
def remote_future(self,
|
|
timeout: Optional[float] = None,
|
|
need_response: bool = True) -> concurrent.futures.Future:
|
|
"""Remote call that returns a Future object."""
|
|
return self._prepare_and_call(timeout, need_response, "future",
|
|
"_call_future")
|
|
|
|
def remote_streaming(self,
|
|
timeout: Optional[float] = None) -> AsyncIterator[Any]:
|
|
"""Remote call for streaming results."""
|
|
# Streaming always needs a response
|
|
return self._prepare_and_call(timeout, True, "async", "_call_streaming")
|
|
|
|
|
|
class RPCClient:
|
|
"""
|
|
An RPC Client that connects to the RPCServer.
|
|
"""
|
|
|
|
def __init__(self,
|
|
address: str,
|
|
hmac_key=None,
|
|
timeout: Optional[float] = None,
|
|
num_workers: int = 4):
|
|
'''
|
|
Args:
|
|
address: The ZMQ address to connect to.
|
|
hmac_key: The HMAC key for encryption.
|
|
timeout: The timeout (seconds) for RPC calls.
|
|
num_workers: The number of workers for the RPC client.
|
|
'''
|
|
self._address = address
|
|
self._timeout = timeout
|
|
self._client_socket = ZeroMqQueue(address=(address, hmac_key),
|
|
is_server=False,
|
|
is_async=True,
|
|
use_hmac_encryption=False)
|
|
self._pending_futures = {}
|
|
# map request_id to the queue for streaming responses
|
|
self._streaming_queues: Dict[str, AsyncQueue] = {}
|
|
self._reader_task = None
|
|
self._executor = concurrent.futures.ThreadPoolExecutor(
|
|
max_workers=num_workers, thread_name_prefix="rpc_client_worker")
|
|
|
|
self._server_stopped = False
|
|
self._closed = False
|
|
self._stop_event = None
|
|
self._loop = None
|
|
self._loop_thread = None
|
|
|
|
logger_debug(f"RPC Client initialized. Connected to {self._address}")
|
|
|
|
def shutdown_server(self):
|
|
"""Shutdown the server."""
|
|
if self._server_stopped:
|
|
return
|
|
|
|
self._rpc_shutdown().remote()
|
|
|
|
self._server_stopped = True
|
|
|
|
def close(self):
|
|
"""Gracefully close the client, cleaning up background tasks."""
|
|
|
|
if self._closed:
|
|
return
|
|
# stop the main loop
|
|
self._closed = True
|
|
|
|
logger_debug("RPC Client closing")
|
|
|
|
if self._stop_event and self._loop:
|
|
# Use call_soon_threadsafe since set() is not a coroutine
|
|
self._loop.call_soon_threadsafe(self._stop_event.set)
|
|
|
|
if self._reader_task:
|
|
try:
|
|
self._reader_task.result(timeout=1.0)
|
|
except concurrent.futures.TimeoutError:
|
|
logger.warning(
|
|
"Reader task did not exit gracefully, cancelling")
|
|
self._reader_task.cancel()
|
|
except Exception as e:
|
|
# Task might have already finished or been cancelled
|
|
logger_debug(f"Reader task cleanup: {e}")
|
|
self._reader_task = None
|
|
|
|
if self._loop and self._loop.is_running():
|
|
self._loop.call_soon_threadsafe(self._loop.stop)
|
|
if self._loop_thread:
|
|
self._loop_thread.join()
|
|
self._loop_thread = None
|
|
if self._executor:
|
|
self._executor.shutdown(wait=True)
|
|
|
|
if self._client_socket:
|
|
self._client_socket.close()
|
|
self._client_socket = None
|
|
|
|
logger_debug("RPC Client closed")
|
|
|
|
async def _response_reader(self):
|
|
"""Task to read responses from the socket and set results on futures."""
|
|
logger_debug("Response reader started")
|
|
|
|
while not self._stop_event.is_set():
|
|
try:
|
|
# Use wait_for with a short timeout to periodically check stop event
|
|
try:
|
|
response: RPCResponse = await asyncio.wait_for(
|
|
self._client_socket.get_async(),
|
|
timeout=0.1 # Check stop event every 100ms
|
|
)
|
|
except asyncio.TimeoutError:
|
|
# Timeout is expected - just check stop event and continue
|
|
continue
|
|
|
|
logger_debug(f"RPC Client received response: {response}")
|
|
logger_debug(
|
|
f"Response request_id: {response.request_id}, is_streaming: {response.is_streaming}"
|
|
)
|
|
logger_debug(
|
|
f"Pending futures: {list(self._pending_futures.keys())}")
|
|
|
|
# Handle streaming responses
|
|
if response.is_streaming:
|
|
assert response.stream_status in [
|
|
'start', 'data', 'end', 'error'
|
|
], f"Invalid stream status: {response.stream_status}"
|
|
queue = self._streaming_queues.get(response.request_id)
|
|
if queue:
|
|
# put to the sync queue, as the current event loop is
|
|
# different from the one in call_async or call_streaming
|
|
assert isinstance(queue, AsyncQueue)
|
|
logger_debug(
|
|
f"RPC Client putting response to AsyncQueue: {response}"
|
|
)
|
|
queue.sync_q.put(response)
|
|
# Clean up if stream ended
|
|
if response.stream_status in ['end', 'error']:
|
|
self._streaming_queues.pop(response.request_id,
|
|
None)
|
|
else:
|
|
# Handle regular responses
|
|
logger_debug(
|
|
f"Handling regular response for request_id: {response.request_id}"
|
|
)
|
|
if future_info := self._pending_futures.get(
|
|
response.request_id):
|
|
future, target_loop = future_info
|
|
logger_debug(
|
|
f"Found future for request_id: {response.request_id}, future done: {future.done()}"
|
|
)
|
|
|
|
if not future.done():
|
|
if response.error is None:
|
|
logger_debug(
|
|
f"Setting result for request_id: {response.request_id}, result: {response.result}"
|
|
)
|
|
target_loop.call_soon_threadsafe(
|
|
future.set_result, response.result)
|
|
else:
|
|
# Use the original RPCError from the response
|
|
logger_debug(
|
|
f"Setting exception for request_id: {response.request_id}, error: {response.error}"
|
|
)
|
|
target_loop.call_soon_threadsafe(
|
|
future.set_exception, response.error)
|
|
else:
|
|
logger_debug(
|
|
f"No future found for request_id: {response.request_id}"
|
|
)
|
|
self._pending_futures.pop(response.request_id, None)
|
|
|
|
except asyncio.CancelledError:
|
|
# Still handle cancellation for backward compatibility
|
|
logger_debug("Response reader cancelled")
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Exception in RPC response reader: {e}")
|
|
# Propagate exception to all pending futures
|
|
for (future, target_loop) in self._pending_futures.values():
|
|
|
|
if not future.done():
|
|
target_loop.call_soon_threadsafe(
|
|
future.set_exception, e)
|
|
# Also signal error to streaming queues
|
|
for queue in self._streaming_queues.values():
|
|
await queue.put(RPCResponse("", None, e, False, 0, 'error'))
|
|
break
|
|
|
|
logger_debug("Response reader exiting gracefully")
|
|
self._reader_task = None
|
|
|
|
def _start_response_reader_lazily(self):
|
|
if self._reader_task is None or self._reader_task.done():
|
|
# Ensure we have a persistent background loop
|
|
self._ensure_event_loop()
|
|
# Always create the reader task on the persistent loop
|
|
future = asyncio.run_coroutine_threadsafe(self._response_reader(),
|
|
self._loop)
|
|
# Store the concurrent.futures.Future
|
|
self._reader_task = future
|
|
|
|
async def _call_async(self, method_name, *args, **kwargs):
|
|
"""Async version of RPC call.
|
|
Args:
|
|
method_name: Method name to call
|
|
*args: Positional arguments
|
|
**kwargs: Keyword arguments
|
|
__rpc_params: RPCParams object containing RPC parameters.
|
|
|
|
Returns:
|
|
The result of the remote method call
|
|
"""
|
|
logger_debug(
|
|
f"RPC client calling method: {method_name} with args: {args} and kwargs: {kwargs}"
|
|
)
|
|
if self._server_stopped:
|
|
raise RPCCancelled("Server is shutting down, request cancelled")
|
|
|
|
self._start_response_reader_lazily()
|
|
rpc_params = kwargs.pop("__rpc_params", RPCParams())
|
|
need_response = rpc_params.need_response
|
|
timeout = rpc_params.timeout if rpc_params.timeout is not None else self._timeout
|
|
|
|
request_id = uuid.uuid4().hex
|
|
request = RPCRequest(request_id,
|
|
method_name,
|
|
args,
|
|
kwargs,
|
|
need_response,
|
|
timeout=timeout)
|
|
logger_debug(f"RPC client sending request: {request}")
|
|
await self._client_socket.put_async(request)
|
|
|
|
if not need_response:
|
|
return None
|
|
|
|
loop = asyncio.get_running_loop()
|
|
future = loop.create_future()
|
|
logger_debug(
|
|
f"RPC Client _call_async: Created future for request_id: {request_id} in loop: {id(loop)}"
|
|
)
|
|
self._pending_futures[request_id] = (future, loop)
|
|
logger_debug(
|
|
f"RPC Client _call_async: Stored future in pending_futures")
|
|
|
|
try:
|
|
# If timeout, the remote call should return a timeout error timely,
|
|
# so we add 1 second to the timeout to ensure the client can get
|
|
# that result.
|
|
logger_debug(
|
|
f"RPC Client _call_async: Awaiting future for request_id: {request_id}"
|
|
)
|
|
if timeout is None:
|
|
res = await future
|
|
else:
|
|
# Add 1 second to the timeout to ensure the client can get
|
|
res = await asyncio.wait_for(future, timeout)
|
|
logger_debug(
|
|
f"RPC Client _call_async: Got result for request_id: {request_id}: {res}"
|
|
)
|
|
return res
|
|
except RPCCancelled:
|
|
self._server_stopped = True
|
|
raise
|
|
except asyncio.TimeoutError:
|
|
raise RPCTimeout(
|
|
f"Request '{method_name}' timed out after {timeout}s")
|
|
except Exception as e:
|
|
raise e
|
|
finally:
|
|
self._pending_futures.pop(request_id, None)
|
|
|
|
def _ensure_event_loop(self):
|
|
"""Ensure we have a running event loop in a background thread."""
|
|
if self._loop is None or not self._loop.is_running():
|
|
self._loop = asyncio.new_event_loop()
|
|
|
|
def run_loop():
|
|
asyncio.set_event_loop(self._loop)
|
|
self._stop_event = asyncio.Event()
|
|
self._loop.run_forever()
|
|
|
|
self._loop_thread = threading.Thread(target=run_loop,
|
|
daemon=True,
|
|
name="rpc_client_loop")
|
|
self._loop_thread.start()
|
|
|
|
# Give the loop a moment to start
|
|
import time
|
|
time.sleep(0.1)
|
|
|
|
def _call_sync(self, method_name, *args, **kwargs):
|
|
"""Synchronous version of RPC call."""
|
|
logger_debug(
|
|
f"RPC Client calling method: {method_name} with args: {args} and kwargs: {kwargs}"
|
|
)
|
|
self._ensure_event_loop()
|
|
logger_debug(
|
|
f"RPC Client _call_sync: Creating future for {method_name}")
|
|
future = asyncio.run_coroutine_threadsafe(
|
|
self._call_async(method_name, *args, **kwargs), self._loop)
|
|
logger_debug(
|
|
f"RPC Client _call_sync: Waiting for result of {method_name}")
|
|
result = future.result()
|
|
logger_debug(
|
|
f"RPC Client _call_sync: Got result for {method_name}: {result}")
|
|
return result
|
|
|
|
def _call_future(self, name: str, *args,
|
|
**kwargs) -> concurrent.futures.Future:
|
|
"""
|
|
Call a remote method and return a Future.
|
|
|
|
Args:
|
|
name: Method name to call
|
|
*args: Positional arguments
|
|
**kwargs: Keyword arguments
|
|
|
|
Returns:
|
|
A Future object that can be used to retrieve the result
|
|
"""
|
|
|
|
def _async_to_sync():
|
|
self._ensure_event_loop()
|
|
future = asyncio.run_coroutine_threadsafe(
|
|
self._call_async(name, *args, **kwargs), self._loop)
|
|
return future.result()
|
|
|
|
return self._executor.submit(_async_to_sync)
|
|
|
|
async def _call_streaming(self, name: str, *args,
|
|
**kwargs) -> AsyncIterator[Any]:
|
|
"""
|
|
Call a remote async generator method and get streaming results.
|
|
|
|
Args:
|
|
name: Method name to call
|
|
*args: Positional arguments
|
|
**kwargs: Keyword arguments
|
|
|
|
Yields:
|
|
Results from the remote async generator
|
|
"""
|
|
if self._server_stopped:
|
|
raise RPCCancelled("Server is shutting down, request cancelled")
|
|
|
|
self._start_response_reader_lazily()
|
|
rpc_params = kwargs.pop("__rpc_params", RPCParams())
|
|
timeout = rpc_params.timeout if rpc_params.timeout is not None else self._timeout
|
|
|
|
request_id = uuid.uuid4().hex
|
|
# Use AsyncQueue to ensure proper cross-thread communication
|
|
queue = AsyncQueue()
|
|
# Recreate sync_q with the current running loop for proper cross-thread communication
|
|
# This ensures the background _response_reader thread can properly notify this event loop
|
|
queue._sync_q = _SyncQueue(queue, asyncio.get_running_loop())
|
|
self._streaming_queues[request_id] = queue
|
|
|
|
try:
|
|
# Send streaming request
|
|
request = RPCRequest(request_id,
|
|
name,
|
|
args,
|
|
kwargs,
|
|
need_response=True,
|
|
timeout=timeout,
|
|
is_streaming=True)
|
|
await self._client_socket.put_async(request)
|
|
|
|
# Read streaming responses
|
|
while True:
|
|
logger_debug(f"RPC Client _call_streaming waiting for response",
|
|
color="green")
|
|
if timeout is None:
|
|
response = await queue.get()
|
|
else:
|
|
response = await asyncio.wait_for(queue.get(),
|
|
timeout=timeout)
|
|
|
|
logger_debug(
|
|
f"RPC Client _call_streaming received [{response.stream_status}] response: {response}",
|
|
color="green")
|
|
if response.stream_status == 'start':
|
|
# Start of stream
|
|
continue
|
|
elif response.stream_status == 'data':
|
|
logger_debug(
|
|
f"RPC Client _call_streaming received data: {response.result}",
|
|
color="green")
|
|
yield response.result
|
|
elif response.stream_status == 'end':
|
|
# End of stream
|
|
break
|
|
elif response.stream_status == 'error':
|
|
# Error in stream
|
|
if response.error:
|
|
raise response.error
|
|
else:
|
|
raise RPCStreamingError("Unknown streaming error")
|
|
|
|
except asyncio.TimeoutError:
|
|
raise RPCTimeout(
|
|
f"Streaming request '{name}' timed out after {timeout}s")
|
|
finally:
|
|
# Clean up
|
|
self._streaming_queues.pop(request_id, None)
|
|
|
|
def get_server_attr(self, name: str):
|
|
""" Get the attribute of the RPC server.
|
|
This is mainly used for testing. """
|
|
return self._rpc_get_attr(name).remote()
|
|
|
|
def __getattr__(self, name):
|
|
"""
|
|
Magically handles calls to non-existent methods.
|
|
Returns a callable that when invoked returns a RemoteCall instance.
|
|
|
|
This enables the new syntax:
|
|
client.method(args).remote()
|
|
await client.method(args).remote_async()
|
|
client.method(args).remote_future()
|
|
async for x in client.method(args).remote_streaming()
|
|
"""
|
|
logger_debug(f"RPC Client getting attribute: {name}")
|
|
|
|
def method_caller(*args, **kwargs):
|
|
return RemoteCall(self, name, *args, **kwargs)
|
|
|
|
return method_caller
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
self.close()
|
|
|
|
def __del__(self):
|
|
self.close()
|