mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5720482][fix] Fix test rpc streaming (#9902)
Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com>
This commit is contained in:
parent
8cbf2d958c
commit
85406f9dda
@ -59,6 +59,7 @@ class ZeroMqQueue:
|
||||
self._setup_done = False
|
||||
self.name = name
|
||||
self.socket = self.context.socket(socket_type)
|
||||
self.socket.set_hwm(0)
|
||||
|
||||
# For ROUTER sockets, track the last identity to enable replies. For now we assume there is only one client in our case.
|
||||
self._last_identity = None
|
||||
@ -154,14 +155,14 @@ class ZeroMqQueue:
|
||||
else:
|
||||
return False
|
||||
|
||||
def put(self, obj: Any):
|
||||
def put(self, obj: Any, routing_id: Optional[bytes] = None):
|
||||
self.setup_lazily()
|
||||
self._check_thread_safety()
|
||||
with nvtx_range_debug("send", color="blue", category="IPC"):
|
||||
if self.use_hmac_encryption or self.socket_type == zmq.ROUTER:
|
||||
# Need manual serialization for encryption or ROUTER multipart
|
||||
data = self._prepare_data(obj)
|
||||
self._send_data(data)
|
||||
self._send_data(data, routing_id=routing_id)
|
||||
else:
|
||||
# Standard socket without encryption - use pyobj directly
|
||||
self.socket.send_pyobj(obj)
|
||||
@ -197,14 +198,14 @@ class ZeroMqQueue:
|
||||
else:
|
||||
logger.error(f"Failed to send object: {obj}")
|
||||
|
||||
async def put_async(self, obj: Any):
|
||||
async def put_async(self, obj: Any, routing_id: Optional[bytes] = None):
|
||||
self.setup_lazily()
|
||||
self._check_thread_safety()
|
||||
try:
|
||||
if self.use_hmac_encryption or self.socket_type == zmq.ROUTER:
|
||||
# Need manual serialization for encryption or ROUTER multipart
|
||||
data = self._prepare_data(obj)
|
||||
await self._send_data_async(data)
|
||||
await self._send_data_async(data, routing_id=routing_id)
|
||||
else:
|
||||
# Standard socket without encryption
|
||||
await self.socket.send_pyobj(obj)
|
||||
@ -243,7 +244,9 @@ class ZeroMqQueue:
|
||||
self._check_thread_safety()
|
||||
return await self._recv_data_async()
|
||||
|
||||
async def get_async_noblock(self, timeout: float = 0.5) -> Any:
|
||||
async def get_async_noblock(self,
|
||||
timeout: float = 0.5,
|
||||
return_identity: bool = False) -> Any:
|
||||
"""Get data with timeout using polling to avoid message drops.
|
||||
|
||||
This method uses ZMQ's NOBLOCK flag with polling instead of asyncio.wait_for
|
||||
@ -251,9 +254,10 @@ class ZeroMqQueue:
|
||||
|
||||
Args:
|
||||
timeout: Timeout in seconds
|
||||
return_identity: Whether to return the identity of the sender (for ROUTER sockets)
|
||||
|
||||
Returns:
|
||||
The received object
|
||||
The received object, or (object, identity) if return_identity is True
|
||||
|
||||
Raises:
|
||||
asyncio.TimeoutError: If timeout is reached without receiving data
|
||||
@ -271,13 +275,22 @@ class ZeroMqQueue:
|
||||
identity, data = await self.socket.recv_multipart(
|
||||
flags=zmq.NOBLOCK)
|
||||
self._last_identity = identity
|
||||
return self._parse_data(data)
|
||||
obj = self._parse_data(data)
|
||||
if return_identity:
|
||||
return obj, identity
|
||||
else:
|
||||
return obj
|
||||
else:
|
||||
if self.use_hmac_encryption:
|
||||
data = await self.socket.recv(flags=zmq.NOBLOCK)
|
||||
return self._parse_data(data)
|
||||
obj = self._parse_data(data)
|
||||
else:
|
||||
return await self.socket.recv_pyobj(flags=zmq.NOBLOCK)
|
||||
obj = await self.socket.recv_pyobj(flags=zmq.NOBLOCK)
|
||||
|
||||
if return_identity:
|
||||
return obj, None
|
||||
else:
|
||||
return obj
|
||||
except zmq.Again:
|
||||
# No message available yet
|
||||
if asyncio.get_event_loop().time() >= deadline:
|
||||
@ -329,30 +342,39 @@ class ZeroMqQueue:
|
||||
else:
|
||||
return pickle.loads(data) # nosec B301
|
||||
|
||||
def _send_data(self, data: bytes, flags: int = 0):
|
||||
def _send_data(self,
|
||||
data: bytes,
|
||||
flags: int = 0,
|
||||
routing_id: Optional[bytes] = None):
|
||||
"""Send data using appropriate API based on socket type."""
|
||||
if self.socket_type == zmq.ROUTER:
|
||||
if self._last_identity is None:
|
||||
identity = routing_id if routing_id is not None else self._last_identity
|
||||
if identity is None:
|
||||
raise ValueError("ROUTER socket requires identity")
|
||||
self.socket.send_multipart([self._last_identity, data], flags=flags)
|
||||
self.socket.send_multipart([identity, data], flags=flags)
|
||||
else:
|
||||
self.socket.send(data, flags=flags)
|
||||
|
||||
async def _send_data_async(self, data: bytes):
|
||||
async def _send_data_async(self,
|
||||
data: bytes,
|
||||
routing_id: Optional[bytes] = None):
|
||||
"""Async version of _send_data."""
|
||||
if self.socket_type == zmq.ROUTER:
|
||||
if self._last_identity is None:
|
||||
identity = routing_id if routing_id is not None else self._last_identity
|
||||
if identity is None:
|
||||
raise ValueError("ROUTER socket requires identity")
|
||||
await self.socket.send_multipart([self._last_identity, data])
|
||||
await self.socket.send_multipart([identity, data])
|
||||
else:
|
||||
await self.socket.send(data)
|
||||
|
||||
def _recv_data(self) -> Any:
|
||||
def _recv_data(self, return_identity: bool = False) -> Any:
|
||||
"""Receive data using appropriate API based on socket type."""
|
||||
if self.socket_type == zmq.ROUTER:
|
||||
identity, data = self.socket.recv_multipart()
|
||||
self._last_identity = identity # Store for replies
|
||||
obj = self._parse_data(data)
|
||||
if return_identity:
|
||||
return obj, identity
|
||||
return obj
|
||||
else:
|
||||
if self.use_hmac_encryption:
|
||||
@ -360,20 +382,30 @@ class ZeroMqQueue:
|
||||
obj = self._parse_data(data)
|
||||
else:
|
||||
obj = self.socket.recv_pyobj()
|
||||
|
||||
if return_identity:
|
||||
return obj, None
|
||||
return obj
|
||||
|
||||
async def _recv_data_async(self) -> Any:
|
||||
async def _recv_data_async(self, return_identity: bool = False) -> Any:
|
||||
"""Async version of _recv_data."""
|
||||
if self.socket_type == zmq.ROUTER:
|
||||
identity, data = await self.socket.recv_multipart()
|
||||
self._last_identity = identity # Store for replies
|
||||
return self._parse_data(data)
|
||||
obj = self._parse_data(data)
|
||||
if return_identity:
|
||||
return obj, identity
|
||||
return obj
|
||||
else:
|
||||
if self.use_hmac_encryption:
|
||||
data = await self.socket.recv()
|
||||
return self._parse_data(data)
|
||||
obj = self._parse_data(data)
|
||||
else:
|
||||
return await self.socket.recv_pyobj()
|
||||
obj = await self.socket.recv_pyobj()
|
||||
|
||||
if return_identity:
|
||||
return obj, None
|
||||
return obj
|
||||
|
||||
def notify_with_retry(self, message, max_retries=5, timeout=1):
|
||||
"""
|
||||
|
||||
@ -75,6 +75,7 @@ class RPCRequest:
|
||||
is_streaming: bool = False
|
||||
creation_timestamp: Optional[
|
||||
float] = None # Unix timestamp when request was created
|
||||
routing_id: Optional[bytes] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize creation_timestamp if not provided."""
|
||||
|
||||
@ -228,8 +228,10 @@ class RPCServer:
|
||||
|
||||
while asyncio.get_event_loop().time() < end_time:
|
||||
try:
|
||||
req: RPCRequest = await asyncio.wait_for(
|
||||
self._client_socket.get_async_noblock(), timeout=2)
|
||||
req, routing_id = await asyncio.wait_for(
|
||||
self._client_socket.get_async_noblock(return_identity=True),
|
||||
timeout=2)
|
||||
req.routing_id = routing_id
|
||||
drained_count += 1
|
||||
logger_debug(f"[server] Draining request after shutdown: {req}")
|
||||
|
||||
@ -299,13 +301,16 @@ class RPCServer:
|
||||
error=error,
|
||||
is_streaming=
|
||||
True, # Important: mark as streaming so it gets routed correctly
|
||||
stream_status='error'))
|
||||
stream_status='error'),
|
||||
routing_id=req.routing_id)
|
||||
logger_debug(
|
||||
f"[server] Sent error response for request {req.request_id}",
|
||||
color="green")
|
||||
else:
|
||||
await self._client_socket.put_async(
|
||||
RPCResponse(req.request_id, result=None, error=error))
|
||||
await self._client_socket.put_async(RPCResponse(req.request_id,
|
||||
result=None,
|
||||
error=error),
|
||||
routing_id=req.routing_id)
|
||||
logger_debug(
|
||||
f"[server] Sent error response for request {req.request_id}",
|
||||
color="green")
|
||||
@ -335,8 +340,10 @@ class RPCServer:
|
||||
try:
|
||||
#logger_debug(f"[server] Worker waiting for request", color="green")
|
||||
# Read request directly from socket with timeout
|
||||
req: RPCRequest = await asyncio.wait_for(
|
||||
self._client_socket.get_async_noblock(), timeout=2)
|
||||
req, routing_id = await asyncio.wait_for(
|
||||
self._client_socket.get_async_noblock(return_identity=True),
|
||||
timeout=2)
|
||||
req.routing_id = routing_id
|
||||
logger_debug(f"[server] Worker got request: {req}",
|
||||
color="green")
|
||||
except asyncio.TimeoutError:
|
||||
@ -492,15 +499,15 @@ class RPCServer:
|
||||
func = self._functions[req.method_name]
|
||||
|
||||
if not inspect.isasyncgenfunction(func):
|
||||
await self._client_socket.put_async(
|
||||
RPCResponse(
|
||||
req.request_id,
|
||||
result=None,
|
||||
error=RPCStreamingError(
|
||||
f"Method '{req.method_name}' is not an async generator.",
|
||||
traceback=traceback.format_exc()),
|
||||
is_streaming=True,
|
||||
stream_status='error'))
|
||||
await self._client_socket.put_async(RPCResponse(
|
||||
req.request_id,
|
||||
result=None,
|
||||
error=RPCStreamingError(
|
||||
f"Method '{req.method_name}' is not an async generator.",
|
||||
traceback=traceback.format_exc()),
|
||||
is_streaming=True,
|
||||
stream_status='error'),
|
||||
routing_id=req.routing_id)
|
||||
return
|
||||
|
||||
chunk_index = 0
|
||||
@ -512,13 +519,14 @@ class RPCServer:
|
||||
logger_debug(
|
||||
f"[server] RPC Server running streaming task {req.method_name}")
|
||||
# Send start signal
|
||||
await self._client_socket.put_async(
|
||||
RPCResponse(req.request_id,
|
||||
result=None,
|
||||
error=None,
|
||||
is_streaming=True,
|
||||
chunk_index=chunk_index,
|
||||
stream_status='start'))
|
||||
await self._client_socket.put_async(RPCResponse(
|
||||
req.request_id,
|
||||
result=None,
|
||||
error=None,
|
||||
is_streaming=True,
|
||||
chunk_index=chunk_index,
|
||||
stream_status='start'),
|
||||
routing_id=req.routing_id)
|
||||
logger_debug(
|
||||
f"[server] Sent start signal for request {req.request_id}",
|
||||
color="green")
|
||||
@ -584,39 +592,41 @@ class RPCServer:
|
||||
chunk_index += 1
|
||||
|
||||
# Send end signal
|
||||
await self._client_socket.put_async(
|
||||
RPCResponse(req.request_id,
|
||||
result=None,
|
||||
error=None,
|
||||
is_streaming=True,
|
||||
chunk_index=chunk_index,
|
||||
stream_status='end'))
|
||||
await self._client_socket.put_async(RPCResponse(
|
||||
req.request_id,
|
||||
result=None,
|
||||
error=None,
|
||||
is_streaming=True,
|
||||
chunk_index=chunk_index,
|
||||
stream_status='end'),
|
||||
routing_id=req.routing_id)
|
||||
logger_debug(
|
||||
f"[server] Sent end signal for request {req.request_id}",
|
||||
color="green")
|
||||
except RPCCancelled as e:
|
||||
# Server is shutting down, send cancelled error
|
||||
await self._client_socket.put_async(
|
||||
RPCResponse(req.request_id,
|
||||
result=None,
|
||||
error=e,
|
||||
is_streaming=True,
|
||||
chunk_index=chunk_index,
|
||||
stream_status='error'))
|
||||
await self._client_socket.put_async(RPCResponse(
|
||||
req.request_id,
|
||||
result=None,
|
||||
error=e,
|
||||
is_streaming=True,
|
||||
chunk_index=chunk_index,
|
||||
stream_status='error'),
|
||||
routing_id=req.routing_id)
|
||||
logger_debug(
|
||||
f"[server] Sent error signal for request {req.request_id}",
|
||||
color="green")
|
||||
except asyncio.TimeoutError:
|
||||
await self._client_socket.put_async(
|
||||
RPCResponse(
|
||||
req.request_id,
|
||||
result=None,
|
||||
error=RPCTimeout(
|
||||
f"Streaming method '{req.method_name}' timed out",
|
||||
traceback=traceback.format_exc()),
|
||||
is_streaming=True,
|
||||
chunk_index=chunk_index,
|
||||
stream_status='error'))
|
||||
await self._client_socket.put_async(RPCResponse(
|
||||
req.request_id,
|
||||
result=None,
|
||||
error=RPCTimeout(
|
||||
f"Streaming method '{req.method_name}' timed out",
|
||||
traceback=traceback.format_exc()),
|
||||
is_streaming=True,
|
||||
chunk_index=chunk_index,
|
||||
stream_status='error'),
|
||||
routing_id=req.routing_id)
|
||||
|
||||
except Exception as e:
|
||||
response = RPCResponse(
|
||||
@ -633,7 +643,8 @@ class RPCServer:
|
||||
response: RPCResponse) -> bool:
|
||||
"""Safely sends a response, handling pickle errors."""
|
||||
try:
|
||||
await self._client_socket.put_async(response)
|
||||
await self._client_socket.put_async(response,
|
||||
routing_id=req.routing_id)
|
||||
logger_debug(f"[server] Sent response for request {req.request_id}",
|
||||
color="green")
|
||||
return True
|
||||
@ -661,7 +672,8 @@ class RPCServer:
|
||||
traceback=traceback.format_exc()))
|
||||
|
||||
try:
|
||||
await self._client_socket.put_async(error_response)
|
||||
await self._client_socket.put_async(error_response,
|
||||
routing_id=req.routing_id)
|
||||
logger_debug(
|
||||
f"[server] Sent error response for request {req.request_id}",
|
||||
color="green")
|
||||
|
||||
@ -422,6 +422,5 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUT
|
||||
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_2gpus[cutlass-two_model-overlap_scheduler] SKIP (https://nvbugs/5702826)
|
||||
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-two_model-overlap_scheduler] SKIP (https://nvbugs/5702826)
|
||||
disaggregated/test_auto_scaling.py::test_worker_restart[etcd-round_robin] SKIP (https://nvbugs/5726118)
|
||||
unittest/executor/test_rpc.py::TestRpcCorrectness::test_incremental_task_streaming SKIP (https://nvbugs/5720482)
|
||||
unittest/llmapi/test_llm_pytorch.py::test_llm_reward_model SKIP (https://nvbugs/5670458)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[enable_configurable_moe-moe_backend=TRTLLM-mtp_nextn=0-ep4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] SKIP (https://nvbugs/5727475)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import threading
|
||||
import time
|
||||
|
||||
@ -200,7 +201,9 @@ class TestRpcCorrectness:
|
||||
) == no + 1, f"result {future.result()} != {no + 1}"
|
||||
|
||||
def test_incremental_task_streaming(self):
|
||||
with RpcServerWrapper(TestRpcCorrectness.App()) as server:
|
||||
with RpcServerWrapper(TestRpcCorrectness.App(),
|
||||
async_run_task=True) as server:
|
||||
|
||||
with RPCClient(server.addr) as client:
|
||||
|
||||
async def test_streaming_task():
|
||||
@ -218,6 +221,30 @@ class TestRpcCorrectness:
|
||||
|
||||
asyncio.run(test_streaming_task())
|
||||
|
||||
def test_multi_client_to_single_server(self):
|
||||
"""Test that multiple RPC clients can concurrently connect to a single RPC server and execute tasks."""
|
||||
|
||||
class App:
|
||||
|
||||
def echo(self, msg: str) -> str:
|
||||
return msg
|
||||
|
||||
with RpcServerWrapper(App()) as server:
|
||||
# Create multiple clients
|
||||
num_clients = 10
|
||||
clients = [RPCClient(server.addr) for _ in range(num_clients)]
|
||||
|
||||
try:
|
||||
# Perform requests from all clients
|
||||
for i, client in enumerate(clients):
|
||||
msg = f"hello from client {i}"
|
||||
ret = client.echo(msg).remote()
|
||||
assert ret == msg, f"Client {i} failed: expected '{msg}', got '{ret}'"
|
||||
finally:
|
||||
# Clean up clients
|
||||
for client in clients:
|
||||
client.close()
|
||||
|
||||
|
||||
class TestRpcError:
|
||||
|
||||
@ -1006,3 +1033,93 @@ class TestRpcRobustness:
|
||||
f"Iteration {i}/{num_calls} completed successfully")
|
||||
|
||||
print(f"All {num_calls} iterations completed successfully")
|
||||
|
||||
@pytest.mark.parametrize("concurrency", [10, 50, 100])
|
||||
def test_many_client_to_single_server(self, concurrency):
|
||||
"""
|
||||
Pressure test where many clients connect to a single server.
|
||||
Controls concurrency via parameter and ensures each client performs multiple operations.
|
||||
"""
|
||||
|
||||
class App:
|
||||
|
||||
def echo(self, msg: str) -> str:
|
||||
return msg
|
||||
|
||||
total_clients = max(200, concurrency * 2)
|
||||
requests_per_client = 100
|
||||
|
||||
with RpcServerWrapper(App(), async_run_task=True) as server:
|
||||
errors = []
|
||||
|
||||
def run_client_session(client_id):
|
||||
try:
|
||||
with RPCClient(server.addr) as client:
|
||||
for i in range(requests_per_client):
|
||||
msg = f"c{client_id}-req{i}"
|
||||
ret = client.echo(msg).remote()
|
||||
assert ret == msg
|
||||
except Exception as e:
|
||||
errors.append(f"Client {client_id} error: {e}")
|
||||
raise
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=concurrency) as executor:
|
||||
futures = [
|
||||
executor.submit(run_client_session, i)
|
||||
for i in range(total_clients)
|
||||
]
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
# Check for exceptions in futures
|
||||
for f in futures:
|
||||
if f.exception():
|
||||
errors.append(str(f.exception()))
|
||||
|
||||
assert not errors, f"Encountered errors: {errors[:5]}..."
|
||||
|
||||
@pytest.mark.parametrize("concurrency", [10, 50, 100])
|
||||
def test_many_client_to_single_server_threaded(self, concurrency):
|
||||
"""
|
||||
Pressure test where clients are created and used in different threads.
|
||||
"""
|
||||
import concurrent.futures
|
||||
|
||||
class App:
|
||||
|
||||
def echo(self, msg: str) -> str:
|
||||
return msg
|
||||
|
||||
# Scale total clients to be more than concurrency to force queueing/reuse
|
||||
total_clients = max(200, concurrency * 2)
|
||||
requests_per_client = 100
|
||||
|
||||
with RpcServerWrapper(App(), async_run_task=True) as server:
|
||||
errors = []
|
||||
|
||||
def run_client_session(client_id):
|
||||
try:
|
||||
# Client creation and usage happens strictly within this thread
|
||||
with RPCClient(server.addr) as client:
|
||||
for i in range(requests_per_client):
|
||||
msg = f"c{client_id}-req{i}"
|
||||
ret = client.echo(msg).remote()
|
||||
assert ret == msg
|
||||
except Exception as e:
|
||||
errors.append(f"Client {client_id} error: {e}")
|
||||
raise
|
||||
|
||||
# Use ThreadPoolExecutor to simulate concurrent threads
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=concurrency) as executor:
|
||||
futures = [
|
||||
executor.submit(run_client_session, i)
|
||||
for i in range(total_clients)
|
||||
]
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
for f in futures:
|
||||
if f.exception():
|
||||
errors.append(str(f.exception()))
|
||||
|
||||
assert not errors, f"Encountered errors: {errors[:5]}..."
|
||||
|
||||
Loading…
Reference in New Issue
Block a user