TensorRT-LLMs/tests/unittest/bindings/test_transfer_agent_bindings.py
Chuang Zhu 536a8f6a9c
[TRTLLM-9527][feat] Add transferAgent binding (step 1) (#10113)
Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
2026-01-06 08:40:38 +08:00

586 lines
20 KiB
Python

import pytest
# Try to import the transfer agent binding module
try:
import tensorrt_llm.tensorrt_llm_transfer_agent_binding as tab
HAS_TRANSFER_AGENT = True
# Check which backends are available
HAS_NIXL = getattr(tab, "NIXL_ENABLED", False)
HAS_MOONCAKE = getattr(tab, "MOONCAKE_ENABLED", False)
except ImportError:
HAS_TRANSFER_AGENT = False
HAS_NIXL = False
HAS_MOONCAKE = False
# Try to import torch for functional tests
try:
import torch
HAS_TORCH = True
HAS_CUDA = torch.cuda.is_available()
except ImportError:
HAS_TORCH = False
HAS_CUDA = False
pytestmark = pytest.mark.skipif(
not HAS_TRANSFER_AGENT,
reason="Transfer agent bindings not available (tensorrt_llm_transfer_agent_binding)",
)
# =============================================================================
# Common Tests (independent of backend)
# =============================================================================
def test_memory_type_enum():
"""Test MemoryType enum values."""
assert tab.MemoryType.DRAM is not None
assert tab.MemoryType.VRAM is not None
assert tab.MemoryType.BLK is not None
assert tab.MemoryType.OBJ is not None
assert tab.MemoryType.FILE is not None
# Verify they are distinct
assert tab.MemoryType.DRAM != tab.MemoryType.VRAM
assert tab.MemoryType.VRAM != tab.MemoryType.BLK
def test_transfer_op_enum():
"""Test TransferOp enum values."""
assert tab.TransferOp.READ is not None
assert tab.TransferOp.WRITE is not None
assert tab.TransferOp.READ != tab.TransferOp.WRITE
def test_transfer_state_enum():
"""Test TransferState enum values."""
assert tab.TransferState.IN_PROGRESS is not None
assert tab.TransferState.SUCCESS is not None
assert tab.TransferState.FAILURE is not None
# Verify they are distinct
assert tab.TransferState.IN_PROGRESS != tab.TransferState.SUCCESS
assert tab.TransferState.SUCCESS != tab.TransferState.FAILURE
assert tab.TransferState.IN_PROGRESS != tab.TransferState.FAILURE
def test_memory_desc():
"""Test MemoryDesc class."""
addr = 0x1000
length = 4096
device_id = 0
desc = tab.MemoryDesc(addr, length, device_id)
assert desc.addr == addr
assert desc.len == length
assert desc.device_id == device_id
def test_memory_desc_different_values():
"""Test MemoryDesc with different values."""
test_cases = [
(0x0, 1, 0),
(0xFFFFFFFF, 65536, 1),
(0x12345678, 1024, 7),
]
for addr, length, device_id in test_cases:
desc = tab.MemoryDesc(addr, length, device_id)
assert desc.addr == addr
assert desc.len == length
assert desc.device_id == device_id
def test_memory_descs():
"""Test MemoryDescs class."""
desc1 = tab.MemoryDesc(0x1000, 4096, 0)
desc2 = tab.MemoryDesc(0x2000, 8192, 0)
descs = tab.MemoryDescs(tab.MemoryType.VRAM, [desc1, desc2])
assert descs.type == tab.MemoryType.VRAM
assert len(descs.descs) == 2
assert descs.descs[0].addr == 0x1000
assert descs.descs[1].addr == 0x2000
def test_memory_descs_empty():
"""Test MemoryDescs with empty list."""
descs = tab.MemoryDescs(tab.MemoryType.DRAM, [])
assert descs.type == tab.MemoryType.DRAM
assert len(descs.descs) == 0
def test_agent_desc_from_string():
"""Test AgentDesc from string."""
test_data = "test_agent_descriptor"
desc = tab.AgentDesc(test_data)
assert desc.backend_agent_desc == test_data.encode()
def test_agent_desc_from_bytes():
"""Test AgentDesc from bytes."""
test_data = b"test_binary_data\x00\x01\x02"
desc = tab.AgentDesc(test_data)
assert desc.backend_agent_desc == test_data
def test_base_agent_config_default():
"""Test BaseAgentConfig with default values."""
config = tab.BaseAgentConfig()
# Default values should be set
assert config is not None
def test_base_agent_config_custom():
"""Test BaseAgentConfig with custom values."""
name = "test_agent"
use_prog_thread = True
multi_thread = False
use_listen_thread = True
num_workers = 4
config = tab.BaseAgentConfig(
name=name,
use_prog_thread=use_prog_thread,
multi_thread=multi_thread,
use_listen_thread=use_listen_thread,
num_workers=num_workers,
)
assert config.name == name
assert config.use_prog_thread == use_prog_thread
assert config.multi_thread == multi_thread
assert config.use_listen_thread == use_listen_thread
assert config.num_workers == num_workers
def test_base_agent_config_readwrite():
"""Test BaseAgentConfig read/write properties."""
config = tab.BaseAgentConfig()
config.name = "modified_name"
assert config.name == "modified_name"
config.use_prog_thread = False
assert config.use_prog_thread is False
config.multi_thread = True
assert config.multi_thread is True
config.use_listen_thread = True
assert config.use_listen_thread is True
config.num_workers = 8
assert config.num_workers == 8
def test_transfer_request():
"""Test TransferRequest class."""
src_desc = tab.MemoryDesc(0x1000, 4096, 0)
dst_desc = tab.MemoryDesc(0x2000, 4096, 1)
src_descs = tab.MemoryDescs(tab.MemoryType.VRAM, [src_desc])
dst_descs = tab.MemoryDescs(tab.MemoryType.VRAM, [dst_desc])
remote_name = "remote_agent"
request = tab.TransferRequest(tab.TransferOp.WRITE, src_descs, dst_descs, remote_name)
assert request.op == tab.TransferOp.WRITE
assert request.remote_name == remote_name
assert request.src_descs.type == tab.MemoryType.VRAM
assert request.dst_descs.type == tab.MemoryType.VRAM
def test_transfer_request_read_op():
"""Test TransferRequest with READ operation."""
src_desc = tab.MemoryDesc(0x3000, 2048, 0)
dst_desc = tab.MemoryDesc(0x4000, 2048, 0)
src_descs = tab.MemoryDescs(tab.MemoryType.DRAM, [src_desc])
dst_descs = tab.MemoryDescs(tab.MemoryType.DRAM, [dst_desc])
request = tab.TransferRequest(tab.TransferOp.READ, src_descs, dst_descs, "another_remote")
assert request.op == tab.TransferOp.READ
assert request.remote_name == "another_remote"
def test_backend_availability_flags():
"""Test that backend availability flags are exposed."""
# These should always be defined (either True or False)
assert hasattr(tab, "NIXL_ENABLED")
assert hasattr(tab, "MOONCAKE_ENABLED")
assert isinstance(tab.NIXL_ENABLED, bool)
assert isinstance(tab.MOONCAKE_ENABLED, bool)
# =============================================================================
# NIXL-specific Tests
# =============================================================================
@pytest.mark.skipif(not HAS_NIXL, reason="NIXL backend not available")
class TestNixlTransferAgent:
"""Test cases for NixlTransferAgent."""
def test_nixl_transfer_agent_class_exists(self):
"""Test that NixlTransferAgent class exists."""
assert hasattr(tab, "NixlTransferAgent")
def test_nixl_transfer_status_class_exists(self):
"""Test that NixlTransferStatus class exists."""
assert hasattr(tab, "NixlTransferStatus")
def test_nixl_transfer_agent_is_base_subclass(self):
"""Test that NixlTransferAgent is a subclass of BaseTransferAgent."""
assert issubclass(tab.NixlTransferAgent, tab.BaseTransferAgent)
def test_nixl_transfer_status_is_base_subclass(self):
"""Test that NixlTransferStatus is a subclass of TransferStatus."""
assert issubclass(tab.NixlTransferStatus, tab.TransferStatus)
def test_nixl_transfer_agent_has_required_methods(self):
"""Test that NixlTransferAgent has all required methods."""
required_methods = [
"register_memory",
"deregister_memory",
"load_remote_agent",
"load_remote_agent_by_connection",
"get_local_agent_desc",
"get_local_connection_info",
"invalidate_remote_agent",
"submit_transfer_requests",
"notify_sync_message",
"get_notified_sync_messages",
"check_remote_descs",
]
for method in required_methods:
assert hasattr(tab.NixlTransferAgent, method), f"Missing method: {method}"
# =============================================================================
# Mooncake-specific Tests
# =============================================================================
@pytest.mark.skipif(not HAS_MOONCAKE, reason="Mooncake backend not available")
class TestMooncakeTransferAgent:
"""Test cases for MooncakeTransferAgent."""
def test_mooncake_transfer_agent_class_exists(self):
"""Test that MooncakeTransferAgent class exists."""
assert hasattr(tab, "MooncakeTransferAgent")
def test_mooncake_transfer_status_class_exists(self):
"""Test that MooncakeTransferStatus class exists."""
assert hasattr(tab, "MooncakeTransferStatus")
def test_mooncake_transfer_agent_is_base_subclass(self):
"""Test that MooncakeTransferAgent is a subclass of BaseTransferAgent."""
assert issubclass(tab.MooncakeTransferAgent, tab.BaseTransferAgent)
def test_mooncake_transfer_status_is_base_subclass(self):
"""Test that MooncakeTransferStatus is a subclass of TransferStatus."""
assert issubclass(tab.MooncakeTransferStatus, tab.TransferStatus)
def test_mooncake_transfer_agent_has_required_methods(self):
"""Test that MooncakeTransferAgent has all required methods."""
required_methods = [
"register_memory",
"deregister_memory",
"load_remote_agent",
"load_remote_agent_by_connection",
"get_local_agent_desc",
"get_local_connection_info",
"invalidate_remote_agent",
"submit_transfer_requests",
"notify_sync_message",
"get_notified_sync_messages",
"check_remote_descs",
]
for method in required_methods:
assert hasattr(tab.MooncakeTransferAgent, method), f"Missing method: {method}"
# =============================================================================
# Functional Tests - Data Transfer Validation
# =============================================================================
def _create_memory_descs_from_tensor(tensor, memory_type):
"""Helper to create MemoryDescs from a torch tensor."""
addr = tensor.data_ptr()
size = tensor.numel() * tensor.element_size()
device_id = tensor.device.index if tensor.is_cuda else 0
desc = tab.MemoryDesc(addr, size, device_id)
return tab.MemoryDescs(memory_type, [desc])
@pytest.mark.skipif(
not (HAS_TORCH and HAS_CUDA),
reason="Torch with CUDA support required for functional tests",
)
@pytest.mark.skipif(not HAS_NIXL, reason="NIXL backend not available")
class TestNixlFunctionalTransfer:
"""Functional tests for NIXL data transfer between two agents."""
def test_nixl_write_transfer_gpu_tensor(self):
"""Test WRITE transfer of GPU tensor data between two NIXL agents."""
device = torch.device("cuda:0")
# Create source tensor with known data pattern
src_tensor = torch.arange(1024, dtype=torch.float32, device=device)
# Create destination tensor (zeros)
dst_tensor = torch.zeros(1024, dtype=torch.float32, device=device)
# Verify initial state
assert not torch.equal(src_tensor, dst_tensor)
# Create two agents
config_a = tab.BaseAgentConfig(
name="agent_a",
use_prog_thread=True,
use_listen_thread=False,
)
config_b = tab.BaseAgentConfig(
name="agent_b",
use_prog_thread=True,
use_listen_thread=False,
)
agent_a = tab.NixlTransferAgent(config_a)
agent_b = tab.NixlTransferAgent(config_b)
# Register memory regions
src_descs = _create_memory_descs_from_tensor(src_tensor, tab.MemoryType.VRAM)
dst_descs = _create_memory_descs_from_tensor(dst_tensor, tab.MemoryType.VRAM)
agent_a.register_memory(src_descs)
agent_b.register_memory(dst_descs)
# Exchange agent descriptors
agent_a_desc = agent_a.get_local_agent_desc()
agent_b_desc = agent_b.get_local_agent_desc()
agent_a.load_remote_agent("agent_b", agent_b_desc)
agent_b.load_remote_agent("agent_a", agent_a_desc)
# Create transfer request: agent_a writes src_tensor to agent_b's dst_tensor
request = tab.TransferRequest(
tab.TransferOp.WRITE,
src_descs, # local source
dst_descs, # remote destination
"agent_b", # remote agent name
)
# Submit transfer and wait for completion
status = agent_a.submit_transfer_requests(request)
result = status.wait(timeout_ms=5000)
assert result == tab.TransferState.SUCCESS, f"Transfer failed with state: {result}"
# Synchronize CUDA to ensure transfer is complete
torch.cuda.synchronize()
# Verify data was transferred correctly
assert torch.equal(src_tensor, dst_tensor), "Data mismatch after transfer"
# Cleanup
agent_a.deregister_memory(src_descs)
agent_b.deregister_memory(dst_descs)
def test_nixl_write_transfer_multiple_chunks(self):
"""Test WRITE transfer with multiple memory chunks."""
device = torch.device("cuda:0")
# Create multiple source tensors
src_tensors = [
torch.arange(i * 256, (i + 1) * 256, dtype=torch.float32, device=device)
for i in range(4)
]
# Create corresponding destination tensors
dst_tensors = [torch.zeros(256, dtype=torch.float32, device=device) for _ in range(4)]
# Create agents
config_a = tab.BaseAgentConfig(
name="agent_a", use_prog_thread=True, use_listen_thread=False
)
config_b = tab.BaseAgentConfig(
name="agent_b", use_prog_thread=True, use_listen_thread=False
)
agent_a = tab.NixlTransferAgent(config_a)
agent_b = tab.NixlTransferAgent(config_b)
# Create memory descriptors for all chunks
src_memory_descs = []
dst_memory_descs = []
for src, dst in zip(src_tensors, dst_tensors):
src_memory_descs.append(
tab.MemoryDesc(src.data_ptr(), src.numel() * src.element_size(), 0)
)
dst_memory_descs.append(
tab.MemoryDesc(dst.data_ptr(), dst.numel() * dst.element_size(), 0)
)
src_descs = tab.MemoryDescs(tab.MemoryType.VRAM, src_memory_descs)
dst_descs = tab.MemoryDescs(tab.MemoryType.VRAM, dst_memory_descs)
# Register memory
agent_a.register_memory(src_descs)
agent_b.register_memory(dst_descs)
# Exchange agent info
agent_a.load_remote_agent("agent_b", agent_b.get_local_agent_desc())
agent_b.load_remote_agent("agent_a", agent_a.get_local_agent_desc())
# Transfer
request = tab.TransferRequest(tab.TransferOp.WRITE, src_descs, dst_descs, "agent_b")
status = agent_a.submit_transfer_requests(request)
result = status.wait(timeout_ms=5000)
assert result == tab.TransferState.SUCCESS
torch.cuda.synchronize()
# Verify all chunks
for i, (src, dst) in enumerate(zip(src_tensors, dst_tensors)):
assert torch.equal(src, dst), f"Data mismatch in chunk {i}"
# Cleanup
agent_a.deregister_memory(src_descs)
agent_b.deregister_memory(dst_descs)
@pytest.mark.skipif(
not (HAS_TORCH and HAS_CUDA),
reason="Torch with CUDA support required for functional tests",
)
@pytest.mark.skipif(not HAS_MOONCAKE, reason="Mooncake backend not available")
class TestMooncakeFunctionalTransfer:
"""Functional tests for Mooncake data transfer between two agents."""
def test_mooncake_write_transfer_gpu_tensor(self):
"""Test WRITE transfer of GPU tensor data between two Mooncake agents."""
device = torch.device("cuda:0")
# Create source tensor with known data pattern
src_tensor = torch.arange(1024, dtype=torch.float32, device=device)
# Create destination tensor (zeros)
dst_tensor = torch.zeros(1024, dtype=torch.float32, device=device)
# Verify initial state
assert not torch.equal(src_tensor, dst_tensor)
# Create two agents
config_a = tab.BaseAgentConfig(name="mooncake_agent_a", use_prog_thread=True)
config_b = tab.BaseAgentConfig(name="mooncake_agent_b", use_prog_thread=True)
agent_a = tab.MooncakeTransferAgent(config_a)
agent_b = tab.MooncakeTransferAgent(config_b)
# Register memory regions
src_descs = _create_memory_descs_from_tensor(src_tensor, tab.MemoryType.VRAM)
dst_descs = _create_memory_descs_from_tensor(dst_tensor, tab.MemoryType.VRAM)
agent_a.register_memory(src_descs)
agent_b.register_memory(dst_descs)
agent_a_desc = agent_a.get_local_agent_desc()
agent_b_desc = agent_b.get_local_agent_desc()
agent_a.load_remote_agent("mooncake_agent_b", agent_b_desc)
agent_b.load_remote_agent("mooncake_agent_a", agent_a_desc)
request = tab.TransferRequest(
tab.TransferOp.WRITE,
src_descs, # local source
dst_descs, # remote destination
"mooncake_agent_b", # remote agent name
)
# # Submit transfer and wait for completion
status = agent_a.submit_transfer_requests(request)
result = status.wait()
assert result == tab.TransferState.SUCCESS, f"Transfer failed with state: {result}"
# Synchronize CUDA to ensure transfer is complete
torch.cuda.synchronize()
# Verify data was transferred correctly
assert torch.equal(src_tensor, dst_tensor), "Data mismatch after transfer"
# Cleanup
agent_a.deregister_memory(src_descs)
agent_b.deregister_memory(dst_descs)
def test_mooncake_write_transfer_multiple_chunks(self):
"""Test WRITE transfer with multiple memory chunks."""
device = torch.device("cuda:0")
# Create multiple source tensors
src_tensors = [
torch.arange(i * 256, (i + 1) * 256, dtype=torch.float32, device=device)
for i in range(4)
]
# Create corresponding destination tensors
dst_tensors = [torch.zeros(256, dtype=torch.float32, device=device) for _ in range(4)]
# Create agents
config_a = tab.BaseAgentConfig(name="mooncake_agent_a", use_prog_thread=True)
config_b = tab.BaseAgentConfig(name="mooncake_agent_b", use_prog_thread=True)
agent_a = tab.MooncakeTransferAgent(config_a)
agent_b = tab.MooncakeTransferAgent(config_b)
# Create memory descriptors for all chunks
src_memory_descs = []
dst_memory_descs = []
for src, dst in zip(src_tensors, dst_tensors):
src_memory_descs.append(
tab.MemoryDesc(src.data_ptr(), src.numel() * src.element_size(), 0)
)
dst_memory_descs.append(
tab.MemoryDesc(dst.data_ptr(), dst.numel() * dst.element_size(), 0)
)
src_descs = tab.MemoryDescs(tab.MemoryType.VRAM, src_memory_descs)
dst_descs = tab.MemoryDescs(tab.MemoryType.VRAM, dst_memory_descs)
# Register memory
agent_a.register_memory(src_descs)
agent_b.register_memory(dst_descs)
# Exchange agent info
agent_a.load_remote_agent("mooncake_agent_b", agent_b.get_local_agent_desc())
agent_b.load_remote_agent("mooncake_agent_a", agent_a.get_local_agent_desc())
# Transfer
request = tab.TransferRequest(
tab.TransferOp.WRITE, src_descs, dst_descs, "mooncake_agent_b"
)
status = agent_a.submit_transfer_requests(request)
result = status.wait(timeout_ms=5000)
assert result == tab.TransferState.SUCCESS
torch.cuda.synchronize()
# Verify all chunks
for i, (src, dst) in enumerate(zip(src_tensors, dst_tensors)):
assert torch.equal(src, dst), f"Data mismatch in chunk {i}"
# Cleanup
agent_a.deregister_memory(src_descs)
agent_b.deregister_memory(dst_descs)