TensorRT-LLMs/tests/unittest/executor/test_ipc.py
Yan Chunwei 04a39a4e2b
[None][chore] enable test_ipc.py (#9865)
Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com>
2025-12-11 17:47:14 +08:00

849 lines
26 KiB
Python

import asyncio
import time
from threading import Thread
import pytest
import zmq
from tensorrt_llm.executor.ipc import ZeroMqQueue
class TestIpcBasics:
"""Test basic synchronous IPC operations."""
def test_pair_socket_with_hmac(self):
"""Test PAIR socket with HMAC encryption."""
# Create server
server = ZeroMqQueue(
address=None,
socket_type=zmq.PAIR,
is_server=True,
is_async=False,
name="test_server",
use_hmac_encryption=True,
)
# Create client with server's address
client = ZeroMqQueue(
address=server.address,
socket_type=zmq.PAIR,
is_server=False,
is_async=False,
name="test_client",
use_hmac_encryption=True,
)
try:
# Test basic send/receive
test_data = {"message": "hello", "value": 42}
client.put(test_data)
received = server.get()
assert received == test_data
# Test reverse direction
response = {"status": "ok", "result": 100}
server.put(response)
received = client.get()
assert received == response
finally:
client.close()
server.close()
def test_pair_socket_without_hmac(self):
"""Test PAIR socket without HMAC encryption."""
# Create server without HMAC
server = ZeroMqQueue(
address=None,
socket_type=zmq.PAIR,
is_server=True,
is_async=False,
name="test_server_no_hmac",
use_hmac_encryption=False,
)
# Create client
client = ZeroMqQueue(
address=(server.address[0], None),
socket_type=zmq.PAIR,
is_server=False,
is_async=False,
name="test_client_no_hmac",
use_hmac_encryption=False,
)
try:
# Test send/receive
test_data = {"message": "hello without encryption", "numbers": [1, 2, 3]}
client.put(test_data)
received = server.get()
assert received == test_data
finally:
client.close()
server.close()
def test_poll_timeout(self):
"""Test poll timeout behavior."""
server = ZeroMqQueue(
address=None,
socket_type=zmq.PAIR,
is_server=True,
is_async=False,
name="test_poll_server",
use_hmac_encryption=False,
)
try:
# Poll should timeout when no data available
start = time.time()
result = server.poll(timeout=1)
elapsed = time.time() - start
assert result is False
assert elapsed >= 1.0
assert elapsed < 1.5 # Allow some margin
finally:
server.close()
def test_poll_with_data(self):
"""Test poll returns True when data is available."""
server = ZeroMqQueue(
address=None,
socket_type=zmq.PAIR,
is_server=True,
is_async=False,
name="test_poll_data_server",
use_hmac_encryption=False,
)
client = ZeroMqQueue(
address=(server.address[0], None),
socket_type=zmq.PAIR,
is_server=False,
is_async=False,
name="test_poll_data_client",
use_hmac_encryption=False,
)
try:
# Send data in background
def send_data():
time.sleep(0.1) # Small delay
client.put({"data": "test"})
thread = Thread(target=send_data)
thread.start()
# Poll should return True
result = server.poll(timeout=2)
assert result is True
# Verify data
received = server.get()
assert received == {"data": "test"}
thread.join()
finally:
client.close()
server.close()
def test_router_socket_with_hmac(self):
"""Test ROUTER socket with HMAC encryption and identity tracking."""
# Create ROUTER server
server = ZeroMqQueue(
address=None,
socket_type=zmq.ROUTER,
is_server=True,
is_async=False,
name="test_router_server",
use_hmac_encryption=True,
)
# Create DEALER client
client = ZeroMqQueue(
address=server.address,
socket_type=zmq.DEALER,
is_server=False,
is_async=False,
name="test_dealer_client",
use_hmac_encryption=True,
)
try:
# Client sends request
request = {"action": "process", "data": [1, 2, 3]}
client.put(request)
# Server receives and tracks identity
received = server.get()
assert received == request
# Server sends response (using stored identity)
response = {"status": "done", "result": 6}
server.put(response)
# Client receives response
received = client.get()
assert received == response
finally:
client.close()
server.close()
def test_dealer_notify_with_retry(self):
"""Test DEALER socket notify_with_retry mechanism."""
# Create ROUTER server
server = ZeroMqQueue(
address=None,
socket_type=zmq.ROUTER,
is_server=True,
is_async=False,
name="test_router_ack_server",
use_hmac_encryption=False,
)
# Create DEALER client
client = ZeroMqQueue(
address=(server.address[0], None),
socket_type=zmq.DEALER,
is_server=False,
is_async=False,
name="test_dealer_ack_client",
use_hmac_encryption=False,
)
try:
# Server thread that acknowledges messages
def server_ack():
msg = server.get()
assert msg == {"notify": "test"}
# Send ACK
server.put({"ack": True})
thread = Thread(target=server_ack)
thread.start()
# Client sends with retry
result = client.notify_with_retry({"notify": "test"}, max_retries=3, timeout=1)
assert result is True
thread.join()
finally:
client.close()
server.close()
def test_dealer_notify_with_retry_timeout(self):
"""Test DEALER socket notify_with_retry timeout behavior."""
# Create ROUTER server (but don't respond)
server = ZeroMqQueue(
address=None,
socket_type=zmq.ROUTER,
is_server=True,
is_async=False,
name="test_router_no_ack_server",
use_hmac_encryption=False,
)
# Create DEALER client
client = ZeroMqQueue(
address=(server.address[0], None),
socket_type=zmq.DEALER,
is_server=False,
is_async=False,
name="test_dealer_no_ack_client",
use_hmac_encryption=False,
)
try:
# Client sends but server doesn't acknowledge
result = client.notify_with_retry({"notify": "test"}, max_retries=2, timeout=0.5)
assert result is False
finally:
client.close()
server.close()
def test_hmac_key_generation(self):
"""Test that server generates HMAC key when encryption is enabled."""
server = ZeroMqQueue(
address=None,
socket_type=zmq.PAIR,
is_server=True,
is_async=False,
name="test_hmac_gen",
use_hmac_encryption=True,
)
try:
# Server should have generated an HMAC key
assert server.hmac_key is not None
assert len(server.hmac_key) == 32
finally:
server.close()
def test_hmac_validation_error_client_no_key(self):
"""Test that client without HMAC key raises ValueError when encryption enabled."""
with pytest.raises(ValueError, match="Client must receive HMAC key"):
ZeroMqQueue(
address=("tcp://127.0.0.1:5555", None), # No HMAC key
socket_type=zmq.PAIR,
is_server=False,
is_async=False,
name="test_client_no_key",
use_hmac_encryption=True, # But encryption enabled
)
def test_hmac_validation_error_key_when_disabled(self):
"""Test that providing HMAC key when encryption disabled raises ValueError."""
with pytest.raises(ValueError, match="should not receive HMAC key"):
ZeroMqQueue(
address=("tcp://127.0.0.1:5555", b"some_key"), # Has key
socket_type=zmq.PAIR,
is_server=False,
is_async=False,
name="test_client_key_disabled",
use_hmac_encryption=False, # But encryption disabled
)
def test_put_noblock_retry(self):
"""Test put_noblock with retry mechanism."""
server = ZeroMqQueue(
address=None,
socket_type=zmq.PAIR,
is_server=True,
is_async=False,
name="test_noblock_server",
use_hmac_encryption=False,
)
client = ZeroMqQueue(
address=(server.address[0], None),
socket_type=zmq.PAIR,
is_server=False,
is_async=False,
name="test_noblock_client",
use_hmac_encryption=False,
)
try:
# Send with put_noblock
test_data = {"nonblocking": True, "value": 123}
client.put_noblock(test_data, retry=3, wait_time=0.001)
# Should be able to receive
received = server.get()
assert received == test_data
finally:
client.close()
server.close()
class TestIpcAsyncBasics:
"""Test asynchronous IPC operations."""
@pytest.mark.asyncio
async def test_async_pair_with_hmac(self):
"""Test async PAIR socket with HMAC encryption."""
# Create async server
server = ZeroMqQueue(
address=None,
socket_type=zmq.PAIR,
is_server=True,
is_async=True,
name="async_server",
use_hmac_encryption=True,
)
# Create async client
client = ZeroMqQueue(
address=server.address,
socket_type=zmq.PAIR,
is_server=False,
is_async=True,
name="async_client",
use_hmac_encryption=True,
)
try:
# Test async send/receive
test_data = {"async": True, "value": 999}
await client.put_async(test_data)
received = await server.get_async()
assert received == test_data
# Test reverse direction
response = {"status": "async_ok"}
await server.put_async(response)
received = await client.get_async()
assert received == response
finally:
client.close()
server.close()
@pytest.mark.asyncio
async def test_async_pair_without_hmac(self):
"""Test async PAIR socket without HMAC encryption."""
server = ZeroMqQueue(
address=None,
socket_type=zmq.PAIR,
is_server=True,
is_async=True,
name="async_server_no_hmac",
use_hmac_encryption=False,
)
client = ZeroMqQueue(
address=(server.address[0], None),
socket_type=zmq.PAIR,
is_server=False,
is_async=True,
name="async_client_no_hmac",
use_hmac_encryption=False,
)
try:
# Test async operations
test_data = {"no_encryption": True, "items": [1, 2, 3, 4, 5]}
await client.put_async(test_data)
received = await server.get_async()
assert received == test_data
finally:
client.close()
server.close()
@pytest.mark.asyncio
async def test_async_router_with_identity(self):
"""Test async ROUTER socket with identity handling."""
server = ZeroMqQueue(
address=None,
socket_type=zmq.ROUTER,
is_server=True,
is_async=True,
name="async_router_server",
use_hmac_encryption=True,
)
client = ZeroMqQueue(
address=server.address,
socket_type=zmq.DEALER,
is_server=False,
is_async=True,
name="async_dealer_client",
use_hmac_encryption=True,
)
try:
# Client sends async request
request = {"async_request": "process"}
await client.put_async(request)
# Server receives with identity
received = await server.get_async()
assert received == request
# Server replies
response = {"async_response": "completed"}
await server.put_async(response)
# Client receives
received = await client.get_async()
assert received == response
finally:
client.close()
server.close()
@pytest.mark.asyncio
async def test_get_async_noblock_timeout(self):
"""Test get_async_noblock timeout expiration."""
server = ZeroMqQueue(
address=None,
socket_type=zmq.PAIR,
is_server=True,
is_async=True,
name="async_timeout_server",
use_hmac_encryption=False,
)
try:
# Should timeout when no data available
with pytest.raises(asyncio.TimeoutError):
await server.get_async_noblock(timeout=0.5)
finally:
server.close()
@pytest.mark.asyncio
async def test_get_async_noblock_success(self):
"""Test get_async_noblock successful receive before timeout."""
server = ZeroMqQueue(
address=None,
socket_type=zmq.PAIR,
is_server=True,
is_async=True,
name="async_noblock_server",
use_hmac_encryption=False,
)
client = ZeroMqQueue(
address=(server.address[0], None),
socket_type=zmq.PAIR,
is_server=False,
is_async=True,
name="async_noblock_client",
use_hmac_encryption=False,
)
try:
# Send data in background
async def send_delayed():
await asyncio.sleep(0.1)
await client.put_async({"delayed": True})
send_task = asyncio.create_task(send_delayed())
# Should receive before timeout
received = await server.get_async_noblock(timeout=2.0)
assert received == {"delayed": True}
await send_task
finally:
client.close()
server.close()
@pytest.mark.asyncio
async def test_put_async_noblock(self):
"""Test put_async_noblock with NOBLOCK flag."""
server = ZeroMqQueue(
address=None,
socket_type=zmq.PAIR,
is_server=True,
is_async=True,
name="async_put_noblock_server",
use_hmac_encryption=False,
)
client = ZeroMqQueue(
address=(server.address[0], None),
socket_type=zmq.PAIR,
is_server=False,
is_async=True,
name="async_put_noblock_client",
use_hmac_encryption=False,
)
try:
# Send with noblock
test_data = {"noblock_async": True}
await client.put_async_noblock(test_data)
# Should be able to receive
received = await server.get_async()
assert received == test_data
finally:
client.close()
server.close()
@pytest.mark.asyncio
async def test_async_router_without_hmac(self):
"""Test async ROUTER socket without HMAC encryption."""
server = ZeroMqQueue(
address=None,
socket_type=zmq.ROUTER,
is_server=True,
is_async=True,
name="async_router_server_no_hmac",
use_hmac_encryption=False,
)
client = ZeroMqQueue(
address=server.address,
socket_type=zmq.DEALER,
is_server=False,
is_async=True,
name="async_dealer_client_no_hmac",
use_hmac_encryption=False,
)
try:
# Client sends async request
request = {"async_request": "process_no_hmac"}
await client.put_async(request)
# Server receives with identity
received = await server.get_async()
assert received == request
# Server replies
response = {"async_response": "completed_no_hmac"}
await server.put_async(response)
# Client receives
received = await client.get_async()
assert received == response
finally:
client.close()
server.close()
@pytest.mark.asyncio
async def test_async_router_get_noblock(self):
"""Test get_async_noblock on ROUTER socket (handling multipart)."""
server = ZeroMqQueue(
address=None,
socket_type=zmq.ROUTER,
is_server=True,
is_async=True,
name="async_router_noblock_server",
use_hmac_encryption=False,
)
client = ZeroMqQueue(
address=server.address,
socket_type=zmq.DEALER,
is_server=False,
is_async=True,
name="async_dealer_noblock_client",
use_hmac_encryption=False,
)
try:
# Client sends async request
request = {"noblock_request": "test"}
# Send with small delay to ensure we test the polling/waiting
async def send_delayed():
await asyncio.sleep(0.1)
await client.put_async(request)
send_task = asyncio.create_task(send_delayed())
# Server receives using get_async_noblock
# This exercises the ROUTER specific recv_multipart path
received = await server.get_async_noblock(timeout=2.0)
assert received == request
# Ensure identity was captured so we can reply
assert server._last_identity is not None
# Server replies
response = {"noblock_response": "done"}
await server.put_async(response)
# Client receives
received = await client.get_async()
assert received == response
await send_task
finally:
client.close()
server.close()
class TestIpcPressureTest:
"""Test performance and load handling."""
def test_high_frequency_small_messages(self):
"""Test sending many small messages rapidly."""
server = ZeroMqQueue(
address=None,
socket_type=zmq.PAIR,
is_server=True,
is_async=False,
name="pressure_server",
use_hmac_encryption=False,
)
client = ZeroMqQueue(
address=(server.address[0], None),
socket_type=zmq.PAIR,
is_server=False,
is_async=False,
name="pressure_client",
use_hmac_encryption=False,
)
num_messages = 10000
try:
# Send many small messages
def sender():
for i in range(num_messages):
client.put({"id": i, "data": f"msg_{i}"})
# Receive in parallel
def receiver():
received_count = 0
for i in range(num_messages):
msg = server.get()
assert msg["id"] == i
assert msg["data"] == f"msg_{i}"
received_count += 1
return received_count
send_thread = Thread(target=sender)
start_time = time.time()
send_thread.start()
count = receiver()
send_thread.join()
elapsed = time.time() - start_time
# Verify all messages received
assert count == num_messages
print(
f"\nHigh frequency test: {num_messages} messages in {elapsed:.2f}s "
f"({num_messages / elapsed:.0f} msg/s)"
)
finally:
client.close()
server.close()
def test_large_message_size(self):
"""Test sending large messages with HMAC encryption."""
server = ZeroMqQueue(
address=None,
socket_type=zmq.PAIR,
is_server=True,
is_async=False,
name="large_msg_server",
use_hmac_encryption=True,
)
client = ZeroMqQueue(
address=server.address,
socket_type=zmq.PAIR,
is_server=False,
is_async=False,
name="large_msg_client",
use_hmac_encryption=True,
)
num_messages = 100
message_size = 1024 * 1024 # 1 MB
try:
start_time = time.time()
for i in range(num_messages):
# Create large message (1 MB of data)
large_data = {"id": i, "payload": "x" * message_size}
client.put(large_data)
received = server.get()
assert received["id"] == i
assert len(received["payload"]) == message_size
elapsed = time.time() - start_time
total_mb = (num_messages * message_size) / (1024 * 1024)
print(
f"\nLarge message test: {num_messages} x 1MB messages in {elapsed:.2f}s "
f"({total_mb / elapsed:.1f} MB/s)"
)
finally:
client.close()
server.close()
@pytest.mark.asyncio
async def test_concurrent_async_access(self):
"""Test multiple async coroutines sending/receiving simultaneously."""
server = ZeroMqQueue(
address=None,
socket_type=zmq.PAIR,
is_server=True,
is_async=True,
name="concurrent_server",
use_hmac_encryption=False,
)
client = ZeroMqQueue(
address=(server.address[0], None),
socket_type=zmq.PAIR,
is_server=False,
is_async=True,
name="concurrent_client",
use_hmac_encryption=False,
)
num_messages = 1000
try:
# Sender coroutine
async def sender():
for i in range(num_messages):
await client.put_async({"id": i, "data": f"concurrent_{i}"})
if i % 100 == 0:
await asyncio.sleep(0.001) # Small yield
# Receiver coroutine
async def receiver():
received_ids = set()
for _ in range(num_messages):
msg = await server.get_async()
received_ids.add(msg["id"])
return received_ids
# Run concurrently
start_time = time.time()
sender_task = asyncio.create_task(sender())
receiver_task = asyncio.create_task(receiver())
received_ids = await receiver_task
await sender_task
elapsed = time.time() - start_time
# Verify all messages received
assert len(received_ids) == num_messages
assert received_ids == set(range(num_messages))
print(f"\nConcurrent async test: {num_messages} messages in {elapsed:.2f}s")
finally:
client.close()
server.close()
def test_router_socket_multiple_requests(self):
"""Test ROUTER socket handling multiple sequential requests."""
server = ZeroMqQueue(
address=None,
socket_type=zmq.ROUTER,
is_server=True,
is_async=False,
name="router_load_server",
use_hmac_encryption=False,
)
client = ZeroMqQueue(
address=(server.address[0], None),
socket_type=zmq.DEALER,
is_server=False,
is_async=False,
name="dealer_load_client",
use_hmac_encryption=False,
)
num_requests = 1000
try:
start_time = time.time()
for i in range(num_requests):
# Client sends request
client.put({"request_id": i, "action": "process"})
# Server receives
request = server.get()
assert request["request_id"] == i
# Server responds
server.put({"request_id": i, "result": i * 2})
# Client receives response
response = client.get()
assert response["request_id"] == i
assert response["result"] == i * 2
elapsed = time.time() - start_time
print(
f"\nROUTER socket test: {num_requests} round-trips in {elapsed:.2f}s "
f"({num_requests / elapsed:.0f} req/s)"
)
finally:
client.close()
server.close()