This commit is contained in:
Shi Xiaowei 2026-01-13 19:20:38 +08:00 committed by GitHub
commit 883f9b6694
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1094 additions and 0 deletions

View File

@ -0,0 +1,128 @@
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Tuple, Union
from tensorrt_llm import logger
# We deliberately use a non-enum data structure here. This choice ensures that
# members are directly equivalent to the plain strings.
class TransferOp:
READ = "READ"
WRITE = "WRITE"
class MemoryType:
DRAM = "DRAM"
VRAM = "VRAM"
BLK = "BLK"
OBJ = "OBJ"
FILE = "FILE"
@dataclass
class MemoryDesc:
ptr: int
size: int
device_id: int
@dataclass
class MemoryDescs:
type: str
descs: List[Union[Tuple[int, int, int], MemoryDesc]]
@dataclass
class TransferRequest:
op: TransferOp
src_descs: MemoryDescs
dst_descs: MemoryDescs
remote_name: str
sync_message: str
class TransferStatus(ABC):
@abstractmethod
def is_completed(self) -> bool: ...
@abstractmethod
def wait(self, timeout: float | None = None) -> None: ...
class BaseTransferAgent(ABC):
@abstractmethod
def register_memory(self, descs: MemoryDescs) -> None: ...
@abstractmethod
def deregister_memory(self, descs: MemoryDescs) -> None: ...
@abstractmethod
def load_remote_agent(self, name: str, agent_desc: str) -> None: ...
@abstractmethod
def get_local_agent_desc(self) -> str: ...
@abstractmethod
def invalidate_remote_agent(self, name: str) -> None: ...
@abstractmethod
def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus: ...
@abstractmethod
def notify_sync_message(self, name: str, sync_message: str) -> None: ...
@abstractmethod
def check_remote_descs(self, name: str, memory_descs: List[int]) -> bool: ...
# RegMemoryDescs is Python-only (used for registration with name field)
@dataclass
class RegMemoryDescs:
type: str
descs: List[Tuple[int, int, int, str]] # (ptr, size, device_id, name)
def _force_py_nixl_kv_transfer() -> bool:
res = os.getenv("TRTLLM_USE_PY_NIXL_KVCACHE", "0") == "1"
if res:
logger.info("Forcing use of pure Python NIXL KV Transfer Agent implementation.")
return res
def _try_load_cpp_binding():
try:
import tensorrt_llm.tensorrt_llm_transfer_agent_binding as _cpp_binding
required_attributes = [
"MemoryType",
"TransferOp",
"MemoryDesc",
"MemoryDescs",
"TransferRequest",
"TransferStatus",
"BaseTransferAgent",
]
if all(hasattr(_cpp_binding, attr) for attr in required_attributes):
return _cpp_binding
except ImportError:
logger.info("tensorrt_llm_transfer_agent_binding module not found.")
return None
_cpp_binding = _try_load_cpp_binding()
if _cpp_binding and not _force_py_nixl_kv_transfer():
MemoryType = _cpp_binding.MemoryType
TransferOp = _cpp_binding.TransferOp
MemoryDesc = _cpp_binding.MemoryDesc
MemoryDescs = _cpp_binding.MemoryDescs
TransferRequest = _cpp_binding.TransferRequest
TransferStatus = _cpp_binding.TransferStatus
BaseTransferAgent = _cpp_binding.BaseTransferAgent
logger.info("Using Pybind transfer agent binding for Transfer Agent implementation.")
else:
logger.warning(
"Failed to import Pybind transfer agent binding, using pure Python implementation."
)

View File

