TensorRT-LLMs/tensorrt_llm/executor/rpc/rpc_common.py
Yan Chunwei 85406f9dda
[https://nvbugs/5720482][fix] Fix test rpc streaming (#9902)
Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com>
2025-12-13 01:14:43 -08:00

95 lines
2.6 KiB
Python

import os
import tempfile
import time
import uuid
from dataclasses import KW_ONLY, dataclass
from typing import Any, Literal, NamedTuple, Optional
def get_unique_ipc_addr() -> str:
"""Generate a cryptographically unique IPC address using UUID."""
# uuid.uuid4() generates a random, unique identifier
unique_id = uuid.uuid4()
temp_dir = tempfile.gettempdir()
file_name = f"rpc_test_{unique_id}"
full_path = os.path.join(temp_dir, file_name)
return f"ipc://{full_path}"
class RPCParams(NamedTuple):
""" Parameters for RPC calls. """
# seconds to wait for the response
timeout: Optional[float] = None
# whether the client needs the response, if False, it will return immediately
need_response: bool = True
# mode for RPC calls: "sync", "async", or "future"
mode: str = "sync"
# --- Custom Exceptions ---
class RPCError(Exception):
"""Custom exception for RPC-related errors raised on the client side.
Args:
message: The error message.
cause: The original exception that caused this error.
traceback: The traceback of the exception.
"""
def __init__(self,
message: str,
cause: Optional[Exception] = None,
traceback: Optional[str] = None):
super().__init__(message)
self.cause = cause
self.traceback = traceback
class RPCTimeout(RPCError):
"""Exception for when a request processing times out."""
class RPCCancelled(RPCError):
"""Exception for when a client request is cancelled.
This happens when the server is shutting down and all the pending
requests will be cancelled and return with this error.
"""
class RPCStreamingError(RPCError):
"""Exception for streaming-related errors."""
@dataclass
class RPCRequest:
request_id: str
_: KW_ONLY
method_name: str
args: tuple
kwargs: dict
need_response: bool = True
timeout: float = 0.5
is_streaming: bool = False
creation_timestamp: Optional[
float] = None # Unix timestamp when request was created
routing_id: Optional[bytes] = None
def __post_init__(self):
"""Initialize creation_timestamp if not provided."""
if self.creation_timestamp is None:
self.creation_timestamp = time.time()
@dataclass
class RPCResponse:
request_id: str
_: KW_ONLY
result: Any
error: Optional[RPCError] = None
is_streaming: bool = False # True if more responses coming
chunk_index: int = 0 # For ordering streaming responses
stream_status: Literal['start', 'data', 'end', 'error'] = 'data'