mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
663 lines
21 KiB
Python
663 lines
21 KiB
Python
import asyncio
|
|
import time
|
|
|
|
import pytest
|
|
|
|
from tensorrt_llm.executor.rpc import (RPCCancelled, RPCClient, RPCError,
|
|
RPCServer, RPCStreamingError, RPCTimeout)
|
|
|
|
|
|
class RpcServerWrapper(RPCServer):
|
|
|
|
def __init__(self, *args, addr: str, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.addr = addr
|
|
|
|
def __enter__(self):
|
|
self.bind(self.addr)
|
|
self.start()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
self.shutdown()
|
|
|
|
|
|
class TestRpcBasics:
|
|
|
|
def test_rpc_server_basics(self):
|
|
|
|
class App:
|
|
|
|
def hello(self):
|
|
print("hello")
|
|
|
|
with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server:
|
|
pass
|
|
|
|
def test_remote_call_without_arg(self):
|
|
|
|
class App:
|
|
|
|
def hello(self):
|
|
print("hello")
|
|
return "world"
|
|
|
|
with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server:
|
|
with RPCClient("ipc:///tmp/rpc_test") as client:
|
|
ret = client.hello().remote() # sync call
|
|
assert ret == "world"
|
|
|
|
def test_remote_call_with_args(self):
|
|
|
|
class App:
|
|
|
|
def hello(self, name: str, location: str):
|
|
print("hello")
|
|
return f"hello {name} from {location}"
|
|
|
|
with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server:
|
|
with RPCClient("ipc:///tmp/rpc_test") as client:
|
|
ret = client.hello("app", "Marvel").remote()
|
|
assert ret == "hello app from Marvel"
|
|
|
|
def test_remote_call_with_kwargs(self):
|
|
|
|
class App:
|
|
|
|
def hello(self, name: str, location: str):
|
|
print("hello")
|
|
return f"hello {name} from {location}"
|
|
|
|
with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server:
|
|
with RPCClient("ipc:///tmp/rpc_test") as client:
|
|
ret = client.hello(name="app", location="Marvel").remote()
|
|
assert ret == "hello app from Marvel"
|
|
|
|
def test_remote_call_with_args_and_kwargs(self):
|
|
|
|
class App:
|
|
|
|
def hello(self, name: str, location: str):
|
|
print("hello")
|
|
return f"hello {name} from {location}"
|
|
|
|
with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server:
|
|
with RPCClient("ipc:///tmp/rpc_test") as client:
|
|
ret = client.hello(name="app", location="Marvel").remote()
|
|
assert ret == "hello app from Marvel"
|
|
|
|
def test_rpc_server_address(self):
|
|
|
|
class App:
|
|
pass
|
|
|
|
with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server:
|
|
assert server.address == "ipc:///tmp/rpc_test"
|
|
|
|
def test_rpc_with_error(self):
|
|
|
|
class App:
|
|
|
|
def hello(self):
|
|
raise ValueError("hello")
|
|
|
|
with RpcServerWrapper(App(),
|
|
addr="ipc:///tmp/rpc_test_error") as server:
|
|
with RPCClient("ipc:///tmp/rpc_test_error") as client:
|
|
with pytest.raises(RPCError):
|
|
client.hello().remote()
|
|
|
|
def test_rpc_without_wait_response(self):
|
|
|
|
class App:
|
|
|
|
def __init__(self):
|
|
self.task_submitted = False
|
|
|
|
def send_task(self) -> None:
|
|
# Just submit the task and return immediately
|
|
# The result is not important
|
|
self.task_submitted = True
|
|
return None
|
|
|
|
def get_task_submitted(self) -> bool:
|
|
return self.task_submitted
|
|
|
|
with RpcServerWrapper(App(),
|
|
addr="ipc:///tmp/rpc_test_no_wait") as server:
|
|
with RPCClient("ipc:///tmp/rpc_test_no_wait") as client:
|
|
client.send_task().remote(need_response=False)
|
|
time.sleep(
|
|
0.1
|
|
) # wait for some time to make sure the task is submitted
|
|
assert client.get_task_submitted().remote()
|
|
|
|
|
|
class TestRpcError:
|
|
|
|
class CustomError(Exception):
|
|
pass
|
|
|
|
def test_task_error(self):
|
|
"""Test that server-side exceptions are properly wrapped in RPCError with details."""
|
|
|
|
class App:
|
|
|
|
def hello(self):
|
|
raise ValueError("Test error message")
|
|
|
|
def divide_by_zero(self):
|
|
return 1 / 0
|
|
|
|
def custom_exception(self):
|
|
raise TestRpcError.CustomError("Custom error occurred")
|
|
|
|
with RPCServer(App()) as server:
|
|
server.bind("ipc:///tmp/rpc_test_error")
|
|
server.start()
|
|
time.sleep(0.1)
|
|
with RPCClient("ipc:///tmp/rpc_test_error") as client:
|
|
# Test ValueError handling
|
|
with pytest.raises(RPCError) as exc_info:
|
|
client.hello().remote()
|
|
|
|
error = exc_info.value
|
|
assert "Test error message" in str(error)
|
|
assert error.cause is not None
|
|
assert isinstance(error.cause, ValueError)
|
|
assert error.traceback is not None
|
|
assert "ValueError: Test error message" in error.traceback
|
|
|
|
# Test ZeroDivisionError handling
|
|
with pytest.raises(RPCError) as exc_info:
|
|
client.divide_by_zero().remote()
|
|
|
|
error = exc_info.value
|
|
assert "division by zero" in str(error)
|
|
assert error.cause is not None
|
|
assert isinstance(error.cause, ZeroDivisionError)
|
|
assert error.traceback is not None
|
|
|
|
# Test custom exception handling
|
|
with pytest.raises(RPCError) as exc_info:
|
|
client.custom_exception().remote()
|
|
|
|
error = exc_info.value
|
|
assert "Custom error occurred" in str(error)
|
|
assert error.cause is not None
|
|
assert error.traceback is not None
|
|
|
|
def test_shutdown_cancelled_error(self):
|
|
"""Test that pending requests are cancelled with RPCCancelled when server shuts down."""
|
|
|
|
class App:
|
|
|
|
def task(self):
|
|
time.sleep(10)
|
|
return True
|
|
|
|
addr = "ipc:///tmp/rpc_test_cancelled"
|
|
|
|
server = RPCServer(
|
|
App(),
|
|
# only one worker to make it easier to pend requests
|
|
num_workers=1)
|
|
server.bind(addr)
|
|
server.start()
|
|
time.sleep(0.1)
|
|
|
|
with RPCClient(addr) as client:
|
|
client.shutdown_server()
|
|
pending_futures = [client.task().remote_future() for _ in range(10)]
|
|
|
|
for future in pending_futures:
|
|
with pytest.raises(RPCCancelled):
|
|
future.result()
|
|
|
|
time.sleep(5)
|
|
|
|
client.close()
|
|
|
|
@pytest.mark.skip(reason="https://nvbugs/5579234")
|
|
def test_timeout_error(self):
|
|
"""Test that requests that exceed timeout are handled with proper error."""
|
|
|
|
class App:
|
|
|
|
def slow_method(self):
|
|
# Sleep longer than the timeout
|
|
time.sleep(2.0)
|
|
return "completed"
|
|
|
|
with RpcServerWrapper(App(),
|
|
addr="ipc:///tmp/rpc_test_timeout") as server:
|
|
time.sleep(0.1)
|
|
|
|
# Create client with short timeout
|
|
with RPCClient("ipc:///tmp/rpc_test_timeout",
|
|
timeout=0.5) as client:
|
|
with pytest.raises(RPCError) as exc_info:
|
|
client.slow_method().remote(timeout=0.5)
|
|
|
|
error = exc_info.value
|
|
# Should be either a timeout error or RPC error indicating timeout
|
|
assert "timed out" in str(error).lower() or "timeout" in str(
|
|
error).lower()
|
|
|
|
def test_method_not_found_error(self):
|
|
"""Test that calling non-existent methods returns proper error."""
|
|
|
|
class App:
|
|
|
|
def existing_method(self):
|
|
return "exists"
|
|
|
|
with RpcServerWrapper(App(),
|
|
addr="ipc:///tmp/rpc_test_not_found") as server:
|
|
time.sleep(0.1)
|
|
|
|
with RPCClient("ipc:///tmp/rpc_test_not_found") as client:
|
|
with pytest.raises(RPCError) as exc_info:
|
|
client.non_existent_method().remote()
|
|
|
|
error = exc_info.value
|
|
assert "not found" in str(error)
|
|
assert error.traceback is not None
|
|
|
|
|
|
def test_rpc_shutdown_server():
|
|
|
|
class App:
|
|
|
|
def hello(self):
|
|
return "world"
|
|
|
|
with RPCServer(App()) as server:
|
|
server.bind("ipc:///tmp/rpc_test_shutdown")
|
|
server.start()
|
|
time.sleep(0.1)
|
|
with RPCClient("ipc:///tmp/rpc_test_shutdown") as client:
|
|
ret = client.hello().remote()
|
|
assert ret == "world"
|
|
|
|
client.shutdown_server()
|
|
|
|
time.sleep(5) # the server dispatcher thread need some time to quit
|
|
|
|
|
|
def test_rpc_without_response_performance():
|
|
# At any circumstances, the RPC call without response should be faster than the one with response
|
|
class App:
|
|
|
|
def __init__(self):
|
|
self.task_submitted = False
|
|
|
|
def send_task(self) -> None:
|
|
# Just submit the task and return immediately
|
|
# The result is not important
|
|
time.sleep(0.001)
|
|
return None
|
|
|
|
with RPCServer(App(), num_workers=10) as server:
|
|
server.bind("ipc:///tmp/rpc_test_no_wait")
|
|
server.start()
|
|
time.sleep(0.1)
|
|
with RPCClient("ipc:///tmp/rpc_test_no_wait") as client:
|
|
time_start = time.time()
|
|
for i in range(100):
|
|
client.send_task().remote(need_response=False)
|
|
time_end = time.time()
|
|
|
|
no_wait_time = time_end - time_start
|
|
|
|
time_start = time.time()
|
|
for i in range(100):
|
|
client.send_task().remote(need_response=True)
|
|
time_end = time.time()
|
|
wait_time = time_end - time_start
|
|
|
|
assert no_wait_time < wait_time, f"{no_wait_time} > {wait_time}"
|
|
|
|
|
|
@pytest.mark.parametrize("async_run_task", [True, False])
|
|
@pytest.mark.parametrize("use_ipc_addr", [True, False])
|
|
def test_rpc_benchmark(async_run_task: bool, use_ipc_addr: bool):
|
|
|
|
class App:
|
|
|
|
def cal(self, n: int):
|
|
return n * 2
|
|
|
|
with RPCServer(App(), async_run_task=async_run_task) as server:
|
|
address = "ipc:///tmp/rpc_test" if use_ipc_addr else "tcp://127.0.0.1:*"
|
|
|
|
server.bind(address)
|
|
server.start()
|
|
time.sleep(0.1)
|
|
|
|
with RPCClient(server.address) as client:
|
|
|
|
time_start = time.time()
|
|
for i in range(100):
|
|
ret = client.cal(i).remote(timeout=10) # sync call
|
|
assert ret == i * 2, f"{ret} != {i * 2}"
|
|
time_end = time.time()
|
|
print(
|
|
f"Time taken: {time_end - time_start} seconds, {10000 / (time_end - time_start)} calls/second"
|
|
)
|
|
|
|
|
|
class TestRpcTimeout:
|
|
"""Test RPC timeout functionality for both sync and async calls, sharing server/client."""
|
|
|
|
class App:
|
|
|
|
def slow_operation(self, delay: float):
|
|
"""A method that takes a long time to complete."""
|
|
time.sleep(delay)
|
|
return "completed"
|
|
|
|
def setup_method(self, method):
|
|
"""Setup RPC server and client for timeout tests."""
|
|
# Use unique address based on the test parameter to avoid socket conflicts
|
|
test_name = method.__name__
|
|
self.address = f"ipc:///tmp/rpc_test_timeout_{test_name}_{id(self)}"
|
|
self.server = RPCServer(self.App())
|
|
self.server.bind(self.address)
|
|
self.server.start()
|
|
time.sleep(0.1)
|
|
self.client = RPCClient(self.address)
|
|
|
|
def teardown_method(self):
|
|
"""Shutdown server and close client."""
|
|
self.client.close()
|
|
self.server.shutdown()
|
|
# Add a small delay to ensure the socket is fully released before the next test
|
|
time.sleep(0.5)
|
|
|
|
def run_sync_timeout_test(self):
|
|
with pytest.raises(RPCTimeout) as exc_info:
|
|
self.client.slow_operation(2.0).remote(timeout=0.1)
|
|
assert "timed out" in str(
|
|
exc_info.value), f"Timeout message not found: {exc_info.value}"
|
|
|
|
def run_async_timeout_test(self):
|
|
import asyncio
|
|
|
|
async def async_timeout():
|
|
with pytest.raises(RPCTimeout) as exc_info:
|
|
await self.client.slow_operation(2.0).remote_async(timeout=0.1)
|
|
assert "timed out" in str(
|
|
exc_info.value), f"Timeout message not found: {exc_info.value}"
|
|
|
|
asyncio.run(async_timeout())
|
|
|
|
def run_sync_success_test(self):
|
|
result = self.client.slow_operation(0.1).remote(timeout=10.0)
|
|
assert result == "completed"
|
|
print(f"final result: {result}")
|
|
|
|
def run_async_success_test(self):
|
|
import asyncio
|
|
|
|
async def async_success():
|
|
result = await self.client.slow_operation(0.1).remote_async(
|
|
timeout=10.0)
|
|
assert result == "completed"
|
|
print(f"final result: {result}")
|
|
return result
|
|
|
|
return asyncio.run(async_success())
|
|
|
|
@pytest.mark.parametrize("use_async", [True, False])
|
|
def test_rpc_timeout(self, use_async):
|
|
if use_async:
|
|
self.run_async_timeout_test()
|
|
self.run_async_success_test()
|
|
else:
|
|
self.run_sync_timeout_test()
|
|
self.run_sync_success_test()
|
|
|
|
|
|
class TestRpcShutdown:
|
|
|
|
def test_duplicate_shutdown(self):
|
|
|
|
class App:
|
|
|
|
def quick_task(self, task_id: int):
|
|
return f"quick_task_{task_id}"
|
|
|
|
with RpcServerWrapper(App(),
|
|
addr="ipc:///tmp/rpc_test_shutdown") as server:
|
|
time.sleep(0.1)
|
|
with RPCClient("ipc:///tmp/rpc_test_shutdown") as client:
|
|
client.quick_task(1).remote()
|
|
|
|
# repeated shutdown should not raise an error
|
|
for i in range(10):
|
|
server.shutdown()
|
|
|
|
def test_submit_request_after_server_shutdown(self):
|
|
|
|
class App:
|
|
|
|
def foo(self, delay: int):
|
|
time.sleep(delay)
|
|
return "foo"
|
|
|
|
server = RPCServer(App())
|
|
server.bind("ipc:///tmp/rpc_test_shutdown")
|
|
server.start()
|
|
|
|
time.sleep(0.1)
|
|
with RPCClient("ipc:///tmp/rpc_test_shutdown") as client:
|
|
# This task should be continued after server shutdown
|
|
res = client.foo(10).remote_future(timeout=12)
|
|
|
|
# The shutdown will block until all pending requests are finished
|
|
server.shutdown()
|
|
|
|
assert res.result() == "foo"
|
|
|
|
|
|
class TestApp:
|
|
"""Test application with various method types."""
|
|
|
|
def __init__(self):
|
|
self.call_count = 0
|
|
|
|
def sync_add(self, a: int, b: int) -> int:
|
|
"""Sync method."""
|
|
self.call_count += 1
|
|
return a + b
|
|
|
|
async def async_multiply(self, x: int, y: int) -> int:
|
|
"""Async method."""
|
|
await asyncio.sleep(0.01)
|
|
self.call_count += 1
|
|
return x * y
|
|
|
|
async def streaming_range(self, n: int):
|
|
"""Streaming generator."""
|
|
for i in range(n):
|
|
await asyncio.sleep(0.01)
|
|
yield i
|
|
|
|
async def streaming_error(self, n: int):
|
|
"""Streaming generator that raises error."""
|
|
for i in range(n):
|
|
if i == 2:
|
|
raise ValueError("Test error at i=2")
|
|
yield i
|
|
|
|
async def streaming_timeout(self, delay: float):
|
|
"""Streaming generator with configurable delay."""
|
|
for i in range(10):
|
|
await asyncio.sleep(delay)
|
|
yield i
|
|
|
|
|
|
class TestRpcAsync:
|
|
# Use setup_method/teardown_method for pytest class-based setup/teardown
|
|
def setup_method(self):
|
|
"""Setup RPC server and client for tests."""
|
|
self.app = TestApp()
|
|
self.server = RPCServer(self.app, num_workers=2, async_run_task=True)
|
|
self.server.bind("tcp://127.0.0.1:0") # Use random port
|
|
self.server.start()
|
|
# Get actual address after binding
|
|
address = f"tcp://127.0.0.1:{self.server.address.split(':')[-1]}"
|
|
self.client = RPCClient(address)
|
|
|
|
def teardown_method(self):
|
|
self.server.shutdown()
|
|
self.client.close()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sync_method(self):
|
|
"""Test traditional sync method still works."""
|
|
app, client, server = self.app, self.client, self.server
|
|
|
|
# Test sync call
|
|
result = client.sync_add(5, 3).remote()
|
|
assert result == 8
|
|
assert app.call_count == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_method(self):
|
|
"""Test async method execution."""
|
|
app, client, server = self.app, self.client, self.server
|
|
|
|
# Test async call
|
|
result = await client.async_multiply(4, 7).remote_async()
|
|
assert result == 28
|
|
assert app.call_count == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_streaming_basic(self):
|
|
"""Test basic streaming functionality."""
|
|
app, client, server = self.app, self.client, self.server
|
|
|
|
results = []
|
|
async for value in client.streaming_range(5).remote_streaming():
|
|
results.append(value)
|
|
|
|
assert results == [0, 1, 2, 3, 4]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_streaming_concurrent(self):
|
|
"""Test concurrent streaming calls."""
|
|
app, client, server = self.app, self.client, self.server
|
|
|
|
async def collect_stream(n):
|
|
results = []
|
|
async for value in client.streaming_range(n).remote_streaming():
|
|
results.append(value)
|
|
return results
|
|
|
|
# Run 3 concurrent streams
|
|
results = await asyncio.gather(collect_stream(3), collect_stream(4),
|
|
collect_stream(5))
|
|
|
|
assert results[0] == [0, 1, 2]
|
|
assert results[1] == [0, 1, 2, 3]
|
|
assert results[2] == [0, 1, 2, 3, 4]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_streaming_error_handling(self):
|
|
"""Test error handling in streaming."""
|
|
app, client, server = self.app, self.client, self.server
|
|
|
|
results = []
|
|
with pytest.raises(RPCStreamingError, match="Test error at i=2"):
|
|
async for value in client.streaming_error(5).remote_streaming():
|
|
results.append(value)
|
|
|
|
# Should have received values before error
|
|
assert results == [0, 1]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_streaming_timeout(self):
|
|
"""Test timeout handling in streaming."""
|
|
app, client, server = self.app, self.client, self.server
|
|
|
|
# Set short timeout
|
|
with pytest.raises(RPCTimeout):
|
|
async for value in client.streaming_timeout(
|
|
delay=2.0).remote_streaming(timeout=0.5):
|
|
pass # Should timeout before first yield
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mixed_calls(self):
|
|
"""Test mixing different call types."""
|
|
app, client, server = self.app, self.client, self.server
|
|
|
|
# Run sync, async, and streaming calls together
|
|
sync_result = client.sync_add(1, 2).remote()
|
|
async_future = client.async_multiply(3, 4).remote_future()
|
|
|
|
streaming_results = []
|
|
async for value in client.streaming_range(3).remote_streaming():
|
|
streaming_results.append(value)
|
|
|
|
async_result = async_future.result()
|
|
|
|
assert sync_result == 3
|
|
assert async_result == 12
|
|
assert streaming_results == [0, 1, 2]
|
|
assert app.call_count == 2 # sync + async (streaming doesn't increment)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_streaming_call(self):
|
|
"""Test calling non-streaming method with streaming."""
|
|
app, client, server = self.app, self.client, self.server
|
|
|
|
# This should fail because sync_add is not an async generator
|
|
with pytest.raises(RPCStreamingError):
|
|
async for value in client.sync_add(1, 2).remote_streaming():
|
|
pass
|
|
|
|
|
|
class TestResponsePickleError:
|
|
""" The pickle error will break the whole server, test the error handling. """
|
|
|
|
class App:
|
|
|
|
def unpickleable_return(self):
|
|
# Functions defined locally are not pickleable
|
|
def nested_function():
|
|
pass
|
|
|
|
return nested_function
|
|
|
|
async def unpickleable_streaming_return(self):
|
|
# Functions defined locally are not pickleable
|
|
def nested_function():
|
|
pass
|
|
|
|
yield nested_function
|
|
|
|
def test_unpickleable_error(self):
|
|
with RpcServerWrapper(
|
|
self.App(), addr="ipc:///tmp/rpc_test_pickle_error") as server:
|
|
with RPCClient("ipc:///tmp/rpc_test_pickle_error") as client:
|
|
with pytest.raises(RPCError) as exc_info:
|
|
client.unpickleable_return().remote()
|
|
|
|
assert "Failed to pickle response" in str(exc_info.value)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unpickleable_streaming_error(self):
|
|
with RpcServerWrapper(self.App(),
|
|
addr="ipc:///tmp/rpc_test_pickle_error_streaming",
|
|
async_run_task=True) as server:
|
|
with RPCClient(
|
|
"ipc:///tmp/rpc_test_pickle_error_streaming") as client:
|
|
with pytest.raises(RPCStreamingError) as exc_info:
|
|
async for _ in client.unpickleable_streaming_return(
|
|
).remote_streaming():
|
|
pass
|
|
|
|
assert "Failed to pickle response" in str(exc_info.value)
|