@ -0,0 +1,144 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional
from tensorrt_llm import DisaggregatedParams
@dataclass
class KVSlice:
"""Supports transmitting only part of the request cache, e.g, chunks or layers."""
start_token_idx: Optional[int] = None
end_token_idx: Optional[int] = None
start_layer: Optional[int] = None
end_layer: Optional[int] = None
blocks: List[int] = field(default_factory=list)
is_last_slice: bool = False
class SessionStatus(Enum):
"""Status of a transfer session."""
INIT = "INIT"
READY = "READY"
TRANSFERRING = "TRANSFERRING"
TRANSFERRED = "TRANSFERRED"
AUX_TRANSFERRED = "AUX_TRANSFERRED"
COMPLETED = "COMPLETED"
CANCELED = "CANCELED"
ERROR = "ERROR"
TaskIdType = int
@dataclass
class SessionState:
"""State of a transfer session."""
status: SessionStatus
finished_tasks: List[TaskIdType]
@dataclass
class SessionArgsBase:
"""Base arguments for transfer sessions."""
params: DisaggregatedParams
class SenderBase(ABC):
"""Base class for sending KV cache data."""
...
class ReceiverBase(ABC):
"""Base class for receiving KV cache data."""
...
class TxSessionBase(ABC):
def __init__(self, sender: SenderBase, args: SessionArgsBase):
"""
Initializes the transmission session.
:param sender: The sender instance responsible for sending data.
:param args: The session arguments.
"""
self._base_args = args
@property
@abstractmethod
def state(self) -> SessionState:
"""
Returns the current state of the session.
"""
...
@abstractmethod
def poll_task(self, id: TaskIdType) -> SessionStatus:
"""
Polls the status of a specific task by its ID.
:param id: The task ID to poll.
"""
...
@abstractmethod
def send(self, slice: KVSlice) -> TaskIdType:
"""
Sends a slice of KV cache data and returns the task ID.
:param slice: The KV slice to send.
"""
...
@property
@abstractmethod
def exception(self) -> Optional[Exception]:
"""
Returns any exception that occurred during the session.
"""
...
class RxSessionBase(ABC):
def __init__(self, receiver: ReceiverBase, args: SessionArgsBase):
"""
Initializes the reception session.
:param receiver: The receiver instance responsible for receiving data.
"""
self._base_args = args
@property
@abstractmethod
def state(self) -> SessionState:
"""
Returns the current state of the session.
"""
...
@abstractmethod
def poll_task(self, id: TaskIdType) -> SessionStatus:
"""
Polls the status of a specific task by its ID.
:param id: The task ID to poll.
"""
...
@abstractmethod
def receive(self, slice: KVSlice) -> TaskIdType:
"""
Receives a slice of KV cache data and returns the task ID.
:param slice: The KV slice to receive.
"""
...
@property
@abstractmethod
def exception(self) -> Optional[Exception]:
"""Returns any exception that occurred during the session."""
...

View File

