[https://nvbugs/5720482][fix] Fix test rpc streaming (#9902)

Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com>
This commit is contained in:
Yan Chunwei 2025-12-13 17:14:43 +08:00 committed by GitHub
parent 8cbf2d958c
commit 85406f9dda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 232 additions and 71 deletions

View File

@ -59,6 +59,7 @@ class ZeroMqQueue:
self._setup_done = False
self.name = name
self.socket = self.context.socket(socket_type)
self.socket.set_hwm(0)
# For ROUTER sockets, track the last identity to enable replies. For now we assume there is only one client in our case.
self._last_identity = None
@ -154,14 +155,14 @@ class ZeroMqQueue:
else:
return False
def put(self, obj: Any):
def put(self, obj: Any, routing_id: Optional[bytes] = None):
self.setup_lazily()
self._check_thread_safety()
with nvtx_range_debug("send", color="blue", category="IPC"):
if self.use_hmac_encryption or self.socket_type == zmq.ROUTER:
# Need manual serialization for encryption or ROUTER multipart
data = self._prepare_data(obj)
self._send_data(data)
self._send_data(data, routing_id=routing_id)
else:
# Standard socket without encryption - use pyobj directly
self.socket.send_pyobj(obj)
@ -197,14 +198,14 @@ class ZeroMqQueue:
else:
logger.error(f"Failed to send object: {obj}")
async def put_async(self, obj: Any):
async def put_async(self, obj: Any, routing_id: Optional[bytes] = None):
self.setup_lazily()
self._check_thread_safety()
try:
if self.use_hmac_encryption or self.socket_type == zmq.ROUTER:
# Need manual serialization for encryption or ROUTER multipart
data = self._prepare_data(obj)
await self._send_data_async(data)
await self._send_data_async(data, routing_id=routing_id)
else:
# Standard socket without encryption
await self.socket.send_pyobj(obj)
@ -243,7 +244,9 @@ class ZeroMqQueue:
self._check_thread_safety()
return await self._recv_data_async()
async def get_async_noblock(self, timeout: float = 0.5) -> Any:
async def get_async_noblock(self,
timeout: float = 0.5,
return_identity: bool = False) -> Any:
"""Get data with timeout using polling to avoid message drops.
This method uses ZMQ's NOBLOCK flag with polling instead of asyncio.wait_for
@ -251,9 +254,10 @@ class ZeroMqQueue:
Args:
timeout: Timeout in seconds
return_identity: Whether to return the identity of the sender (for ROUTER sockets)
Returns:
The received object
The received object, or (object, identity) if return_identity is True
Raises:
asyncio.TimeoutError: If timeout is reached without receiving data
@ -271,13 +275,22 @@ class ZeroMqQueue:
identity, data = await self.socket.recv_multipart(
flags=zmq.NOBLOCK)
self._last_identity = identity
return self._parse_data(data)
obj = self._parse_data(data)
if return_identity:
return obj, identity
else:
return obj
else:
if self.use_hmac_encryption:
data = await self.socket.recv(flags=zmq.NOBLOCK)
return self._parse_data(data)
obj = self._parse_data(data)
else:
return await self.socket.recv_pyobj(flags=zmq.NOBLOCK)
obj = await self.socket.recv_pyobj(flags=zmq.NOBLOCK)
if return_identity:
return obj, None
else:
return obj
except zmq.Again:
# No message available yet
if asyncio.get_event_loop().time() >= deadline:
@ -329,30 +342,39 @@ class ZeroMqQueue:
else:
return pickle.loads(data) # nosec B301
def _send_data(self, data: bytes, flags: int = 0):
def _send_data(self,
data: bytes,
flags: int = 0,
routing_id: Optional[bytes] = None):
"""Send data using appropriate API based on socket type."""
if self.socket_type == zmq.ROUTER:
if self._last_identity is None:
identity = routing_id if routing_id is not None else self._last_identity
if identity is None:
raise ValueError("ROUTER socket requires identity")
self.socket.send_multipart([self._last_identity, data], flags=flags)
self.socket.send_multipart([identity, data], flags=flags)
else:
self.socket.send(data, flags=flags)
async def _send_data_async(self, data: bytes):
async def _send_data_async(self,
data: bytes,
routing_id: Optional[bytes] = None):
"""Async version of _send_data."""
if self.socket_type == zmq.ROUTER:
if self._last_identity is None:
identity = routing_id if routing_id is not None else self._last_identity
if identity is None:
raise ValueError("ROUTER socket requires identity")
await self.socket.send_multipart([self._last_identity, data])
await self.socket.send_multipart([identity, data])
else:
await self.socket.send(data)
def _recv_data(self) -> Any:
def _recv_data(self, return_identity: bool = False) -> Any:
"""Receive data using appropriate API based on socket type."""
if self.socket_type == zmq.ROUTER:
identity, data = self.socket.recv_multipart()
self._last_identity = identity # Store for replies
obj = self._parse_data(data)
if return_identity:
return obj, identity
return obj
else:
if self.use_hmac_encryption:
@ -360,20 +382,30 @@ class ZeroMqQueue:
obj = self._parse_data(data)
else:
obj = self.socket.recv_pyobj()
if return_identity:
return obj, None
return obj
async def _recv_data_async(self) -> Any:
async def _recv_data_async(self, return_identity: bool = False) -> Any:
"""Async version of _recv_data."""
if self.socket_type == zmq.ROUTER:
identity, data = await self.socket.recv_multipart()
self._last_identity = identity # Store for replies
return self._parse_data(data)
obj = self._parse_data(data)
if return_identity:
return obj, identity
return obj
else:
if self.use_hmac_encryption:
data = await self.socket.recv()
return self._parse_data(data)
obj = self._parse_data(data)
else:
return await self.socket.recv_pyobj()
obj = await self.socket.recv_pyobj()
if return_identity:
return obj, None
return obj
def notify_with_retry(self, message, max_retries=5, timeout=1):
"""

View File

@ -75,6 +75,7 @@ class RPCRequest:
is_streaming: bool = False
creation_timestamp: Optional[
float] = None # Unix timestamp when request was created
routing_id: Optional[bytes] = None
def __post_init__(self):
"""Initialize creation_timestamp if not provided."""

View File

@ -228,8 +228,10 @@ class RPCServer:
while asyncio.get_event_loop().time() < end_time:
try:
req: RPCRequest = await asyncio.wait_for(
self._client_socket.get_async_noblock(), timeout=2)
req, routing_id = await asyncio.wait_for(
self._client_socket.get_async_noblock(return_identity=True),
timeout=2)
req.routing_id = routing_id
drained_count += 1
logger_debug(f"[server] Draining request after shutdown: {req}")
@ -299,13 +301,16 @@ class RPCServer:
error=error,
is_streaming=
True, # Important: mark as streaming so it gets routed correctly
stream_status='error'))
stream_status='error'),
routing_id=req.routing_id)
logger_debug(
f"[server] Sent error response for request {req.request_id}",
color="green")
else:
await self._client_socket.put_async(
RPCResponse(req.request_id, result=None, error=error))
await self._client_socket.put_async(RPCResponse(req.request_id,
result=None,
error=error),
routing_id=req.routing_id)
logger_debug(
f"[server] Sent error response for request {req.request_id}",
color="green")
@ -335,8 +340,10 @@ class RPCServer:
try:
#logger_debug(f"[server] Worker waiting for request", color="green")
# Read request directly from socket with timeout
req: RPCRequest = await asyncio.wait_for(
self._client_socket.get_async_noblock(), timeout=2)
req, routing_id = await asyncio.wait_for(
self._client_socket.get_async_noblock(return_identity=True),
timeout=2)
req.routing_id = routing_id
logger_debug(f"[server] Worker got request: {req}",
color="green")
except asyncio.TimeoutError:
@ -492,15 +499,15 @@ class RPCServer:
func = self._functions[req.method_name]
if not inspect.isasyncgenfunction(func):
await self._client_socket.put_async(
RPCResponse(
req.request_id,
result=None,
error=RPCStreamingError(
f"Method '{req.method_name}' is not an async generator.",
traceback=traceback.format_exc()),
is_streaming=True,
stream_status='error'))
await self._client_socket.put_async(RPCResponse(
req.request_id,
result=None,
error=RPCStreamingError(
f"Method '{req.method_name}' is not an async generator.",
traceback=traceback.format_exc()),
is_streaming=True,
stream_status='error'),
routing_id=req.routing_id)
return
chunk_index = 0
@ -512,13 +519,14 @@ class RPCServer:
logger_debug(
f"[server] RPC Server running streaming task {req.method_name}")
# Send start signal
await self._client_socket.put_async(
RPCResponse(req.request_id,
result=None,
error=None,
is_streaming=True,
chunk_index=chunk_index,
stream_status='start'))
await self._client_socket.put_async(RPCResponse(
req.request_id,
result=None,
error=None,
is_streaming=True,
chunk_index=chunk_index,
stream_status='start'),
routing_id=req.routing_id)
logger_debug(
f"[server] Sent start signal for request {req.request_id}",
color="green")
@ -584,39 +592,41 @@ class RPCServer:
chunk_index += 1
# Send end signal
await self._client_socket.put_async(
RPCResponse(req.request_id,
result=None,
error=None,
is_streaming=True,
chunk_index=chunk_index,
stream_status='end'))
await self._client_socket.put_async(RPCResponse(
req.request_id,
result=None,
error=None,
is_streaming=True,
chunk_index=chunk_index,
stream_status='end'),
routing_id=req.routing_id)
logger_debug(
f"[server] Sent end signal for request {req.request_id}",
color="green")
except RPCCancelled as e:
# Server is shutting down, send cancelled error
await self._client_socket.put_async(
RPCResponse(req.request_id,
result=None,
error=e,
is_streaming=True,
chunk_index=chunk_index,
stream_status='error'))
await self._client_socket.put_async(RPCResponse(
req.request_id,
result=None,
error=e,
is_streaming=True,
chunk_index=chunk_index,
stream_status='error'),
routing_id=req.routing_id)
logger_debug(
f"[server] Sent error signal for request {req.request_id}",
color="green")
except asyncio.TimeoutError:
await self._client_socket.put_async(
RPCResponse(
req.request_id,
result=None,
error=RPCTimeout(
f"Streaming method '{req.method_name}' timed out",
traceback=traceback.format_exc()),
is_streaming=True,
chunk_index=chunk_index,
stream_status='error'))
await self._client_socket.put_async(RPCResponse(
req.request_id,
result=None,
error=RPCTimeout(
f"Streaming method '{req.method_name}' timed out",
traceback=traceback.format_exc()),
is_streaming=True,
chunk_index=chunk_index,
stream_status='error'),
routing_id=req.routing_id)
except Exception as e:
response = RPCResponse(
@ -633,7 +643,8 @@ class RPCServer:
response: RPCResponse) -> bool:
"""Safely sends a response, handling pickle errors."""
try:
await self._client_socket.put_async(response)
await self._client_socket.put_async(response,
routing_id=req.routing_id)
logger_debug(f"[server] Sent response for request {req.request_id}",
color="green")
return True
@ -661,7 +672,8 @@ class RPCServer:
traceback=traceback.format_exc()))
try:
await self._client_socket.put_async(error_response)
await self._client_socket.put_async(error_response,
routing_id=req.routing_id)
logger_debug(
f"[server] Sent error response for request {req.request_id}",
color="green")

