mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
721 lines
30 KiB
Python
721 lines
30 KiB
Python
import asyncio
|
||
import concurrent.futures
|
||
import os
|
||
import threading
|
||
import time
|
||
import uuid
|
||
from datetime import datetime
|
||
from typing import Any, AsyncIterator, Dict, Optional
|
||
|
||
import zmq
|
||
|
||
from tensorrt_llm._utils import (customized_gc_thresholds, nvtx_mark_debug,
|
||
nvtx_range_debug)
|
||
|
||
from ...llmapi.utils import (AsyncQueue, _SyncQueue, enable_llmapi_debug,
|
||
logger_debug)
|
||
from ...logger import logger
|
||
from ..ipc import ZeroMqQueue
|
||
from .rpc_common import (RPCCancelled, RPCParams, RPCRequest, RPCResponse,
|
||
RPCStreamingError, RPCTimeout)
|
||
|
||
|
||
class RemoteCall:
|
||
"""Helper class to enable chained remote call syntax like client.method().remote()"""
|
||
|
||
def __init__(self, client: 'RPCClient', method_name: str, *args, **kwargs):
|
||
self.client = client
|
||
self.method_name = method_name
|
||
self.args = args
|
||
self.kwargs = kwargs
|
||
|
||
def _prepare_and_call(self, timeout: Optional[float], need_response: bool,
|
||
mode: str, call_method: str) -> Any:
|
||
"""Common method to prepare RPC params and make the call.
|
||
|
||
Args:
|
||
timeout: Timeout for the RPC call
|
||
need_response: Whether a response is expected
|
||
mode: The RPC mode ("sync", "async", "future")
|
||
call_method: The method name to call on the client
|
||
|
||
Returns:
|
||
The result of the client method call
|
||
"""
|
||
rpc_params = RPCParams(timeout=timeout,
|
||
need_response=need_response,
|
||
mode=mode)
|
||
self.kwargs["__rpc_params"] = rpc_params
|
||
client_method = getattr(self.client, call_method)
|
||
return client_method(self.method_name, *self.args, **self.kwargs)
|
||
|
||
def remote(self,
|
||
timeout: Optional[float] = None,
|
||
need_response: bool = True) -> Any:
|
||
"""Synchronous remote call with optional RPC parameters."""
|
||
return self._prepare_and_call(timeout, need_response, "sync",
|
||
"_call_sync")
|
||
|
||
def remote_async(self,
|
||
timeout: Optional[float] = None,
|
||
need_response: bool = True):
|
||
"""Asynchronous remote call that returns a coroutine."""
|
||
return self._prepare_and_call(timeout, need_response, "async",
|
||
"_call_async")
|
||
|
||
def remote_future(self,
|
||
timeout: Optional[float] = None,
|
||
need_response: bool = True) -> concurrent.futures.Future:
|
||
"""Remote call that returns a Future object."""
|
||
return self._prepare_and_call(timeout, need_response, "future",
|
||
"_call_future")
|
||
|
||
def remote_streaming(self,
|
||
timeout: Optional[float] = None) -> AsyncIterator[Any]:
|
||
"""Remote call for streaming results."""
|
||
# Streaming always needs a response
|
||
return self._prepare_and_call(timeout, True, "async", "_call_streaming")
|
||
|
||
|
||
class RPCClient:
|
||
"""
|
||
An RPC Client that connects to the RPCServer.
|
||
"""
|
||
|
||
def __init__(self,
|
||
address: str,
|
||
hmac_key=None,
|
||
timeout: Optional[float] = None,
|
||
num_workers: int = 4):
|
||
'''
|
||
Args:
|
||
address: The ZMQ address to connect to.
|
||
hmac_key: The HMAC key for encryption.
|
||
timeout: The timeout (seconds) for RPC calls.
|
||
num_workers: The number of workers for the RPC client.
|
||
'''
|
||
self._address = address
|
||
self._timeout = timeout
|
||
|
||
# Check if PAIR mode is enabled via environment variable
|
||
use_pair_mode = os.environ.get('TLLM_LLMAPI_ZMQ_PAIR', '0') != '0'
|
||
socket_type = zmq.PAIR if use_pair_mode else zmq.DEALER
|
||
|
||
if use_pair_mode:
|
||
logger_debug(
|
||
"[client] Using zmq.PAIR socket type for RPC communication")
|
||
|
||
self._client_socket = ZeroMqQueue(address=(address, hmac_key),
|
||
is_server=False,
|
||
is_async=True,
|
||
use_hmac_encryption=hmac_key
|
||
is not None,
|
||
socket_type=socket_type,
|
||
name="rpc_client")
|
||
self._pending_futures = {}
|
||
# map request_id to the queue for streaming responses
|
||
self._streaming_queues: Dict[str, AsyncQueue] = {}
|
||
self._streaming_queues_lock = threading.RLock(
|
||
) # Protect cross-thread access
|
||
self._reader_task = None
|
||
self._executor = concurrent.futures.ThreadPoolExecutor(
|
||
max_workers=num_workers, thread_name_prefix="rpc_client_worker")
|
||
|
||
self._server_stopped = False
|
||
self._closed = False
|
||
self._loop = None
|
||
self._loop_thread = None
|
||
self._reader_asyncio_task = None # Track the asyncio task for proper cancellation
|
||
self._loop_lock = threading.Lock(
|
||
) # Lock to protect event loop initialization
|
||
|
||
# Eagerly create the background event loop so that all subsequent
|
||
# RPC calls (sync or streaming) can assume it exists. This removes a
|
||
# race between the first streaming call (which previously created the
|
||
# loop lazily) and immediate fire-and-forget calls like `submit()`.
|
||
self._ensure_event_loop()
|
||
|
||
# Force ZeroMqQueue client connection during initialization
|
||
# This ensures the socket is connected before any RPC operations
|
||
self._client_socket.setup_lazily()
|
||
|
||
# Start the response reader eagerly to avoid race conditions with streaming
|
||
# This ensures the reader is processing responses before any RPC calls
|
||
self._start_response_reader_eagerly()
|
||
|
||
logger_debug(
|
||
f"[client] RPC Client initialized. Connected to {self._address}")
|
||
|
||
def shutdown_server(self):
|
||
"""Shutdown the server."""
|
||
if self._server_stopped:
|
||
return
|
||
|
||
self._rpc_shutdown().remote()
|
||
|
||
self._server_stopped = True
|
||
|
||
def close(self):
|
||
"""Gracefully close the client, cleaning up background tasks."""
|
||
|
||
if self._closed:
|
||
return
|
||
self._closed = True
|
||
|
||
logger_debug("[client] RPC Client closing")
|
||
|
||
# Notify any active streaming consumers so they don't hang waiting for
|
||
# further data. This must be done *before* shutting down the event
|
||
# loop/executor because they may depend on the loop to complete.
|
||
self._broadcast_streaming_error(RPCCancelled("RPC client closed"))
|
||
|
||
# 1. Cancel the reader task
|
||
if self._reader_task and not self._reader_task.done():
|
||
if self._loop and self._loop.is_running(
|
||
) and self._reader_asyncio_task:
|
||
try:
|
||
|
||
async def cancel_reader_task():
|
||
if self._reader_asyncio_task and not self._reader_asyncio_task.done(
|
||
):
|
||
self._reader_asyncio_task.cancel()
|
||
try:
|
||
await self._reader_asyncio_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
cancel_future = asyncio.run_coroutine_threadsafe(
|
||
cancel_reader_task(), self._loop)
|
||
cancel_future.result(timeout=2.0)
|
||
logger_debug("[client] Reader task cancelled successfully")
|
||
except concurrent.futures.TimeoutError:
|
||
logger.warning("Reader task did not exit gracefully")
|
||
except Exception as e:
|
||
logger_debug(f"[client] Reader task cleanup: {e}")
|
||
|
||
# 2. Stop the event loop
|
||
if self._loop and self._loop.is_running():
|
||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||
|
||
# 3. Join the event loop thread
|
||
if self._loop_thread:
|
||
self._loop_thread.join(timeout=2.0)
|
||
if self._loop_thread.is_alive():
|
||
logger.warning("Event loop thread did not exit gracefully")
|
||
|
||
# 4. Shutdown the executor
|
||
if self._executor:
|
||
self._executor.shutdown(wait=True)
|
||
|
||
# 5. Close the socket
|
||
if self._client_socket:
|
||
self._client_socket.close()
|
||
|
||
logger_debug("[client] RPC Client closed")
|
||
|
||
def _handle_streaming_response(self, response: RPCResponse):
|
||
"""Handle a streaming response by putting it in the appropriate queue.
|
||
|
||
Args:
|
||
response: The streaming response to handle
|
||
"""
|
||
assert response.stream_status in [
|
||
'start', 'data', 'end', 'error'
|
||
], f"Invalid stream status: {response.stream_status}"
|
||
|
||
with self._streaming_queues_lock:
|
||
queue = self._streaming_queues.get(response.request_id)
|
||
if queue:
|
||
logger_debug(
|
||
f"[client] [{datetime.now().isoformat()}] Found streaming queue for response: request_id={response.request_id}, "
|
||
f"status={response.stream_status}")
|
||
if queue:
|
||
# put to the sync queue, as the current event loop is
|
||
# different from the one in call_async or call_streaming
|
||
assert isinstance(queue, AsyncQueue)
|
||
if enable_llmapi_debug() or logger.level == 'debug':
|
||
logger_debug(
|
||
f"[client] RPC Client putting response to AsyncQueue: status={response.stream_status}, request_id={response.request_id}"
|
||
)
|
||
queue.sync_q.put(response)
|
||
# Clean up if stream ended
|
||
if response.stream_status in ['end', 'error']:
|
||
with self._streaming_queues_lock:
|
||
self._streaming_queues.pop(response.request_id, None)
|
||
|
||
def _handle_regular_response(self, response: RPCResponse):
|
||
"""Handle a regular (non-streaming) response by setting the future result.
|
||
|
||
Args:
|
||
response: The response to handle
|
||
"""
|
||
if future_info := self._pending_futures.get(response.request_id):
|
||
future, target_loop = future_info
|
||
|
||
if not future.done():
|
||
|
||
def safe_set_result():
|
||
"""Safely set result on future, handling race conditions."""
|
||
try:
|
||
if not future.done():
|
||
if response.error is None:
|
||
future.set_result(response.result)
|
||
else:
|
||
future.set_exception(response.error)
|
||
except asyncio.InvalidStateError:
|
||
# Future was cancelled or completed between the check and set
|
||
# This is expected in high-load scenarios, just log and continue
|
||
if enable_llmapi_debug() or logger.level == 'debug':
|
||
logger_debug(
|
||
f"[client] Future already done for request_id: {response.request_id}, skipping"
|
||
)
|
||
|
||
if enable_llmapi_debug() or logger.level == 'debug':
|
||
if response.error is None:
|
||
logger_debug(
|
||
f"[client] Setting result for request_id: {response.request_id}"
|
||
)
|
||
else:
|
||
logger_debug(
|
||
f"[client] Setting exception for request_id: {response.request_id}, error: {response.error}"
|
||
)
|
||
|
||
target_loop.call_soon_threadsafe(safe_set_result)
|
||
else:
|
||
if enable_llmapi_debug() or logger.level == 'debug':
|
||
logger_debug(
|
||
f"[client] No future found for request_id: {response.request_id}"
|
||
)
|
||
|
||
self._pending_futures.pop(response.request_id, None)
|
||
|
||
async def _handle_reader_exception(self, exception: Exception):
|
||
"""Propagate an exception to all pending futures and streaming queues.
|
||
|
||
Args:
|
||
exception: The exception to propagate
|
||
"""
|
||
logger.error(f"Exception in RPC response reader: {exception}")
|
||
|
||
# Propagate exception to all pending futures
|
||
for (future, target_loop) in self._pending_futures.values():
|
||
if not future.done():
|
||
|
||
def safe_set_exception(f=future, exc=exception):
|
||
"""Safely set exception on future, handling race conditions."""
|
||
try:
|
||
if not f.done():
|
||
f.set_exception(exc)
|
||
except asyncio.InvalidStateError:
|
||
# Future was cancelled or completed, this is fine
|
||
pass
|
||
|
||
target_loop.call_soon_threadsafe(safe_set_exception)
|
||
|
||
# Propagate to streaming queues via common helper
|
||
self._broadcast_streaming_error(exception)
|
||
|
||
async def _wait_for_response(self) -> RPCResponse:
|
||
"""Wait for a response from the socket.
|
||
|
||
Returns:
|
||
RPCResponse from the server
|
||
"""
|
||
# Use timeout-based recv to handle cancellation gracefully
|
||
# This prevents the CancelledError from being logged as an exception
|
||
while True:
|
||
try:
|
||
# Short timeout allows periodic checks for cancellation
|
||
return await self._client_socket.get_async_noblock(timeout=2)
|
||
except asyncio.TimeoutError:
|
||
# Check if we should exit due to cancellation
|
||
if self._closed or (self._reader_asyncio_task
|
||
and self._reader_asyncio_task.cancelled()):
|
||
raise asyncio.CancelledError("Reader task cancelled")
|
||
# Otherwise continue polling
|
||
continue
|
||
|
||
async def _response_reader(self):
|
||
"""Task to read responses from the socket and set results on futures."""
|
||
logger_debug("[client] Response reader started")
|
||
self._reader_asyncio_task = asyncio.current_task()
|
||
|
||
try:
|
||
# Add initial delay to ensure socket is fully connected
|
||
# This helps prevent race conditions during initialization
|
||
await asyncio.sleep(0.1)
|
||
logger_debug("[client] Response reader ready to process messages")
|
||
|
||
with customized_gc_thresholds(10000):
|
||
last_alive_log = time.time()
|
||
while not self._closed:
|
||
# Periodic alive logging for debugging
|
||
if time.time() - last_alive_log > 5.0:
|
||
logger_debug(
|
||
"[client] Response reader is alive and waiting for responses"
|
||
)
|
||
last_alive_log = time.time()
|
||
|
||
with nvtx_range_debug("response_reader",
|
||
color="cyan",
|
||
category="RPC"):
|
||
try:
|
||
response = await self._wait_for_response()
|
||
logger_debug(
|
||
f"[client] [{datetime.now().isoformat()}] Received response: {response}"
|
||
)
|
||
|
||
nvtx_mark_debug(
|
||
f"RPC.response.{'streaming' if response.is_streaming else 'sync'}",
|
||
color="black",
|
||
category="RPC")
|
||
|
||
# Optimize: Check debug flag before expensive string operations
|
||
# This avoids holding GIL for f-string evaluation when debug is disabled
|
||
if enable_llmapi_debug() or logger.level == 'debug':
|
||
logger_debug(
|
||
f"[client] [{datetime.now().isoformat()}] RPC Client received response: request_id={response.request_id}, "
|
||
f"is_streaming={response.is_streaming}, "
|
||
f"pending_futures={len(self._pending_futures)}"
|
||
)
|
||
|
||
with nvtx_range_debug("handle_response",
|
||
color="purple",
|
||
category="RPC"):
|
||
if response.is_streaming:
|
||
self._handle_streaming_response(response)
|
||
else:
|
||
self._handle_regular_response(response)
|
||
|
||
except asyncio.CancelledError:
|
||
# Re-raise cancellation to exit cleanly
|
||
raise
|
||
except Exception as e:
|
||
# Log the error but continue reading unless it's a critical error
|
||
logger.error(f"Error processing response: {e}",
|
||
exc_info=True)
|
||
|
||
# For critical errors, propagate and exit
|
||
if isinstance(e, (ConnectionError, zmq.ZMQError)):
|
||
await self._handle_reader_exception(e)
|
||
break
|
||
|
||
# For other errors, try to continue
|
||
await asyncio.sleep(
|
||
0.1
|
||
) # Brief pause to avoid tight loop on repeated errors
|
||
|
||
except asyncio.CancelledError:
|
||
logger_debug("[client] Response reader cancelled")
|
||
finally:
|
||
logger_debug("[client] Response reader exiting gracefully")
|
||
self._reader_task = None
|
||
self._reader_asyncio_task = None
|
||
|
||
def _start_response_reader_eagerly(self):
|
||
"""Start the response reader immediately during initialization.
|
||
|
||
This ensures the reader is ready before any RPC calls are made,
|
||
preventing race conditions with streaming responses.
|
||
"""
|
||
if self._reader_task is not None and not self._reader_task.done():
|
||
return # Already running
|
||
|
||
try:
|
||
if self._loop and self._loop.is_running():
|
||
future = asyncio.run_coroutine_threadsafe(
|
||
self._response_reader(), self._loop)
|
||
self._reader_task = future
|
||
# Wait a bit to ensure the reader is actually processing
|
||
time.sleep(0.2)
|
||
logger_debug(
|
||
"[client] Response reader started eagerly during initialization"
|
||
)
|
||
else:
|
||
raise RuntimeError(
|
||
"Event loop not running when trying to start response reader"
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Failed to start response reader eagerly: {e}")
|
||
self._reader_task = None
|
||
raise
|
||
|
||
async def _call_async(self, method_name, *args, **kwargs):
|
||
"""Async version of RPC call.
|
||
Args:
|
||
method_name: Method name to call
|
||
*args: Positional arguments
|
||
**kwargs: Keyword arguments
|
||
__rpc_params: RPCParams object containing RPC parameters.
|
||
|
||
Returns:
|
||
The result of the remote method call
|
||
"""
|
||
if enable_llmapi_debug() or logger.level == 'debug':
|
||
logger_debug(
|
||
f"[client] [{datetime.now().isoformat()}] RPC client calling method: {method_name}"
|
||
)
|
||
nvtx_mark_debug(f"RPC.async.{method_name}",
|
||
color="yellow",
|
||
category="RPC")
|
||
if self._server_stopped:
|
||
raise RPCCancelled("Server is shutting down, request cancelled")
|
||
|
||
rpc_params = kwargs.pop("__rpc_params", RPCParams())
|
||
need_response = rpc_params.need_response
|
||
timeout = rpc_params.timeout if rpc_params.timeout is not None else self._timeout
|
||
|
||
request_id = uuid.uuid4().hex
|
||
request = RPCRequest(request_id,
|
||
method_name=method_name,
|
||
args=args,
|
||
kwargs=kwargs,
|
||
need_response=need_response,
|
||
timeout=timeout)
|
||
await self._client_socket.put_async(request)
|
||
logger_debug(f"[client] RPC Client sent request: {request}")
|
||
|
||
if not need_response:
|
||
return None
|
||
|
||
loop = asyncio.get_running_loop()
|
||
future = loop.create_future()
|
||
self._pending_futures[request_id] = (future, loop)
|
||
|
||
try:
|
||
# If timeout, the remote call should return a timeout error timely,
|
||
# so we add 1 second to the timeout to ensure the client can get
|
||
# that result.
|
||
if timeout is None:
|
||
res = await future
|
||
else:
|
||
# Add 1 second to the timeout to ensure the client can get
|
||
res = await asyncio.wait_for(future, timeout)
|
||
return res
|
||
except RPCCancelled:
|
||
self._server_stopped = True
|
||
raise
|
||
except asyncio.TimeoutError:
|
||
raise RPCTimeout(
|
||
f"Request '{method_name}' timed out after {timeout}s")
|
||
except Exception as e:
|
||
raise e
|
||
finally:
|
||
self._pending_futures.pop(request_id, None)
|
||
|
||
def _ensure_event_loop(self):
|
||
"""Create and start the background event loop.
|
||
|
||
This is called once during initialization to create the dedicated
|
||
event loop for all socket I/O operations.
|
||
"""
|
||
if self._loop is not None:
|
||
return # Already created
|
||
|
||
self._loop = asyncio.new_event_loop()
|
||
|
||
def run_loop():
|
||
asyncio.set_event_loop(self._loop)
|
||
self._loop.run_forever()
|
||
|
||
self._loop_thread = threading.Thread(target=run_loop,
|
||
daemon=True,
|
||
name="rpc_client_loop")
|
||
self._loop_thread.start()
|
||
|
||
# Wait briefly to ensure the loop is running before returning
|
||
time.sleep(0.2)
|
||
|
||
def _call_sync(self, method_name, *args, **kwargs):
|
||
"""Synchronous version of RPC call."""
|
||
if enable_llmapi_debug() or logger.level == 'debug':
|
||
logger_debug(f"[client] RPC Client calling method: {method_name}")
|
||
nvtx_mark_debug(f"RPC.sync.{method_name}",
|
||
color="green",
|
||
category="RPC")
|
||
future = asyncio.run_coroutine_threadsafe(
|
||
self._call_async(method_name, *args, **kwargs), self._loop)
|
||
result = future.result()
|
||
return result
|
||
|
||
def _call_future(self, name: str, *args,
|
||
**kwargs) -> concurrent.futures.Future:
|
||
"""
|
||
Call a remote method and return a Future.
|
||
|
||
Args:
|
||
name: Method name to call
|
||
*args: Positional arguments
|
||
**kwargs: Keyword arguments
|
||
|
||
Returns:
|
||
A Future object that can be used to retrieve the result
|
||
"""
|
||
nvtx_mark_debug(f"RPC.future.{name}", color="blue", category="RPC")
|
||
|
||
def _async_to_sync():
|
||
future = asyncio.run_coroutine_threadsafe(
|
||
self._call_async(name, *args, **kwargs), self._loop)
|
||
return future.result()
|
||
|
||
return self._executor.submit(_async_to_sync)
|
||
|
||
async def _call_streaming(self, name: str, *args,
|
||
**kwargs) -> AsyncIterator[Any]:
|
||
"""
|
||
Call a remote async generator method and get streaming results.
|
||
|
||
Implementation note: The outgoing request is sent on the RPCClient’s
|
||
private event-loop to obey the single-loop rule. The returned items
|
||
are yielded in the caller’s loop via AsyncQueue, which is thread-safe.
|
||
"""
|
||
nvtx_mark_debug(f"RPC.streaming.{name}", color="red", category="RPC")
|
||
|
||
if self._server_stopped:
|
||
raise RPCCancelled("Server is shutting down, request cancelled")
|
||
rpc_params = kwargs.pop("__rpc_params", RPCParams())
|
||
timeout = rpc_params.timeout if rpc_params.timeout is not None else self._timeout
|
||
|
||
request_id = uuid.uuid4().hex
|
||
# Use AsyncQueue to ensure proper cross-thread communication
|
||
queue = AsyncQueue()
|
||
# Recreate sync_q with the current running loop for proper cross-thread communication
|
||
queue._sync_q = _SyncQueue(queue, asyncio.get_running_loop())
|
||
|
||
# Register queue with lock to ensure thread-safe access
|
||
with self._streaming_queues_lock:
|
||
self._streaming_queues[request_id] = queue
|
||
#logger_debug(f"[{datetime.now().isoformat()}] Registered streaming queue for request_id={request_id}")
|
||
|
||
# Build the RPCRequest object here – it's pickle-able and small – but
|
||
# *do not* touch the ZeroMQ socket from this (caller) event-loop.
|
||
request = RPCRequest(request_id,
|
||
method_name=name,
|
||
args=args,
|
||
kwargs=kwargs,
|
||
need_response=True,
|
||
timeout=timeout,
|
||
is_streaming=True)
|
||
|
||
# Send the request on the RPCClient's dedicated loop to guarantee that
|
||
# **all** socket I/O happens from exactly one thread / loop.
|
||
async def _send_streaming_request(req: RPCRequest):
|
||
"""Private helper executed in the client loop to put the request."""
|
||
logger_debug(
|
||
f"[client] [{datetime.now().isoformat()}] Sending streaming request: {req.method_name}, request_id={req.request_id}"
|
||
)
|
||
await self._client_socket.put_async(req)
|
||
logger_debug(
|
||
f"[client][{datetime.now().isoformat()}] Streaming request sent successfully: {req.method_name}, request_id={req.request_id}"
|
||
)
|
||
|
||
send_future = asyncio.run_coroutine_threadsafe(
|
||
_send_streaming_request(request), self._loop)
|
||
|
||
# Wait until the request is actually on the wire before entering the
|
||
# user-visible streaming loop. We wrap the concurrent.futures.Future so
|
||
# we can await it in the caller's asyncio context.
|
||
await asyncio.wrap_future(send_future)
|
||
|
||
try:
|
||
logger_debug(
|
||
f"[client] [{datetime.now().isoformat()}] Starting to read streaming responses for request_id={request_id}"
|
||
)
|
||
# Read streaming responses
|
||
while True:
|
||
if timeout is None:
|
||
response = await queue.get()
|
||
else:
|
||
response = await asyncio.wait_for(queue.get(),
|
||
timeout=timeout)
|
||
|
||
if enable_llmapi_debug() or logger.level == 'debug':
|
||
logger_debug(
|
||
f"[client] [{datetime.now().isoformat()}] RPC Client _call_streaming received [{response.stream_status}] response",
|
||
color="green")
|
||
|
||
if response.stream_status == 'start':
|
||
# Start of stream
|
||
continue
|
||
elif response.stream_status == 'data':
|
||
yield response.result
|
||
elif response.stream_status == 'end':
|
||
# End of stream
|
||
break
|
||
elif response.stream_status == 'error':
|
||
# Error in stream
|
||
if response.error:
|
||
raise response.error
|
||
else:
|
||
raise RPCStreamingError("Unknown streaming error")
|
||
|
||
except asyncio.TimeoutError:
|
||
raise RPCTimeout(
|
||
f"Streaming request '{name}' timed out after {timeout}s")
|
||
finally:
|
||
# Clean up
|
||
with self._streaming_queues_lock:
|
||
self._streaming_queues.pop(request_id, None)
|
||
|
||
def _broadcast_streaming_error(self, exc: Exception):
|
||
"""Send an error response to all pending streaming queues so that
|
||
any coroutines blocked in _call_streaming can exit promptly.
|
||
|
||
Args:
|
||
exc: The exception instance to propagate downstream.
|
||
"""
|
||
# Iterate over a copy because callbacks may mutate the dict
|
||
with self._streaming_queues_lock:
|
||
streaming_items = list(self._streaming_queues.items())
|
||
for request_id, queue in streaming_items:
|
||
if not isinstance(queue, AsyncQueue):
|
||
continue
|
||
try:
|
||
# Use the underlying sync_q to ensure cross-thread delivery
|
||
queue.sync_q.put(
|
||
RPCResponse(
|
||
request_id,
|
||
result=None,
|
||
error=exc,
|
||
is_streaming=True,
|
||
chunk_index=0,
|
||
stream_status='error',
|
||
))
|
||
except Exception as e:
|
||
# Best-effort; log and continue
|
||
if enable_llmapi_debug() or logger.level == 'debug':
|
||
logger_debug(
|
||
f"[client] [{datetime.now().isoformat()}] Failed to broadcast streaming error for {request_id}: {e}"
|
||
)
|
||
|
||
def get_server_attr(self, name: str):
|
||
""" Get the attribute of the RPC server.
|
||
This is mainly used for testing. """
|
||
return self._rpc_get_attr(name).remote()
|
||
|
||
def __getattr__(self, name):
|
||
"""
|
||
Magically handles calls to non-existent methods.
|
||
Returns a callable that when invoked returns a RemoteCall instance.
|
||
|
||
This enables the new syntax:
|
||
client.method(args).remote()
|
||
await client.method(args).remote_async()
|
||
client.method(args).remote_future()
|
||
async for x in client.method(args).remote_streaming()
|
||
"""
|
||
logger_debug(
|
||
f"[client] [{datetime.now().isoformat()}] RPC Client getting attribute: {name}"
|
||
)
|
||
|
||
def method_caller(*args, **kwargs):
|
||
return RemoteCall(self, name, *args, **kwargs)
|
||
|
||
return method_caller
|
||
|
||
def __enter__(self):
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_value, traceback):
|
||
self.close()
|