@ -0,0 +1,228 @@
from abc import ABC, abstractmethod
from threading import Event, Lock, Thread
from typing import Callable, Optional
import zmq
from tensorrt_llm import logger
from tensorrt_llm._torch.disaggregation.native.utils import get_local_ip
class MessengerInterface(ABC):
"""
Abstract base class for messenger implementations.
"""
@abstractmethod
def start(self):
"""
Start the messenger service.
"""
...
@abstractmethod
def send(self, messages: list[bytes], recipient: Optional[bytes] = None):
"""
Send messages to a recipient.
:param messages: List of byte messages to send.
:param recipient: Optional recipient identifier.
"""
...
@abstractmethod
def send_encoded(self, *messages, encoding: str = "ascii"):
"""
Send messages after encoding them.
:param messages: Messages to send.
:param encoding: Encoding format.
"""
...
@abstractmethod
def receive(self) -> list[bytes]:
"""
Receive messages.
:return: List of byte messages received.
"""
...
@abstractmethod
def start_listener(self, on_message: Callable[[list[bytes]], None]):
"""
Start a listener thread to handle incoming messages.
:param on_message: Callback function to process received messages.
"""
...
@abstractmethod
def stop(self):
"""
Stop the messenger service.
"""
...
@property
@abstractmethod
def endpoint(self) -> str:
"""
Get the endpoint of the messenger.
:return: Endpoint string.
"""
...
def decode_message(
message: list[bytes], encoding: str = "ascii", err_mode: str = "strict"
) -> tuple:
if not isinstance(message, list) or not all(isinstance(m, bytes) for m in message):
raise ValueError("Input must be a list of bytes")
return tuple(m.decode(encoding, errors=err_mode) for m in message)
class ZMQMessenger(MessengerInterface):
SOCKET_MODES = {
"ROUTER": zmq.ROUTER, # Handles multiple connections and routes messages by address.
"DEALER": zmq.DEALER, # Load balances outgoing messages and receives replies fairly.
"REQ": zmq.REQ, # Sends requests and waits for replies (synchronous).
"REP": zmq.REP, # Receives requests and sends replies (synchronous).
}
def __init__(self, mode: str, endpoint: Optional[str] = None):
self._context = zmq.Context()
self._mode = mode
self._socket = self._context.socket(self.SOCKET_MODES[mode])
self._endpoint: Optional[str] = None
self._lock = Lock()
self._closed = False
self._stop_event = Event()
self._listener_thread: Optional[Thread] = None
self._initialize_control_sockets()
if endpoint is None:
endpoint = f"tcp://{get_local_ip()}:*"
if mode in ["ROUTER", "REP"]:
self._socket.bind(endpoint)
self._endpoint = self._socket.getsockopt_string(zmq.LAST_ENDPOINT)
elif mode in ["DEALER", "REQ"]:
self._socket.connect(endpoint)
self._endpoint = endpoint
logger.info(f"Initialized ZMQMessenger(mode={mode}, endpoint={self._endpoint})")
def _initialize_control_sockets(self):
self._control_socket = self._context.socket(zmq.PAIR)
self._internal_socket = self._context.socket(zmq.PAIR)
inproc_endpoint = "inproc://stop_listener"
self._control_socket.bind(inproc_endpoint)
self._internal_socket.connect(inproc_endpoint)
def start(self):
pass
def send(self, messages: list[bytes], recipient: Optional[bytes] = None):
if recipient:
self._socket.send_multipart([recipient] + messages)
else:
self._socket.send_multipart(messages)
def send_encoded(self, *messages, encoding: str = "ascii"):
encoded_messages = [str(message).encode(encoding) for message in messages]
self.send(encoded_messages)
def receive(self) -> list[bytes]:
return self._socket.recv_multipart()
def start_listener(
self,
on_message: Callable[[list[bytes]], None],
on_error: Optional[Callable[[Exception], None]] = None,
):
assert self._mode in ["ROUTER", "REP"], (
"Listener can only be started in ROUTER or REP modes"
)
if self._listener_thread and self._listener_thread.is_alive():
raise RuntimeError("Listener already running")
def listener():
poller = zmq.Poller()
poller.register(self._socket, zmq.POLLIN)
poller.register(self._control_socket, zmq.POLLIN)
while not self._stop_event.is_set():
events = dict(poller.poll(timeout=100))
try:
if self._control_socket in events:
self._stop_event.set()
elif self._socket in events:
messages = self.receive()
try:
persist = on_message(messages)
if persist is False:
self._stop_event.set()
except Exception as e:
logger.error(f"Error in on_message callback: {e}")
if on_error:
on_error(e)
else:
self._stop_event.set()
except zmq.ZMQError as e:
logger.error(f"ZMQ Error in listener: {e}")
if on_error:
on_error(e)
break
except Exception as e:
logger.error(f"Error in listener: {e}")
if on_error:
on_error(e)
break
self._stop_event.set()
self._listener_thread = Thread(target=listener, daemon=True)
self._listener_thread.start()
def stop(self, timeout=5):
def _close_socket(socket: zmq.Socket):
try:
if not socket.closed:
socket.close()
except Exception as e:
logger.error(f"Error closing socket: {e}")
with self._lock:
if self._closed:
return
self._closed = True
logger.debug("Stopping ZMQMessenger...")
self._stop_event.set()
self._internal_socket.send(b"STOP")
if self._listener_thread:
self._listener_thread.join(timeout)
if self._listener_thread.is_alive():
logger.warning("Listener thread did not terminate within timeout")
_close_socket(self._socket)
_close_socket(self._internal_socket)
_close_socket(self._control_socket)
try:
if self._context.closed:
self._context.term()
except Exception as e:
logger.error(f"Error terminating ZMQ context: {e}")
@property
def endpoint(self) -> str:
assert self._endpoint is not None
return self._endpoint
def __del__(self):
self.stop()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()

