mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 10:11:47 +08:00
586 lines
20 KiB
Python
586 lines
20 KiB
Python
import pytest
|
|
|
|
# Try to import the transfer agent binding module
|
|
try:
|
|
import tensorrt_llm.tensorrt_llm_transfer_agent_binding as tab
|
|
|
|
HAS_TRANSFER_AGENT = True
|
|
# Check which backends are available
|
|
HAS_NIXL = getattr(tab, "NIXL_ENABLED", False)
|
|
HAS_MOONCAKE = getattr(tab, "MOONCAKE_ENABLED", False)
|
|
except ImportError:
|
|
HAS_TRANSFER_AGENT = False
|
|
HAS_NIXL = False
|
|
HAS_MOONCAKE = False
|
|
|
|
# Try to import torch for functional tests
|
|
try:
|
|
import torch
|
|
|
|
HAS_TORCH = True
|
|
HAS_CUDA = torch.cuda.is_available()
|
|
except ImportError:
|
|
HAS_TORCH = False
|
|
HAS_CUDA = False
|
|
|
|
|
|
pytestmark = pytest.mark.skipif(
|
|
not HAS_TRANSFER_AGENT,
|
|
reason="Transfer agent bindings not available (tensorrt_llm_transfer_agent_binding)",
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Common Tests (independent of backend)
|
|
# =============================================================================
|
|
|
|
|
|
def test_memory_type_enum():
|
|
"""Test MemoryType enum values."""
|
|
assert tab.MemoryType.DRAM is not None
|
|
assert tab.MemoryType.VRAM is not None
|
|
assert tab.MemoryType.BLK is not None
|
|
assert tab.MemoryType.OBJ is not None
|
|
assert tab.MemoryType.FILE is not None
|
|
|
|
# Verify they are distinct
|
|
assert tab.MemoryType.DRAM != tab.MemoryType.VRAM
|
|
assert tab.MemoryType.VRAM != tab.MemoryType.BLK
|
|
|
|
|
|
def test_transfer_op_enum():
|
|
"""Test TransferOp enum values."""
|
|
assert tab.TransferOp.READ is not None
|
|
assert tab.TransferOp.WRITE is not None
|
|
assert tab.TransferOp.READ != tab.TransferOp.WRITE
|
|
|
|
|
|
def test_transfer_state_enum():
|
|
"""Test TransferState enum values."""
|
|
assert tab.TransferState.IN_PROGRESS is not None
|
|
assert tab.TransferState.SUCCESS is not None
|
|
assert tab.TransferState.FAILURE is not None
|
|
|
|
# Verify they are distinct
|
|
assert tab.TransferState.IN_PROGRESS != tab.TransferState.SUCCESS
|
|
assert tab.TransferState.SUCCESS != tab.TransferState.FAILURE
|
|
assert tab.TransferState.IN_PROGRESS != tab.TransferState.FAILURE
|
|
|
|
|
|
def test_memory_desc():
|
|
"""Test MemoryDesc class."""
|
|
addr = 0x1000
|
|
length = 4096
|
|
device_id = 0
|
|
|
|
desc = tab.MemoryDesc(addr, length, device_id)
|
|
|
|
assert desc.addr == addr
|
|
assert desc.len == length
|
|
assert desc.device_id == device_id
|
|
|
|
|
|
def test_memory_desc_different_values():
|
|
"""Test MemoryDesc with different values."""
|
|
test_cases = [
|
|
(0x0, 1, 0),
|
|
(0xFFFFFFFF, 65536, 1),
|
|
(0x12345678, 1024, 7),
|
|
]
|
|
|
|
for addr, length, device_id in test_cases:
|
|
desc = tab.MemoryDesc(addr, length, device_id)
|
|
assert desc.addr == addr
|
|
assert desc.len == length
|
|
assert desc.device_id == device_id
|
|
|
|
|
|
def test_memory_descs():
|
|
"""Test MemoryDescs class."""
|
|
desc1 = tab.MemoryDesc(0x1000, 4096, 0)
|
|
desc2 = tab.MemoryDesc(0x2000, 8192, 0)
|
|
|
|
descs = tab.MemoryDescs(tab.MemoryType.VRAM, [desc1, desc2])
|
|
|
|
assert descs.type == tab.MemoryType.VRAM
|
|
assert len(descs.descs) == 2
|
|
assert descs.descs[0].addr == 0x1000
|
|
assert descs.descs[1].addr == 0x2000
|
|
|
|
|
|
def test_memory_descs_empty():
|
|
"""Test MemoryDescs with empty list."""
|
|
descs = tab.MemoryDescs(tab.MemoryType.DRAM, [])
|
|
assert descs.type == tab.MemoryType.DRAM
|
|
assert len(descs.descs) == 0
|
|
|
|
|
|
def test_agent_desc_from_string():
|
|
"""Test AgentDesc from string."""
|
|
test_data = "test_agent_descriptor"
|
|
desc = tab.AgentDesc(test_data)
|
|
assert desc.backend_agent_desc == test_data.encode()
|
|
|
|
|
|
def test_agent_desc_from_bytes():
|
|
"""Test AgentDesc from bytes."""
|
|
test_data = b"test_binary_data\x00\x01\x02"
|
|
desc = tab.AgentDesc(test_data)
|
|
assert desc.backend_agent_desc == test_data
|
|
|
|
|
|
def test_base_agent_config_default():
|
|
"""Test BaseAgentConfig with default values."""
|
|
config = tab.BaseAgentConfig()
|
|
# Default values should be set
|
|
assert config is not None
|
|
|
|
|
|
def test_base_agent_config_custom():
|
|
"""Test BaseAgentConfig with custom values."""
|
|
name = "test_agent"
|
|
use_prog_thread = True
|
|
multi_thread = False
|
|
use_listen_thread = True
|
|
num_workers = 4
|
|
|
|
config = tab.BaseAgentConfig(
|
|
name=name,
|
|
use_prog_thread=use_prog_thread,
|
|
multi_thread=multi_thread,
|
|
use_listen_thread=use_listen_thread,
|
|
num_workers=num_workers,
|
|
)
|
|
|
|
assert config.name == name
|
|
assert config.use_prog_thread == use_prog_thread
|
|
assert config.multi_thread == multi_thread
|
|
assert config.use_listen_thread == use_listen_thread
|
|
assert config.num_workers == num_workers
|
|
|
|
|
|
def test_base_agent_config_readwrite():
|
|
"""Test BaseAgentConfig read/write properties."""
|
|
config = tab.BaseAgentConfig()
|
|
|
|
config.name = "modified_name"
|
|
assert config.name == "modified_name"
|
|
|
|
config.use_prog_thread = False
|
|
assert config.use_prog_thread is False
|
|
|
|
config.multi_thread = True
|
|
assert config.multi_thread is True
|
|
|
|
config.use_listen_thread = True
|
|
assert config.use_listen_thread is True
|
|
|
|
config.num_workers = 8
|
|
assert config.num_workers == 8
|
|
|
|
|
|
def test_transfer_request():
|
|
"""Test TransferRequest class."""
|
|
src_desc = tab.MemoryDesc(0x1000, 4096, 0)
|
|
dst_desc = tab.MemoryDesc(0x2000, 4096, 1)
|
|
|
|
src_descs = tab.MemoryDescs(tab.MemoryType.VRAM, [src_desc])
|
|
dst_descs = tab.MemoryDescs(tab.MemoryType.VRAM, [dst_desc])
|
|
|
|
remote_name = "remote_agent"
|
|
|
|
request = tab.TransferRequest(tab.TransferOp.WRITE, src_descs, dst_descs, remote_name)
|
|
|
|
assert request.op == tab.TransferOp.WRITE
|
|
assert request.remote_name == remote_name
|
|
assert request.src_descs.type == tab.MemoryType.VRAM
|
|
assert request.dst_descs.type == tab.MemoryType.VRAM
|
|
|
|
|
|
def test_transfer_request_read_op():
|
|
"""Test TransferRequest with READ operation."""
|
|
src_desc = tab.MemoryDesc(0x3000, 2048, 0)
|
|
dst_desc = tab.MemoryDesc(0x4000, 2048, 0)
|
|
|
|
src_descs = tab.MemoryDescs(tab.MemoryType.DRAM, [src_desc])
|
|
dst_descs = tab.MemoryDescs(tab.MemoryType.DRAM, [dst_desc])
|
|
|
|
request = tab.TransferRequest(tab.TransferOp.READ, src_descs, dst_descs, "another_remote")
|
|
|
|
assert request.op == tab.TransferOp.READ
|
|
assert request.remote_name == "another_remote"
|
|
|
|
|
|
def test_backend_availability_flags():
|
|
"""Test that backend availability flags are exposed."""
|
|
# These should always be defined (either True or False)
|
|
assert hasattr(tab, "NIXL_ENABLED")
|
|
assert hasattr(tab, "MOONCAKE_ENABLED")
|
|
assert isinstance(tab.NIXL_ENABLED, bool)
|
|
assert isinstance(tab.MOONCAKE_ENABLED, bool)
|
|
|
|
|
|
# =============================================================================
|
|
# NIXL-specific Tests
|
|
# =============================================================================
|
|
|
|
|
|
@pytest.mark.skipif(not HAS_NIXL, reason="NIXL backend not available")
|
|
class TestNixlTransferAgent:
|
|
"""Test cases for NixlTransferAgent."""
|
|
|
|
def test_nixl_transfer_agent_class_exists(self):
|
|
"""Test that NixlTransferAgent class exists."""
|
|
assert hasattr(tab, "NixlTransferAgent")
|
|
|
|
def test_nixl_transfer_status_class_exists(self):
|
|
"""Test that NixlTransferStatus class exists."""
|
|
assert hasattr(tab, "NixlTransferStatus")
|
|
|
|
def test_nixl_transfer_agent_is_base_subclass(self):
|
|
"""Test that NixlTransferAgent is a subclass of BaseTransferAgent."""
|
|
assert issubclass(tab.NixlTransferAgent, tab.BaseTransferAgent)
|
|
|
|
def test_nixl_transfer_status_is_base_subclass(self):
|
|
"""Test that NixlTransferStatus is a subclass of TransferStatus."""
|
|
assert issubclass(tab.NixlTransferStatus, tab.TransferStatus)
|
|
|
|
def test_nixl_transfer_agent_has_required_methods(self):
|
|
"""Test that NixlTransferAgent has all required methods."""
|
|
required_methods = [
|
|
"register_memory",
|
|
"deregister_memory",
|
|
"load_remote_agent",
|
|
"load_remote_agent_by_connection",
|
|
"get_local_agent_desc",
|
|
"get_local_connection_info",
|
|
"invalidate_remote_agent",
|
|
"submit_transfer_requests",
|
|
"notify_sync_message",
|
|
"get_notified_sync_messages",
|
|
"check_remote_descs",
|
|
]
|
|
for method in required_methods:
|
|
assert hasattr(tab.NixlTransferAgent, method), f"Missing method: {method}"
|
|
|
|
|
|
# =============================================================================
|
|
# Mooncake-specific Tests
|
|
# =============================================================================
|
|
|
|
|
|
@pytest.mark.skipif(not HAS_MOONCAKE, reason="Mooncake backend not available")
|
|
class TestMooncakeTransferAgent:
|
|
"""Test cases for MooncakeTransferAgent."""
|
|
|
|
def test_mooncake_transfer_agent_class_exists(self):
|
|
"""Test that MooncakeTransferAgent class exists."""
|
|
assert hasattr(tab, "MooncakeTransferAgent")
|
|
|
|
def test_mooncake_transfer_status_class_exists(self):
|
|
"""Test that MooncakeTransferStatus class exists."""
|
|
assert hasattr(tab, "MooncakeTransferStatus")
|
|
|
|
def test_mooncake_transfer_agent_is_base_subclass(self):
|
|
"""Test that MooncakeTransferAgent is a subclass of BaseTransferAgent."""
|
|
assert issubclass(tab.MooncakeTransferAgent, tab.BaseTransferAgent)
|
|
|
|
def test_mooncake_transfer_status_is_base_subclass(self):
|
|
"""Test that MooncakeTransferStatus is a subclass of TransferStatus."""
|
|
assert issubclass(tab.MooncakeTransferStatus, tab.TransferStatus)
|
|
|
|
def test_mooncake_transfer_agent_has_required_methods(self):
|
|
"""Test that MooncakeTransferAgent has all required methods."""
|
|
required_methods = [
|
|
"register_memory",
|
|
"deregister_memory",
|
|
"load_remote_agent",
|
|
"load_remote_agent_by_connection",
|
|
"get_local_agent_desc",
|
|
"get_local_connection_info",
|
|
"invalidate_remote_agent",
|
|
"submit_transfer_requests",
|
|
"notify_sync_message",
|
|
"get_notified_sync_messages",
|
|
"check_remote_descs",
|
|
]
|
|
for method in required_methods:
|
|
assert hasattr(tab.MooncakeTransferAgent, method), f"Missing method: {method}"
|
|
|
|
|
|
# =============================================================================
|
|
# Functional Tests - Data Transfer Validation
|
|
# =============================================================================
|
|
|
|
|
|
def _create_memory_descs_from_tensor(tensor, memory_type):
|
|
"""Helper to create MemoryDescs from a torch tensor."""
|
|
addr = tensor.data_ptr()
|
|
size = tensor.numel() * tensor.element_size()
|
|
device_id = tensor.device.index if tensor.is_cuda else 0
|
|
desc = tab.MemoryDesc(addr, size, device_id)
|
|
return tab.MemoryDescs(memory_type, [desc])
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not (HAS_TORCH and HAS_CUDA),
|
|
reason="Torch with CUDA support required for functional tests",
|
|
)
|
|
@pytest.mark.skipif(not HAS_NIXL, reason="NIXL backend not available")
|
|
class TestNixlFunctionalTransfer:
|
|
"""Functional tests for NIXL data transfer between two agents."""
|
|
|
|
def test_nixl_write_transfer_gpu_tensor(self):
|
|
"""Test WRITE transfer of GPU tensor data between two NIXL agents."""
|
|
device = torch.device("cuda:0")
|
|
|
|
# Create source tensor with known data pattern
|
|
src_tensor = torch.arange(1024, dtype=torch.float32, device=device)
|
|
|
|
# Create destination tensor (zeros)
|
|
dst_tensor = torch.zeros(1024, dtype=torch.float32, device=device)
|
|
|
|
# Verify initial state
|
|
assert not torch.equal(src_tensor, dst_tensor)
|
|
|
|
# Create two agents
|
|
config_a = tab.BaseAgentConfig(
|
|
name="agent_a",
|
|
use_prog_thread=True,
|
|
use_listen_thread=False,
|
|
)
|
|
config_b = tab.BaseAgentConfig(
|
|
name="agent_b",
|
|
use_prog_thread=True,
|
|
use_listen_thread=False,
|
|
)
|
|
|
|
agent_a = tab.NixlTransferAgent(config_a)
|
|
agent_b = tab.NixlTransferAgent(config_b)
|
|
|
|
# Register memory regions
|
|
src_descs = _create_memory_descs_from_tensor(src_tensor, tab.MemoryType.VRAM)
|
|
dst_descs = _create_memory_descs_from_tensor(dst_tensor, tab.MemoryType.VRAM)
|
|
|
|
agent_a.register_memory(src_descs)
|
|
agent_b.register_memory(dst_descs)
|
|
|
|
# Exchange agent descriptors
|
|
agent_a_desc = agent_a.get_local_agent_desc()
|
|
agent_b_desc = agent_b.get_local_agent_desc()
|
|
|
|
agent_a.load_remote_agent("agent_b", agent_b_desc)
|
|
agent_b.load_remote_agent("agent_a", agent_a_desc)
|
|
|
|
# Create transfer request: agent_a writes src_tensor to agent_b's dst_tensor
|
|
request = tab.TransferRequest(
|
|
tab.TransferOp.WRITE,
|
|
src_descs, # local source
|
|
dst_descs, # remote destination
|
|
"agent_b", # remote agent name
|
|
)
|
|
|
|
# Submit transfer and wait for completion
|
|
status = agent_a.submit_transfer_requests(request)
|
|
result = status.wait(timeout_ms=5000)
|
|
|
|
assert result == tab.TransferState.SUCCESS, f"Transfer failed with state: {result}"
|
|
|
|
# Synchronize CUDA to ensure transfer is complete
|
|
torch.cuda.synchronize()
|
|
|
|
# Verify data was transferred correctly
|
|
assert torch.equal(src_tensor, dst_tensor), "Data mismatch after transfer"
|
|
|
|
# Cleanup
|
|
agent_a.deregister_memory(src_descs)
|
|
agent_b.deregister_memory(dst_descs)
|
|
|
|
def test_nixl_write_transfer_multiple_chunks(self):
|
|
"""Test WRITE transfer with multiple memory chunks."""
|
|
device = torch.device("cuda:0")
|
|
|
|
# Create multiple source tensors
|
|
src_tensors = [
|
|
torch.arange(i * 256, (i + 1) * 256, dtype=torch.float32, device=device)
|
|
for i in range(4)
|
|
]
|
|
|
|
# Create corresponding destination tensors
|
|
dst_tensors = [torch.zeros(256, dtype=torch.float32, device=device) for _ in range(4)]
|
|
|
|
# Create agents
|
|
config_a = tab.BaseAgentConfig(
|
|
name="agent_a", use_prog_thread=True, use_listen_thread=False
|
|
)
|
|
config_b = tab.BaseAgentConfig(
|
|
name="agent_b", use_prog_thread=True, use_listen_thread=False
|
|
)
|
|
|
|
agent_a = tab.NixlTransferAgent(config_a)
|
|
agent_b = tab.NixlTransferAgent(config_b)
|
|
|
|
# Create memory descriptors for all chunks
|
|
src_memory_descs = []
|
|
dst_memory_descs = []
|
|
for src, dst in zip(src_tensors, dst_tensors):
|
|
src_memory_descs.append(
|
|
tab.MemoryDesc(src.data_ptr(), src.numel() * src.element_size(), 0)
|
|
)
|
|
dst_memory_descs.append(
|
|
tab.MemoryDesc(dst.data_ptr(), dst.numel() * dst.element_size(), 0)
|
|
)
|
|
|
|
src_descs = tab.MemoryDescs(tab.MemoryType.VRAM, src_memory_descs)
|
|
dst_descs = tab.MemoryDescs(tab.MemoryType.VRAM, dst_memory_descs)
|
|
|
|
# Register memory
|
|
agent_a.register_memory(src_descs)
|
|
agent_b.register_memory(dst_descs)
|
|
|
|
# Exchange agent info
|
|
agent_a.load_remote_agent("agent_b", agent_b.get_local_agent_desc())
|
|
agent_b.load_remote_agent("agent_a", agent_a.get_local_agent_desc())
|
|
|
|
# Transfer
|
|
request = tab.TransferRequest(tab.TransferOp.WRITE, src_descs, dst_descs, "agent_b")
|
|
status = agent_a.submit_transfer_requests(request)
|
|
result = status.wait(timeout_ms=5000)
|
|
|
|
assert result == tab.TransferState.SUCCESS
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
# Verify all chunks
|
|
for i, (src, dst) in enumerate(zip(src_tensors, dst_tensors)):
|
|
assert torch.equal(src, dst), f"Data mismatch in chunk {i}"
|
|
|
|
# Cleanup
|
|
agent_a.deregister_memory(src_descs)
|
|
agent_b.deregister_memory(dst_descs)
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not (HAS_TORCH and HAS_CUDA),
|
|
reason="Torch with CUDA support required for functional tests",
|
|
)
|
|
@pytest.mark.skipif(not HAS_MOONCAKE, reason="Mooncake backend not available")
|
|
class TestMooncakeFunctionalTransfer:
|
|
"""Functional tests for Mooncake data transfer between two agents."""
|
|
|
|
def test_mooncake_write_transfer_gpu_tensor(self):
|
|
"""Test WRITE transfer of GPU tensor data between two Mooncake agents."""
|
|
device = torch.device("cuda:0")
|
|
|
|
# Create source tensor with known data pattern
|
|
src_tensor = torch.arange(1024, dtype=torch.float32, device=device)
|
|
|
|
# Create destination tensor (zeros)
|
|
dst_tensor = torch.zeros(1024, dtype=torch.float32, device=device)
|
|
|
|
# Verify initial state
|
|
assert not torch.equal(src_tensor, dst_tensor)
|
|
|
|
# Create two agents
|
|
config_a = tab.BaseAgentConfig(name="mooncake_agent_a", use_prog_thread=True)
|
|
config_b = tab.BaseAgentConfig(name="mooncake_agent_b", use_prog_thread=True)
|
|
agent_a = tab.MooncakeTransferAgent(config_a)
|
|
|
|
agent_b = tab.MooncakeTransferAgent(config_b)
|
|
# Register memory regions
|
|
src_descs = _create_memory_descs_from_tensor(src_tensor, tab.MemoryType.VRAM)
|
|
dst_descs = _create_memory_descs_from_tensor(dst_tensor, tab.MemoryType.VRAM)
|
|
|
|
agent_a.register_memory(src_descs)
|
|
agent_b.register_memory(dst_descs)
|
|
agent_a_desc = agent_a.get_local_agent_desc()
|
|
|
|
agent_b_desc = agent_b.get_local_agent_desc()
|
|
|
|
agent_a.load_remote_agent("mooncake_agent_b", agent_b_desc)
|
|
agent_b.load_remote_agent("mooncake_agent_a", agent_a_desc)
|
|
|
|
request = tab.TransferRequest(
|
|
tab.TransferOp.WRITE,
|
|
src_descs, # local source
|
|
dst_descs, # remote destination
|
|
"mooncake_agent_b", # remote agent name
|
|
)
|
|
|
|
# # Submit transfer and wait for completion
|
|
status = agent_a.submit_transfer_requests(request)
|
|
|
|
result = status.wait()
|
|
assert result == tab.TransferState.SUCCESS, f"Transfer failed with state: {result}"
|
|
|
|
# Synchronize CUDA to ensure transfer is complete
|
|
torch.cuda.synchronize()
|
|
|
|
# Verify data was transferred correctly
|
|
assert torch.equal(src_tensor, dst_tensor), "Data mismatch after transfer"
|
|
|
|
# Cleanup
|
|
agent_a.deregister_memory(src_descs)
|
|
agent_b.deregister_memory(dst_descs)
|
|
|
|
def test_mooncake_write_transfer_multiple_chunks(self):
|
|
"""Test WRITE transfer with multiple memory chunks."""
|
|
device = torch.device("cuda:0")
|
|
|
|
# Create multiple source tensors
|
|
src_tensors = [
|
|
torch.arange(i * 256, (i + 1) * 256, dtype=torch.float32, device=device)
|
|
for i in range(4)
|
|
]
|
|
|
|
# Create corresponding destination tensors
|
|
dst_tensors = [torch.zeros(256, dtype=torch.float32, device=device) for _ in range(4)]
|
|
|
|
# Create agents
|
|
config_a = tab.BaseAgentConfig(name="mooncake_agent_a", use_prog_thread=True)
|
|
config_b = tab.BaseAgentConfig(name="mooncake_agent_b", use_prog_thread=True)
|
|
|
|
agent_a = tab.MooncakeTransferAgent(config_a)
|
|
agent_b = tab.MooncakeTransferAgent(config_b)
|
|
|
|
# Create memory descriptors for all chunks
|
|
src_memory_descs = []
|
|
dst_memory_descs = []
|
|
for src, dst in zip(src_tensors, dst_tensors):
|
|
src_memory_descs.append(
|
|
tab.MemoryDesc(src.data_ptr(), src.numel() * src.element_size(), 0)
|
|
)
|
|
dst_memory_descs.append(
|
|
tab.MemoryDesc(dst.data_ptr(), dst.numel() * dst.element_size(), 0)
|
|
)
|
|
|
|
src_descs = tab.MemoryDescs(tab.MemoryType.VRAM, src_memory_descs)
|
|
dst_descs = tab.MemoryDescs(tab.MemoryType.VRAM, dst_memory_descs)
|
|
|
|
# Register memory
|
|
agent_a.register_memory(src_descs)
|
|
agent_b.register_memory(dst_descs)
|
|
|
|
# Exchange agent info
|
|
agent_a.load_remote_agent("mooncake_agent_b", agent_b.get_local_agent_desc())
|
|
agent_b.load_remote_agent("mooncake_agent_a", agent_a.get_local_agent_desc())
|
|
|
|
# Transfer
|
|
request = tab.TransferRequest(
|
|
tab.TransferOp.WRITE, src_descs, dst_descs, "mooncake_agent_b"
|
|
)
|
|
status = agent_a.submit_transfer_requests(request)
|
|
result = status.wait(timeout_ms=5000)
|
|
|
|
assert result == tab.TransferState.SUCCESS
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
# Verify all chunks
|
|
for i, (src, dst) in enumerate(zip(src_tensors, dst_tensors)):
|
|
assert torch.equal(src, dst), f"Data mismatch in chunk {i}"
|
|
|
|
# Cleanup
|
|
agent_a.deregister_memory(src_descs)
|
|
agent_b.deregister_memory(dst_descs)
|