View File

@ -422,6 +422,5 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUT
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_2gpus[cutlass-two_model-overlap_scheduler] SKIP (https://nvbugs/5702826)
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-two_model-overlap_scheduler] SKIP (https://nvbugs/5702826)
disaggregated/test_auto_scaling.py::test_worker_restart[etcd-round_robin] SKIP (https://nvbugs/5726118)
unittest/executor/test_rpc.py::TestRpcCorrectness::test_incremental_task_streaming SKIP (https://nvbugs/5720482)
unittest/llmapi/test_llm_pytorch.py::test_llm_reward_model SKIP (https://nvbugs/5670458)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[enable_configurable_moe-moe_backend=TRTLLM-mtp_nextn=0-ep4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] SKIP (https://nvbugs/5727475)

View File

@ -1,4 +1,5 @@
import asyncio
import concurrent.futures
import threading
import time
@ -200,7 +201,9 @@ class TestRpcCorrectness:
) == no + 1, f"result {future.result()} != {no + 1}"
def test_incremental_task_streaming(self):
with RpcServerWrapper(TestRpcCorrectness.App()) as server:
with RpcServerWrapper(TestRpcCorrectness.App(),
async_run_task=True) as server:
with RPCClient(server.addr) as client:
async def test_streaming_task():
@ -218,6 +221,30 @@ class TestRpcCorrectness:
asyncio.run(test_streaming_task())
def test_multi_client_to_single_server(self):
"""Test that multiple RPC clients can concurrently connect to a single RPC server and execute tasks."""
class App:
def echo(self, msg: str) -> str:
return msg
with RpcServerWrapper(App()) as server:
# Create multiple clients
num_clients = 10
clients = [RPCClient(server.addr) for _ in range(num_clients)]
try:
# Perform requests from all clients
for i, client in enumerate(clients):
msg = f"hello from client {i}"
ret = client.echo(msg).remote()
assert ret == msg, f"Client {i} failed: expected '{msg}', got '{ret}'"
finally:
# Clean up clients
for client in clients:
client.close()
class TestRpcError:
@ -1006,3 +1033,93 @@ class TestRpcRobustness:
f"Iteration {i}/{num_calls} completed successfully")
print(f"All {num_calls} iterations completed successfully")
@pytest.mark.parametrize("concurrency", [10, 50, 100])
def test_many_client_to_single_server(self, concurrency):
"""
Pressure test where many clients connect to a single server.
Controls concurrency via parameter and ensures each client performs multiple operations.
"""
class App:
def echo(self, msg: str) -> str:
return msg
total_clients = max(200, concurrency * 2)
requests_per_client = 100
with RpcServerWrapper(App(), async_run_task=True) as server:
errors = []
def run_client_session(client_id):
try:
with RPCClient(server.addr) as client:
for i in range(requests_per_client):
msg = f"c{client_id}-req{i}"
ret = client.echo(msg).remote()
assert ret == msg
except Exception as e:
errors.append(f"Client {client_id} error: {e}")
raise
with concurrent.futures.ThreadPoolExecutor(
max_workers=concurrency) as executor:
futures = [
executor.submit(run_client_session, i)
for i in range(total_clients)
]
concurrent.futures.wait(futures)
# Check for exceptions in futures
for f in futures:
if f.exception():
errors.append(str(f.exception()))
assert not errors, f"Encountered errors: {errors[:5]}..."
@pytest.mark.parametrize("concurrency", [10, 50, 100])
def test_many_client_to_single_server_threaded(self, concurrency):
"""
Pressure test where clients are created and used in different threads.
"""
import concurrent.futures
class App:
def echo(self, msg: str) -> str:
return msg
# Scale total clients to be more than concurrency to force queueing/reuse
total_clients = max(200, concurrency * 2)
requests_per_client = 100
with RpcServerWrapper(App(), async_run_task=True) as server:
errors = []
def run_client_session(client_id):
try:
# Client creation and usage happens strictly within this thread
with RPCClient(server.addr) as client:
for i in range(requests_per_client):
msg = f"c{client_id}-req{i}"
ret = client.echo(msg).remote()
assert ret == msg
except Exception as e:
errors.append(f"Client {client_id} error: {e}")
raise
# Use ThreadPoolExecutor to simulate concurrent threads
with concurrent.futures.ThreadPoolExecutor(
max_workers=concurrency) as executor:
futures = [
executor.submit(run_client_session, i)
for i in range(total_clients)
]
concurrent.futures.wait(futures)
for f in futures:
if f.exception():
errors.append(str(f.exception()))
assert not errors, f"Encountered errors: {errors[:5]}..."