View File

@ -0,0 +1,34 @@
def get_local_ip() -> str:
try:
import socket
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(("8.8.8.8", 80))
ip = s.getsockname()[0]
if not ip.startswith("127."):
return ip
except OSError:
pass
try:
import netifaces
for iface in netifaces.interfaces():
addrs = netifaces.ifaddresses(iface)
if netifaces.AF_INET in addrs:
for addr in addrs[netifaces.AF_INET]:
ip = addr.get("addr", "")
if not ip.startswith("127.") and not ip.startswith("169.254"):
return ip
except Exception:
pass
try:
hostname = socket.gethostname()
ip = socket.gethostbyname(hostname)
if not ip.startswith("127."):
return ip
except OSError:
pass
return "127.0.0.1"

View File

@ -0,0 +1,132 @@
from tensorrt_llm._utils import nvtx_range
from tensorrt_llm.tensorrt_llm_transfer_agent_binding import ( # noqa: E402
AgentDesc,
BaseAgentConfig,
MemoryDescs,
MemoryType,
TransferState,
)
from tensorrt_llm.tensorrt_llm_transfer_agent_binding import (
NixlTransferAgent as CppNixlTransferAgent,
)
from ..base.agent import BaseTransferAgent, RegMemoryDescs, TransferRequest, TransferStatus
class BindingsNixlTransferStatus(TransferStatus):
"""TransferStatus wrapper using C++ bindings with GIL release."""
def __init__(self, cpp_status):
self._cpp_status = cpp_status
def is_completed(self) -> bool:
"""Check if transfer is completed (releases GIL)."""
return self._cpp_status.is_completed()
@nvtx_range("BindingsNixlTransferStatus.wait")
def wait(self) -> bool:
"""Wait for transfer to complete (releases GIL)."""
return self._cpp_status.wait() == TransferState.SUCCESS
class BindingsNixlTransferAgent(BaseTransferAgent):
"""NixlTransferAgent using C++ bindings with GIL release support.
This implementation uses the standalone nixl_bindings C++ module which releases
the GIL during blocking operations like wait().
The nixl_bindings module is independent from the main trtllm bindings,
so trtllm can function normally even without NIXL.
"""
def __init__(self, name: str, use_prog_thread: bool = True, num_workers: int = 1):
config = BaseAgentConfig(
name,
use_prog_thread,
multi_thread=False,
use_listen_thread=False,
num_workers=num_workers,
)
self._cpp_agent = CppNixlTransferAgent(config)
self.name = name
def register_memory(self, descs: RegMemoryDescs):
"""Register memory regions."""
cpp_descs = self._convert_reg_memory_descs(descs)
self._cpp_agent.register_memory(cpp_descs)
def deregister_memory(self, descs: RegMemoryDescs):
"""Deregister memory regions."""
cpp_descs = self._convert_reg_memory_descs(descs)
self._cpp_agent.deregister_memory(cpp_descs)
def load_remote_agent(self, name: str, agent_desc: bytes):
"""Load a remote agent by its descriptor (bytes)."""
# AgentDesc expects std::string which can hold binary data
desc_str = agent_desc if isinstance(agent_desc, bytes) else agent_desc.encode()
cpp_desc = AgentDesc(desc_str)
self._cpp_agent.load_remote_agent(name, cpp_desc)
def load_remote_agent_by_connection(self, name: str, connection_info: str):
"""Load a remote agent by connection info."""
self._cpp_agent.load_remote_agent_by_connection(name, connection_info)
def get_local_agent_desc(self) -> bytes:
"""Get the local agent descriptor as bytes."""
agent_desc = self._cpp_agent.get_local_agent_desc()
return agent_desc.backend_agent_desc # Returns bytes
def get_local_connection_info(self) -> str:
"""Get the local connection info."""
return self._cpp_agent.get_local_connection_info()
def invalidate_remote_agent(self, name: str):
"""Invalidate a remote agent."""
self._cpp_agent.invalidate_remote_agent(name)
def check_remote_descs(self, name: str, memory_descs: MemoryDescs) -> bool:
"""Check if remote descriptors are available.
memory_descs should be C++ MemoryDescs type.
"""
return self._cpp_agent.check_remote_descs(name, memory_descs)
def notify_sync_message(self, name: str, sync_message: str):
"""Send a sync message to a remote agent."""
self._cpp_agent.notify_sync_message(name, sync_message)
def get_notified_sync_messages(self):
"""Get notified sync messages."""
return self._cpp_agent.get_notified_sync_messages()
@nvtx_range("BindingsNixlTransferAgent.submit_transfer_requests")
def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus:
"""Submit transfer requests and return status.
request should be a C++ TransferRequest (from tensorrt_llm_transfer_agent_binding).
"""
cpp_status = self._cpp_agent.submit_transfer_requests(request)
return BindingsNixlTransferStatus(cpp_status)
def _convert_reg_memory_descs(self, descs: RegMemoryDescs) -> "MemoryDescs":
"""Convert Python RegMemoryDescs to C++ MemoryDescs.
RegMemoryDescs.descs is List[Tuple[int, int, int, str]] = (ptr, size, device_id, name)
Extract first 3 elements for C++ batch constructor.
"""
mem_type = self._convert_memory_type(descs.type)
# Extract (ptr, size, device_id) from 4-tuple, discard name
tuples = [(d[0], d[1], d[2]) for d in descs.descs]
return MemoryDescs(mem_type, tuples)
def _convert_memory_type(self, py_type: str) -> "MemoryType":
"""Convert Python memory type string to C++ MemoryType."""
type_map = {
"DRAM": MemoryType.DRAM,
"VRAM": MemoryType.VRAM,
"GPU": MemoryType.VRAM,
"BLK": MemoryType.BLK,
"OBJ": MemoryType.OBJ,
"FILE": MemoryType.FILE,
}
return type_map.get(py_type.upper(), MemoryType.VRAM)

View File

@ -0,0 +1,99 @@
import time
from nixl import nixl_agent, nixl_agent_config, nixl_xfer_handle # noqa: E402
from tensorrt_llm._utils import nvtx_range
# Import base classes for type compatibility
from ..base.agent import BaseTransferAgent, RegMemoryDescs, TransferRequest, TransferStatus
class NixlTransferStatus(TransferStatus):
"""TransferStatus using Python nixl library."""
def __init__(self, agent: nixl_agent, handle: nixl_xfer_handle):
self.agent = agent
self.handle = handle
def is_completed(self):
status = self.agent.check_xfer_state(self.handle)
return status == "DONE"
def wait(self):
status = "PROC"
sleep_time = 0.0001 # 0.1ms
max_sleep_time = 0.01 # 10ms
while status == "PROC":
status = self.agent.check_xfer_state(self.handle)
if status == "ERROR":
return False # transfer failed
# sleep(0.1)
# sleep to release GIL
time.sleep(sleep_time)
sleep_time = min(sleep_time * 2, max_sleep_time)
return True
class NixlTransferAgent(BaseTransferAgent):
"""NixlTransferAgent using Python nixl library."""
def __init__(self, name: str, use_prog_thread: bool = True, num_workers: int = 1):
"""
Initialize NixlTransferAgent.
:param name: Name of the agent.
:param use_prog_thread: Whether to enable the progress thread, if available.
:param num_workers: Specify number of threads for the supported multi-threaded backends.
"""
self.name = name
self.backends = ["UCX"]
agent_config = nixl_agent_config(
enable_prog_thread=use_prog_thread, backends=self.backends, num_threads=num_workers
)
self.agent = nixl_agent(name, agent_config)
def register_memory(self, descs: RegMemoryDescs):
if isinstance(descs.descs[0], tuple):
assert len(descs.descs[0]) == 4, f"Expected 4 elements per desc, got {descs.descs[0]}"
reg_descs = self.agent.get_reg_descs(descs.descs, descs.type)
assert reg_descs is not None, "Failed to get reg_descs"
self.agent.register_memory(reg_descs)
def deregister_memory(self, descs: RegMemoryDescs):
if isinstance(descs.descs[0], tuple):
assert len(descs.descs[0]) == 4, f"Expected 4 elements per desc, got {descs.descs[0]}"
reg_descs = self.agent.get_reg_descs(descs.descs, descs.type)
assert reg_descs is not None, "Failed to get reg_descs"
self.agent.deregister_memory(reg_descs)
def load_remote_agent(self, name: str, agent_desc: bytes):
self.agent.add_remote_agent(agent_desc)
def get_local_agent_desc(self):
return self.agent.get_agent_metadata()
def invalidate_remote_agent(self, name: str):
self.agent.remove_remote_agent(name)
def check_remote_descs(self, name: str, memory_descs: list[int]) -> bool:
raise NotImplementedError
def notify_sync_message(self, name: str, sync_message: str):
raise NotImplementedError
@nvtx_range("NixlTransferAgent.submit_transfer_requests")
def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus:
src_xfer_descs = self.agent.get_xfer_descs(request.src_descs.descs, request.src_descs.type)
dst_xfer_descs = self.agent.get_xfer_descs(request.dst_descs.descs, request.dst_descs.type)
assert src_xfer_descs is not None, "Failed to get src_xfer_descs"
assert dst_xfer_descs is not None, "Failed to get dst_xfer_descs"
sync_message = "" if request.sync_message is None else request.sync_message
handle = self.agent.initialize_xfer(
request.op,
src_xfer_descs,
dst_xfer_descs,
request.remote_name,
sync_message,
)
status = self.agent.transfer(handle)
assert status != "ERROR"
return NixlTransferStatus(self.agent, handle)

View File

@ -0,0 +1,46 @@
from tensorrt_llm.logger import logger
from ..base.agent import _force_py_nixl_kv_transfer
"""NIXL Transfer Agent implementations.
This module provides two implementations:
1. BindingsNixlTransferAgent - Uses the standalone nixl_bindings C++ module with GIL release support
2. NixlTransferAgent - Uses the Python nixl library directly (fallback)
The standalone nixl_bindings module is separate from the main trtllm bindings,
so trtllm can still function normally even without NIXL dependencies.
"""
def _load_agent(module_name, required_attributes):
try:
module = __import__(module_name, fromlist=required_attributes, level=0)
if all(hasattr(module, attr) for attr in required_attributes):
return module
except ImportError:
logger.info("Failed to import module: %s", module_name)
return None
_py_agent = _load_agent(
module_name="tensorrt_llm._torch.disaggregation.nixl._agent_py",
required_attributes=["NixlTransferAgent", "NixlTransferStatus"],
)
_cpp_agent = _load_agent(
module_name="tensorrt_llm._torch.disaggregation.nixl._agent_cpp",
required_attributes=["BindingsNixlTransferAgent", "BindingsNixlTransferStatus"],
)
# Determine which Transfer Agent implementation to use
if _cpp_agent and not _force_py_nixl_kv_transfer():
NixlTransferStatus = _cpp_agent.BindingsNixlTransferStatus
NixlTransferAgent = _cpp_agent.BindingsNixlTransferAgent
logger.info("Using C++ NIXL Transfer Agent implementation.")
elif _py_agent:
NixlTransferStatus = _py_agent.NixlTransferStatus
NixlTransferAgent = _py_agent.NixlTransferAgent
logger.info("Using Python NIXL Transfer Agent implementation.")
else:
raise ImportError("Both C++ and Python NIXL Transfer Agents failed to load.")

View File

@ -0,0 +1,156 @@
from dataclasses import dataclass, field
import pytest
import torch
from tensorrt_llm import logger
from tensorrt_llm._torch.disaggregation.base.agent import (
MemoryDescs,
MemoryType,
RegMemoryDescs,
TransferOp,
TransferRequest,
)
from tensorrt_llm._torch.disaggregation.nixl.agent import NixlTransferAgent
def _convert_to_memory_descs(reg_descs: RegMemoryDescs) -> MemoryDescs:
tuples = [(ptr, size, device_id) for (ptr, size, device_id, _) in reg_descs.descs]
def _convert_memory_type(py_type: str) -> MemoryType:
"""Convert Python memory type string to C++ MemoryType."""
type_map = {
"DRAM": MemoryType.DRAM,
"VRAM": MemoryType.VRAM,
"GPU": MemoryType.VRAM,
"BLK": MemoryType.BLK,
"OBJ": MemoryType.OBJ,
"FILE": MemoryType.FILE,
}
return type_map.get(py_type.upper(), MemoryType.VRAM)
return MemoryDescs(_convert_memory_type(reg_descs.type), tuples)
@dataclass
class MemoryManager:
allocated_memory: list[torch.Tensor] = field(default_factory=list)
def allocate_memory(
self, size: int, name: str, memory_type=MemoryType.VRAM, device_id: int = 0
) -> RegMemoryDescs:
device = torch.device(f"cuda:{device_id}" if memory_type == MemoryType.VRAM else "cpu")
# Allocate memory block using torch.Tensor and track it
block = torch.zeros(size, dtype=torch.uint8, device=device)
self.allocated_memory.append(block)
# Return RegMemoryDescs with position arguments
memory_descs = RegMemoryDescs(
type=memory_type, descs=[(block.data_ptr(), block.numel(), device_id, name)]
)
return memory_descs
def clear_memory(self):
"""Clear all tracked memory blocks."""
self.allocated_memory.clear()
@pytest.fixture
def memory_manager():
return MemoryManager()
@pytest.fixture(params=[256, 512])
def memory_size(request):
return request.param
@pytest.fixture(params=["DRAM", "VRAM"])
def memory_type(request):
return request.param
@pytest.fixture
def alloc(memory_manager, memory_size, memory_type):
"""Allocate memory for source and destination, based on the memory_size and memory_type parameters."""
assert memory_size > 0, "Memory size must be a positive integer."
src_descs = memory_manager.allocate_memory(
size=memory_size, name="src_mem", memory_type=memory_type
)
dst_descs = memory_manager.allocate_memory(
size=memory_size, name="dst_mem", memory_type=memory_type
)
return src_descs, dst_descs
@pytest.fixture
def transfer_agent_src():
return NixlTransferAgent(name="src_agent")
@pytest.fixture
def transfer_agent_dst():
return NixlTransferAgent(name="dst_agent")
def test_transfer_between_agents(
transfer_agent_src,
transfer_agent_dst,
memory_manager,
alloc,
memory_size,
memory_type,
):
"""End-to-end test of data transfer between two agents with parameterized memory sizes and types."""
# Debug log the parameters being tested
logger.info(f"Testing with memory_size={memory_size}, memory_type={memory_type}")
# Unpack source and destination memory descriptions
memory_descs_src, memory_descs_dst = alloc
# Fill source memory with sequential data for validation
src_data = memory_manager.allocated_memory[0]
assert memory_size > 0, "Memory size must be positive."
tensor = torch.arange(memory_size, dtype=torch.uint8) % 10
src_data.copy_(tensor)
# Register memory with source and destination agents
transfer_agent_src.register_memory(memory_descs_src)
transfer_agent_dst.register_memory(memory_descs_dst)
src_agent_desc = transfer_agent_src.get_local_agent_desc()
transfer_agent_dst.load_remote_agent("src_agent", src_agent_desc)
dst_agent_desc = transfer_agent_dst.get_local_agent_desc()
transfer_agent_src.load_remote_agent("dst_agent", dst_agent_desc)
# Create and submit the transfer request
transfer_request = TransferRequest(
op=TransferOp.WRITE,
src_descs=_convert_to_memory_descs(memory_descs_src),
dst_descs=_convert_to_memory_descs(memory_descs_dst),
remote_name="dst_agent",
sync_message=None,
)
transfer_status = transfer_agent_src.submit_transfer_requests(transfer_request)
transfer_status.wait()
# Validate transfer completion
assert transfer_status.is_completed(), "Transfer did not complete successfully."
# Validate that the destination data matches the source data
dst_data = memory_manager.allocated_memory[1]
assert torch.equal(dst_data, src_data), "Destination data does not match source data."
# Clean up by deregistering memory and clearing allocations
transfer_agent_src.deregister_memory(memory_descs_src)
transfer_agent_dst.deregister_memory(memory_descs_dst)
memory_manager.clear_memory()
transfer_agent_src.invalidate_remote_agent("dst_agent")
transfer_agent_dst.invalidate_remote_agent("src_agent")
if __name__ == "__main__":
pytest.main()

View File

@ -0,0 +1,127 @@
import socket
import time
import unittest
import pytest
from parameterized import parameterized
from tensorrt_llm._torch.disaggregation.native.messenger import ZMQMessenger, decode_message
from tensorrt_llm._torch.disaggregation.native.utils import get_local_ip
TEST_CASES = [
{
"name": "valid_message",
"message": [b"hello", b"world"],
"encoding": "utf-8",
"err_mode": "strict",
"expected": ("hello", "world"),
"raises": None,
},
{
"name": "invalid_input",
"message": ["hello", b"world"],
"encoding": "utf-8",
"err_mode": "strict",
"expected": None,
"raises": ValueError,
},
{
"name": "decoding_error",
"message": [b"\xff"],
"encoding": "utf-8",
"err_mode": "strict",
"expected": None,
"raises": UnicodeDecodeError,
},
{
"name": "decoding_with_ignore",
"message": [b"\xff"],
"encoding": "utf-8",
"err_mode": "ignore",
"expected": ("",),
"raises": None,
},
]
class TestDecodeMessage(unittest.TestCase):
@parameterized.expand([(case["name"], case) for case in TEST_CASES])
def test_decode_message(self, name, case):
message = case["message"]
encoding = case["encoding"]
err_mode = case["err_mode"]
expected = case["expected"]
raises = case["raises"]
if raises:
with self.assertRaises(raises):
decode_message(message, encoding=encoding, err_mode=err_mode)
else:
decoded = decode_message(message, encoding=encoding, err_mode=err_mode)
self.assertEqual(decoded, expected)
@pytest.fixture
def dynamic_endpoint():
"""Fixture to dynamically generate an available endpoint with a free port."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0)) # Bind to an available port provided by the OS
port = s.getsockname()[1]
return f"tcp://{get_local_ip()}:{port}"
@pytest.fixture
def create_messenger_pair(dynamic_endpoint):
def _create_messenger_pair(mode1, mode2):
messenger1 = ZMQMessenger(
mode1, endpoint=dynamic_endpoint if mode1 in ["ROUTER", "REP"] else None
)
messenger2 = ZMQMessenger(
mode2, endpoint=dynamic_endpoint if mode2 in ["DEALER", "REQ"] else None
)
return messenger1, messenger2
yield _create_messenger_pair
def test_router_dealer(create_messenger_pair):
"""Test ROUTER and DEALER communication."""
router, dealer = create_messenger_pair("ROUTER", "DEALER")
received_messages = []
def on_message(messages):
received_messages.extend(messages)
router.start_listener(on_message)
dealer.send([b"Hello, ROUTER!"])
time.sleep(0.1)
assert len(received_messages) > 0
assert b"Hello, ROUTER!" in received_messages
router.stop()
dealer.stop()
def test_req_rep(create_messenger_pair):
"""Test REQ and REP communication."""
rep, req = create_messenger_pair("REP", "REQ")
def on_message(messages):
rep.send(messages)
rep.start_listener(on_message)
req.send([b"Hello, REP!"])
response = req.receive()
assert response == [b"Hello, REP!"]
req.stop()
rep.stop()
if __name__ == "__main__":
unittest.main()