[TRTLLM-9527][feat] Python transceiver components (step 2) (#10494)

Signed-off-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
This commit is contained in:
Shi Xiaowei 2026-01-23 02:14:50 +08:00 committed by GitHub
parent 9adef4eb28
commit 944c304bbb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 1331 additions and 30 deletions

View File

@ -296,7 +296,8 @@ struct BaseAgentConfig
bool useProgThread;
bool multiThread;
bool useListenThread;
unsigned int numWorkers;
bool enableTelemetry;
std::unordered_map<std::string, std::string> backendParams;
};
class BaseTransferAgent

View File

@ -247,7 +247,7 @@ AgentConnectionManager::AgentConnectionManager(
mAgentName = genUniqueAgentName();
// Create Agent
BaseAgentConfig config{mAgentName, true, false, true, 1};
BaseAgentConfig config{mAgentName, true, false, true};
m_Agent = makeTransferAgent(backendType, &config);
TLLM_CHECK(!mCacheTransBufferManagers.empty());
std::vector<MemoryDesc> memDescs;

View File

@ -26,10 +26,9 @@
#endif
#include <nanobind/nanobind.h>
#include <nanobind/stl/function.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/pair.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/tuple.h>
#include <nanobind/stl/unordered_map.h>
#include <nanobind/stl/vector.h>
@ -69,6 +68,21 @@ NB_MODULE(tensorrt_llm_transfer_agent_binding, m)
// MemoryDescs class
nb::class_<kvc::MemoryDescs>(m, "MemoryDescs")
.def(nb::init<kvc::MemoryType, std::vector<kvc::MemoryDesc>>(), nb::arg("type"), nb::arg("descs"))
// Batch constructor from list of tuples: [(ptr, size, device_id), ...]
.def(
"__init__",
[](kvc::MemoryDescs* self, kvc::MemoryType type,
std::vector<std::tuple<uintptr_t, size_t, uint32_t>> const& tuples)
{
std::vector<kvc::MemoryDesc> descs;
descs.reserve(tuples.size());
for (auto const& [addr, len, deviceId] : tuples)
{
descs.emplace_back(addr, len, deviceId);
}
new (self) kvc::MemoryDescs(type, std::move(descs));
},
nb::arg("type"), nb::arg("tuples"))
.def_prop_ro("type", &kvc::MemoryDescs::getType)
.def_prop_ro("descs", &kvc::MemoryDescs::getDescs);
@ -113,17 +127,21 @@ NB_MODULE(tensorrt_llm_transfer_agent_binding, m)
.def(
"__init__",
[](kvc::BaseAgentConfig* self, std::string name, bool use_prog_thread, bool multi_thread,
bool use_listen_thread, unsigned int num_workers) {
new (self) kvc::BaseAgentConfig{
std::move(name), use_prog_thread, multi_thread, use_listen_thread, num_workers};
bool use_listen_thread, bool enable_telemetry,
std::unordered_map<std::string, std::string> backend_params)
{
new (self) kvc::BaseAgentConfig{std::move(name), use_prog_thread, multi_thread, use_listen_thread,
enable_telemetry, std::move(backend_params)};
},
nb::arg("name"), nb::arg("use_prog_thread") = true, nb::arg("multi_thread") = false,
nb::arg("use_listen_thread") = false, nb::arg("num_workers") = 1)
nb::arg("use_listen_thread") = false, nb::arg("enable_telemetry") = false,
nb::arg("backend_params") = std::unordered_map<std::string, std::string>{})
.def_rw("name", &kvc::BaseAgentConfig::mName)
.def_rw("use_prog_thread", &kvc::BaseAgentConfig::useProgThread)
.def_rw("multi_thread", &kvc::BaseAgentConfig::multiThread)
.def_rw("use_listen_thread", &kvc::BaseAgentConfig::useListenThread)
.def_rw("num_workers", &kvc::BaseAgentConfig::numWorkers);
.def_rw("enable_telemetry", &kvc::BaseAgentConfig::enableTelemetry)
.def_rw("backend_params", &kvc::BaseAgentConfig::backendParams);
// BaseTransferAgent class (abstract base)
nb::class_<kvc::BaseTransferAgent>(m, "BaseTransferAgent")

View File

@ -66,6 +66,19 @@ PYBIND11_MODULE(tensorrt_llm_transfer_agent_binding, m)
// MemoryDescs class
py::class_<kvc::MemoryDescs>(m, "MemoryDescs")
.def(py::init<kvc::MemoryType, std::vector<kvc::MemoryDesc>>(), py::arg("type"), py::arg("descs"))
// Batch constructor from list of tuples: [(ptr, size, device_id), ...]
.def(py::init(
[](kvc::MemoryType type, std::vector<std::tuple<uintptr_t, size_t, uint32_t>> const& tuples)
{
std::vector<kvc::MemoryDesc> descs;
descs.reserve(tuples.size());
for (auto const& [addr, len, deviceId] : tuples)
{
descs.emplace_back(addr, len, deviceId);
}
return kvc::MemoryDescs(type, std::move(descs));
}),
py::arg("type"), py::arg("tuples"))
.def_property_readonly("type", &kvc::MemoryDescs::getType)
.def_property_readonly("descs", &kvc::MemoryDescs::getDescs);
@ -108,17 +121,20 @@ PYBIND11_MODULE(tensorrt_llm_transfer_agent_binding, m)
.def(py::init<>())
.def(py::init(
[](std::string name, bool use_prog_thread, bool multi_thread, bool use_listen_thread,
unsigned int num_workers) {
return kvc::BaseAgentConfig{
std::move(name), use_prog_thread, multi_thread, use_listen_thread, num_workers};
bool enable_telemetry, std::unordered_map<std::string, std::string> backend_params)
{
return kvc::BaseAgentConfig{std::move(name), use_prog_thread, multi_thread, use_listen_thread,
enable_telemetry, std::move(backend_params)};
}),
py::arg("name"), py::arg("use_prog_thread") = true, py::arg("multi_thread") = false,
py::arg("use_listen_thread") = false, py::arg("num_workers") = 1)
py::arg("use_listen_thread") = false, py::arg("enable_telemetry") = false,
py::arg("backend_params") = std::unordered_map<std::string, std::string>{})
.def_readwrite("name", &kvc::BaseAgentConfig::mName)
.def_readwrite("use_prog_thread", &kvc::BaseAgentConfig::useProgThread)
.def_readwrite("multi_thread", &kvc::BaseAgentConfig::multiThread)
.def_readwrite("use_listen_thread", &kvc::BaseAgentConfig::useListenThread)
.def_readwrite("num_workers", &kvc::BaseAgentConfig::numWorkers);
.def_readwrite("enable_telemetry", &kvc::BaseAgentConfig::enableTelemetry)
.def_readwrite("backend_params", &kvc::BaseAgentConfig::backendParams);
// BaseTransferAgent class (abstract base)
py::class_<kvc::BaseTransferAgent>(m, "BaseTransferAgent")

View File

@ -374,22 +374,28 @@ NixlTransferAgent::NixlTransferAgent(BaseAgentConfig const& config)
}
auto envPort = common::getEnvNixlPort();
uint16_t port = envPort > 0 ? getIncrmentPort(envPort) : getAvailablePort();
nixlAgentConfig nixlConfig{
config.useProgThread, true, port, nixl_thread_sync_t::NIXL_THREAD_SYNC_DEFAULT, config.numWorkers};
uint32_t numWorker = config.backendParams.find("num_workers") != config.backendParams.end()
? std::stoi(config.backendParams.at("num_workers"))
: 1;
nixlAgentConfig nixlConfig{config.useProgThread, true, port, nixl_thread_sync_t::NIXL_THREAD_SYNC_DEFAULT,
numWorker, 0, 10000, config.enableTelemetry};
mAddress = getAvailableIP() + ":" + std::to_string(port);
mRawAgent = std::make_unique<nixlAgent>(config.mName, std::move(nixlConfig));
}
else
{
uint32_t numWorker = config.backendParams.find("num_workers") != config.backendParams.end()
? std::stoi(config.backendParams.at("num_workers"))
: 1;
mAddress.clear();
nixlAgentConfig nixlConfig{
config.useProgThread, false, 0, nixl_thread_sync_t::NIXL_THREAD_SYNC_DEFAULT, config.numWorkers};
nixlAgentConfig nixlConfig{config.useProgThread, false, 0, nixl_thread_sync_t::NIXL_THREAD_SYNC_DEFAULT,
numWorker, 0, 10000, config.enableTelemetry};
mRawAgent = std::make_unique<nixlAgent>(config.mName, std::move(nixlConfig));
}
std::string nixlBackend = common::getEnvNixlBackend();
// List of supported backends - extend this list as new backends are added
static const std::set<std::string> kSUPPORTED_BACKENDS = {"UCX", "LIBFABRIC"};
static std::set<std::string> const kSUPPORTED_BACKENDS = {"UCX", "LIBFABRIC"};
if (kSUPPORTED_BACKENDS.find(nixlBackend) == kSUPPORTED_BACKENDS.end())
{
@ -400,6 +406,11 @@ NixlTransferAgent::NixlTransferAgent(BaseAgentConfig const& config)
TLLM_LOG_INFO("NixlTransferAgent::NixlTransferAgent using NIXL backend: %s", nixlBackend.c_str());
nixl_b_params_t init1;
for (auto const& [key, value] : config.backendParams)
{
init1[key] = value;
TLLM_LOG_INFO("NixlTransferAgent::NixlTransferAgent backendParams: %s: %s", key.c_str(), value.c_str());
}
nixl_mem_list_t mems1;
status = mRawAgent->getPluginParams(nixlBackend.c_str(), mems1, init1);
TLLM_CHECK(status == NIXL_SUCCESS);

View File

@ -92,7 +92,7 @@ TEST_P(TransferAgentTest, Basic)
{
std::string const agent0{"agent0"}, agent1{"agent1"};
BaseAgentConfig config0{agent0, true, false, true, 1}, config1{agent1, true, false, true, 1};
BaseAgentConfig config0{agent0, true, false, true}, config1{agent1, true, false, true};
auto xferAgent0 = makeTransferAgent(config0);
auto xferAgent1 = makeTransferAgent(config1);
@ -127,7 +127,7 @@ TEST_P(TransferAgentTest, Basic2)
{
std::string const agent0{"agent0"}, agent1{"agent1"};
BaseAgentConfig config0{agent0, true, false, true, 1}, config1{agent1, true, false, true, 1};
BaseAgentConfig config0{agent0, true, false, true}, config1{agent1, true, false, true};
auto xferAgent0 = makeTransferAgent(config0);
auto xferAgent1 = makeTransferAgent(config1);
@ -162,7 +162,7 @@ TEST_P(TransferAgentTest, DeviceMemory)
{
std::string const agent0{"agent0"}, agent1{"agent1"};
BaseAgentConfig config0{agent0, true, false, true, 1}, config1{agent1, true, false, true, 1};
BaseAgentConfig config0{agent0, true, false, true}, config1{agent1, true, false, true};
auto xferAgent0 = makeTransferAgent(config0);
auto xferAgent1 = makeTransferAgent(config1);
@ -209,8 +209,8 @@ TEST_P(TransferAgentTest, Connect)
{
std::string const agent0{"agent0"}, agent1{"agent1"}, agent2{"agent2"};
BaseAgentConfig config0{agent0, true, false, true, 1}, config1{agent1, true, false, true, 1},
config2{agent2, true, false, true, 1};
BaseAgentConfig config0{agent0, true, false, true}, config1{agent1, true, false, true},
config2{agent2, true, false, true};
auto xferAgent0 = makeTransferAgent(config0);
auto xferAgent1 = makeTransferAgent(config1);
auto xferAgent2 = makeTransferAgent(config2);
@ -261,7 +261,7 @@ TEST_P(TransferAgentTest, SyncMessage)
{
constexpr std::size_t MAX_QUERY_TIMES = std::numeric_limits<size_t>::max();
std::string const agent0{"agent0"}, agent1{"agent1"};
BaseAgentConfig config0{agent0, true, false, true, 1}, config1{agent1, true, false, true, 1};
BaseAgentConfig config0{agent0, true, false, true}, config1{agent1, true, false, true};
auto xferAgent0 = makeTransferAgent(config0);
auto xferAgent1 = makeTransferAgent(config1);

View File

@ -38,3 +38,4 @@ opentelemetry-semantic-conventions-ai>=0.4.1
fuzzywuzzy==0.18.0
aiperf==0.3.0
nanobind>=2.9.0
nixl==0.8.0

View File

@ -0,0 +1,145 @@
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, NamedTuple, Optional
from tensorrt_llm import logger
# We deliberately use a non-enum data structure here. This choice ensures that
# members are directly equivalent to the plain strings.
class TransferOp:
READ = "READ"
WRITE = "WRITE"
class MemoryType:
DRAM = "DRAM"
VRAM = "VRAM"
BLK = "BLK"
OBJ = "OBJ"
FILE = "FILE"
class MemoryDesc(NamedTuple):
ptr: int
size: int
device_id: int
name: Optional[str] = None
@dataclass
class MemoryDescs:
type: str
descs: List[MemoryDesc]
@dataclass
class TransferRequest:
op: TransferOp
src_descs: MemoryDescs
dst_descs: MemoryDescs
remote_name: str
sync_message: Optional[str] = None
@dataclass
class RegMemoryDescs:
type: str
descs: List[MemoryDesc]
class TransferStatus(ABC):
@abstractmethod
def is_completed(self) -> bool: ...
@abstractmethod
def wait(self, timeout_ms: int | None = None) -> bool: ...
class BaseTransferAgent(ABC):
@abstractmethod
def register_memory(self, descs: RegMemoryDescs) -> None: ...
@abstractmethod
def deregister_memory(self, descs: RegMemoryDescs) -> None: ...
@abstractmethod
def load_remote_agent(self, name: str, agent_desc: bytes) -> None: ...
@abstractmethod
def get_local_agent_desc(self) -> bytes: ...
@abstractmethod
def invalidate_remote_agent(self, name: str) -> None: ...
@abstractmethod
def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus: ...
@abstractmethod
def notify_sync_message(self, name: str, sync_message: str) -> None: ...
@abstractmethod
def check_remote_descs(self, name: str, memory_descs: MemoryDescs) -> bool: ...
def _force_py_nixl_kv_transfer() -> bool:
env_value = os.getenv("TRTLLM_USE_PY_NIXL_KVCACHE", "0")
if env_value not in {"0", "1"}:
logger.warning(
f"Invalid value for TRTLLM_USE_PY_NIXL_KVCACHE: {env_value}. Expected '0' or '1'. Defaulting to '0'."
)
return False
if env_value == "1":
logger.info("Forcing use of pure Python NIXL KV Transfer Agent implementation.")
return True
return False
def _try_load_cpp_binding():
try:
import tensorrt_llm.tensorrt_llm_transfer_agent_binding as _cpp_binding
required_attributes = [
"MemoryType",
"TransferOp",
"MemoryDesc",
"MemoryDescs",
"TransferRequest",
"TransferStatus",
"BaseTransferAgent",
]
if all(hasattr(_cpp_binding, attr) for attr in required_attributes):
return _cpp_binding
except ImportError:
logger.info("tensorrt_llm_transfer_agent_binding module not found.")
return None
_use_pure_python_transfer_agent = None
# The current implementation still implicitly depends on cpp_bindings.
# We should remove this dependency.
_cpp_binding = _try_load_cpp_binding()
if _force_py_nixl_kv_transfer():
logger.info("Using pure Python transfer agent (forced by TRTLLM_USE_PY_NIXL_KVCACHE)")
_use_pure_python_transfer_agent = True
else:
if _cpp_binding:
MemoryType = _cpp_binding.MemoryType
TransferOp = _cpp_binding.TransferOp
MemoryDesc = _cpp_binding.MemoryDesc
MemoryDescs = _cpp_binding.MemoryDescs
TransferRequest = _cpp_binding.TransferRequest
TransferStatus = _cpp_binding.TransferStatus
BaseTransferAgent = _cpp_binding.BaseTransferAgent
logger.info("Using C++ transfer agent binding")
_use_pure_python_transfer_agent = False
else:
logger.info("C++ transfer agent binding unavailable, using pure Python implementation")
_use_pure_python_transfer_agent = True
def use_pure_python_transfer_agent() -> bool:
return _use_pure_python_transfer_agent

View File

@ -0,0 +1,200 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional
from tensorrt_llm import DisaggregatedParams
@dataclass
class TokenRange:
"""Range of tokens in the sequence dimension."""
start: int
end: int # exclusive
def __post_init__(self):
if self.start < 0 or self.end < 0:
raise ValueError("Token indices must be non-negative")
if self.start >= self.end:
raise ValueError(f"Invalid range: [{self.start}, {self.end})")
@dataclass
class LayerRange:
"""Range of layers to transfer."""
start: int
end: int # exclusive
def __post_init__(self):
if self.start < 0 or self.end < 0:
raise ValueError("Layer indices must be non-negative")
if self.start >= self.end:
raise ValueError(f"Invalid range: [{self.start}, {self.end})")
@dataclass
class KVSlice:
"""
Specifies which portion of KV cache to transfer.
"""
token_range: Optional[TokenRange] = None
layer_range: Optional[LayerRange] = None
block_ids: List[int] = field(default_factory=list) # Physical block IDs
is_last_slice: bool = False
class SessionStatus(Enum):
"""Status of a transfer session.
Represents the various stages/statuses that a file transfer session can go through:
- INIT: The session has been initialized but not yet ready.
- READY: The session is ready to start transferring.
- TRANSFERRING: The session is in progress, currently transferring data.
- TRANSFERRED: The primary transfer has completed successfully.
- AUX_TRANSFERRED: The auxiliary part (such as tokens) of the transfer has completed successfully.
- COMPLETED: The entire session process, including all transfers, has been successfully completed.
- CANCELED: The session has been canceled by the user or system.
- ERROR: An error occurred during the session. The session could not complete successfully.
"""
INIT = "INIT"
READY = "READY"
TRANSFERRING = "TRANSFERRING"
TRANSFERRED = "TRANSFERRED"
AUX_TRANSFERRED = "AUX_TRANSFERRED"
COMPLETED = "COMPLETED"
CANCELED = "CANCELED"
ERROR = "ERROR"
TaskIdType = int
@dataclass
class SessionState:
"""State of a transfer session."""
status: SessionStatus
finished_tasks: List[TaskIdType]
@dataclass
class SessionArgsBase:
"""Base arguments for transfer sessions."""
params: DisaggregatedParams
class SenderBase(ABC):
"""Base class for sending KV cache data."""
...
class ReceiverBase(ABC):
"""Base class for receiving KV cache data."""
...
class TxSessionBase(ABC):
def __init__(self, sender: SenderBase, args: SessionArgsBase):
"""
Initializes the transmission session.
:param sender: The sender instance responsible for sending data.
:param args: The session arguments.
"""
self._sender = sender
self._base_args = args
@property
@abstractmethod
def state(self) -> SessionState:
"""
Returns the current state of the session.
"""
...
@abstractmethod
def poll_task(self, id: TaskIdType) -> SessionStatus:
"""
Polls the status of a specific task by its ID.
:param id: The task ID to poll.
"""
...
@abstractmethod
def send(self, slice: KVSlice) -> TaskIdType:
"""
Sends a slice of KV cache data and returns the task ID.
:param slice: The KV slice to send.
"""
...
@property
@abstractmethod
def exception(self) -> Optional[Exception]:
"""
Returns any exception that occurred during the session.
"""
...
@abstractmethod
def close(self) -> None:
"""
Closes the session and releases any resources.
"""
...
class RxSessionBase(ABC):
def __init__(self, receiver: ReceiverBase, args: SessionArgsBase):
"""
Initializes the reception session.
:param receiver: The receiver instance responsible for receiving data.
"""
self._receiver = receiver
self._base_args = args
@property
@abstractmethod
def state(self) -> SessionState:
"""
Returns the current state of the session.
"""
...
@abstractmethod
def poll_task(self, task_id: TaskIdType) -> SessionStatus:
"""
Polls the status of a specific task by its ID.
:param task_id: The task ID to poll.
"""
...
@abstractmethod
def receive(self, slice: KVSlice) -> TaskIdType:
"""
Receives a slice of KV cache data and returns the task ID.
:param slice: The KV slice to receive.
"""
...
@property
@abstractmethod
def exception(self) -> Optional[Exception]:
"""Returns any exception that occurred during the session."""
...
@abstractmethod
def close(self) -> None:
"""
Closes the session and releases any resources.
"""
...

View File

@ -0,0 +1,219 @@
from abc import ABC, abstractmethod
from threading import Event, Lock, Thread
from typing import Callable, Optional
import zmq
from tensorrt_llm import logger
from tensorrt_llm._torch.disaggregation.native.utils import get_local_ip
class MessengerInterface(ABC):
"""
Abstract base class for messenger implementations.
"""
@abstractmethod
def start(self) -> None:
"""
Start the messenger service.
"""
...
@abstractmethod
def send(self, messages: list[bytes], recipient: Optional[bytes] = None) -> None:
"""
Send messages to a recipient.
:param messages: List of byte messages to send.
:param recipient: Optional recipient identifier.
"""
...
@abstractmethod
def receive(self) -> list[bytes]:
"""
Receive messages.
:return: List of byte messages received.
"""
...
@abstractmethod
def start_listener(self, on_message: Callable[[list[bytes]], Optional[bool]]) -> None:
"""
Start a listener thread to handle incoming messages.
:param on_message: Callback function to process received messages.
"""
...
@abstractmethod
def stop(self) -> None:
"""
Stop the messenger service.
"""
...
@property
@abstractmethod
def endpoint(self) -> str:
"""
Get the endpoint of the messenger.
:return: Endpoint string.
"""
...
def decode_message(
message: list[bytes], encoding: str = "ascii", err_mode: str = "strict"
) -> tuple:
if not isinstance(message, list) or not all(isinstance(m, bytes) for m in message):
raise ValueError("Input must be a list of bytes")
return tuple(m.decode(encoding, errors=err_mode) for m in message)
class ZMQMessenger(MessengerInterface):
SOCKET_MODES = {
"ROUTER": zmq.ROUTER, # Handles multiple connections and routes messages by address.
"DEALER": zmq.DEALER, # Load balances outgoing messages and receives replies fairly.
"REQ": zmq.REQ, # Sends requests and waits for replies (synchronous).
"REP": zmq.REP, # Receives requests and sends replies (synchronous).
}
def __init__(self, mode: str, endpoint: Optional[str] = None) -> None:
if mode not in self.SOCKET_MODES:
raise ValueError(
f"Invalid mode '{mode}'. Allowed modes are {list(self.SOCKET_MODES.keys())}"
)
self._context = zmq.Context()
self._mode = mode
self._socket = self._context.socket(self.SOCKET_MODES[mode])
self._endpoint: Optional[str] = None
self._lock = Lock()
self._closed = False
self._stop_event = Event()
self._listener_thread: Optional[Thread] = None
self._initialize_control_sockets()
if endpoint is None:
if mode in ["DEALER", "REQ"]:
raise ValueError("endpoint is required for DEALER/REQ modes")
endpoint = f"tcp://{get_local_ip()}:*"
if mode in ["ROUTER", "REP"]:
self._socket.bind(endpoint)
self._endpoint = self._socket.getsockopt_string(zmq.LAST_ENDPOINT)
elif mode in ["DEALER", "REQ"]:
self._socket.connect(endpoint)
self._endpoint = endpoint
logger.info(f"Initialized ZMQMessenger(mode={mode}, endpoint={self._endpoint})")
def _initialize_control_sockets(self) -> None:
self._control_socket = self._context.socket(zmq.PAIR)
self._internal_socket = self._context.socket(zmq.PAIR)
inproc_endpoint = "inproc://stop_listener"
self._control_socket.bind(inproc_endpoint)
self._internal_socket.connect(inproc_endpoint)
def start(self) -> None:
pass
def send(self, messages: list[bytes], recipient: Optional[bytes] = None) -> None:
if recipient:
self._socket.send_multipart([recipient] + messages)
else:
self._socket.send_multipart(messages)
def receive(self) -> list[bytes]:
return self._socket.recv_multipart()
def start_listener(
self,
on_message: Callable[[list[bytes]], Optional[bool]],
on_error: Optional[Callable[[Exception], None]] = None,
) -> None:
assert self._mode in ["ROUTER", "REP"], (
"Listener can only be started in ROUTER or REP modes"
)
if self._listener_thread and self._listener_thread.is_alive():
raise RuntimeError("Listener already running")
def handle_listener_exceptions(
exception: Exception, on_error: Optional[Callable[[Exception], None]]
) -> None:
logger.error(f"Error in listener: {exception}")
if on_error:
on_error(exception)
else:
self._stop_event.set()
def listener() -> None:
poller = zmq.Poller()
poller.register(self._socket, zmq.POLLIN)
poller.register(self._control_socket, zmq.POLLIN)
while not self._stop_event.is_set():
events = dict(poller.poll(timeout=100))
try:
if self._control_socket in events:
self._stop_event.set()
elif self._socket in events:
messages = self.receive()
persist = on_message(messages)
if persist is False:
self._stop_event.set()
except zmq.ZMQError as e:
handle_listener_exceptions(e, on_error)
break
except Exception as e:
handle_listener_exceptions(e, on_error)
break
self._stop_event.set()
self._listener_thread = Thread(target=listener, daemon=True)
self._listener_thread.start()
def stop(self, timeout: int = 5) -> None:
def _close_socket(socket: zmq.Socket) -> None:
try:
if not socket.closed:
socket.close()
except Exception as e:
logger.error(f"Error closing socket: {e}")
with self._lock:
if self._closed:
return
self._closed = True
logger.debug("Stopping ZMQMessenger...")
self._stop_event.set()
self._internal_socket.send(b"STOP")
if self._listener_thread:
self._internal_socket.send(b"STOP")
self._listener_thread.join(timeout)
if self._listener_thread.is_alive():
logger.warning("Listener thread did not terminate within timeout")
_close_socket(self._socket)
_close_socket(self._internal_socket)
_close_socket(self._control_socket)
try:
if not self._context.closed:
self._context.term()
except Exception as e:
logger.error(f"Error terminating ZMQ context: {e}")
@property
def endpoint(self) -> str:
assert self._endpoint is not None
return self._endpoint
def __enter__(self) -> "ZMQMessenger":
return self
def __exit__(
self, exc_type: Optional[type], exc_val: Optional[BaseException], exc_tb: Optional
) -> None:
self.stop()

View File

@ -0,0 +1,41 @@
from tensorrt_llm import logger
def get_local_ip() -> str:
try:
import socket
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(("10.255.255.255", 1))
ip = s.getsockname()[0]
if not ip.startswith("127."):
return ip
except (ImportError, OSError, ValueError):
logger.error("Failed to get local IP via UDP socket method.")
pass
try:
import netifaces
for iface in netifaces.interfaces():
addrs = netifaces.ifaddresses(iface)
if netifaces.AF_INET in addrs:
for addr in addrs[netifaces.AF_INET]:
ip = addr.get("addr", "")
if not ip.startswith("127.") and not ip.startswith("169.254"):
return ip
except (ImportError, OSError, ValueError) as e:
logger.error(f"Failed to get local IP via netifaces: {e}")
pass
try:
hostname = socket.gethostname()
ip = socket.gethostbyname(hostname)
if not ip.startswith("127."):
return ip
except (OSError, ValueError):
logger.error("Failed to get local IP via hostname resolution.")
pass
return "127.0.0.1"

View File

@ -0,0 +1,147 @@
from tensorrt_llm._utils import nvtx_range
from tensorrt_llm.tensorrt_llm_transfer_agent_binding import ( # noqa: E402
AgentDesc,
BaseAgentConfig,
MemoryDescs,
MemoryType,
TransferState,
)
from tensorrt_llm.tensorrt_llm_transfer_agent_binding import (
NixlTransferAgent as CppNixlTransferAgent,
)
from ..base.agent import BaseTransferAgent, RegMemoryDescs, TransferRequest, TransferStatus
class BindingsNixlTransferStatus(TransferStatus):
"""TransferStatus wrapper using C++ bindings with GIL release."""
def __init__(self, cpp_status):
self._cpp_status = cpp_status
def is_completed(self) -> bool:
"""Check if transfer is completed (releases GIL)."""
return self._cpp_status.is_completed()
@nvtx_range("BindingsNixlTransferStatus.wait")
def wait(self, timeout_ms=None) -> bool:
"""Wait for transfer to complete (releases GIL)."""
if timeout_ms is None:
timeout_ms = -1
return self._cpp_status.wait(timeout_ms) == TransferState.SUCCESS
class BindingsNixlTransferAgent(BaseTransferAgent):
"""NixlTransferAgent using C++ bindings with GIL release support.
This implementation uses the standalone nixl_bindings C++ module which releases
the GIL during blocking operations like wait().
The nixl_bindings module is independent from the main trtllm bindings,
so trtllm can function normally even without NIXL.
"""
def __init__(
self,
name: str,
use_prog_thread: bool = True,
num_threads: int = 1,
enable_telemetry: bool = False,
**kwargs,
):
backend_params = kwargs
for key, value in backend_params.items():
backend_params[key] = str(value)
backend_params["num_threads"] = str(num_threads)
config = BaseAgentConfig(
name,
use_prog_thread,
multi_thread=False,
use_listen_thread=False,
enable_telemetry=enable_telemetry,
backend_params=backend_params,
)
self._cpp_agent = CppNixlTransferAgent(config)
self.name = name
def register_memory(self, descs: RegMemoryDescs):
"""Register memory regions."""
cpp_descs = self._convert_reg_memory_descs(descs)
self._cpp_agent.register_memory(cpp_descs)
def deregister_memory(self, descs: RegMemoryDescs):
"""Deregister memory regions."""
cpp_descs = self._convert_reg_memory_descs(descs)
self._cpp_agent.deregister_memory(cpp_descs)
def load_remote_agent(self, name: str, agent_desc: bytes):
"""Load a remote agent by its descriptor (bytes)."""
# AgentDesc expects std::string which can hold binary data
desc_str = agent_desc if isinstance(agent_desc, bytes) else agent_desc.encode()
cpp_desc = AgentDesc(desc_str)
self._cpp_agent.load_remote_agent(name, cpp_desc)
def load_remote_agent_by_connection(self, name: str, connection_info: str):
"""Load a remote agent by connection info."""
self._cpp_agent.load_remote_agent_by_connection(name, connection_info)
def get_local_agent_desc(self) -> bytes:
"""Get the local agent descriptor as bytes."""
agent_desc = self._cpp_agent.get_local_agent_desc()
return agent_desc.backend_agent_desc # Returns bytes
def get_local_connection_info(self) -> str:
"""Get the local connection info."""
return self._cpp_agent.get_local_connection_info()
def invalidate_remote_agent(self, name: str):
"""Invalidate a remote agent."""
self._cpp_agent.invalidate_remote_agent(name)
def check_remote_descs(self, name: str, memory_descs: MemoryDescs) -> bool:
"""Check if remote descriptors are available.
memory_descs should be C++ MemoryDescs type.
"""
return self._cpp_agent.check_remote_descs(name, memory_descs)
def notify_sync_message(self, name: str, sync_message: str):
"""Send a sync message to a remote agent."""
self._cpp_agent.notify_sync_message(name, sync_message)
def get_notified_sync_messages(self):
"""Get notified sync messages."""
return self._cpp_agent.get_notified_sync_messages()
@nvtx_range("BindingsNixlTransferAgent.submit_transfer_requests")
def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus:
"""Submit transfer requests and return status.
request should be a C++ TransferRequest (from tensorrt_llm_transfer_agent_binding).
"""
cpp_status = self._cpp_agent.submit_transfer_requests(request)
return BindingsNixlTransferStatus(cpp_status)
def _convert_reg_memory_descs(self, descs: RegMemoryDescs) -> "MemoryDescs":
"""Convert Python RegMemoryDescs to C++ MemoryDescs.
RegMemoryDescs.descs is List[Tuple[int, int, int, str]] = (ptr, size, device_id, name)
Extract first 3 elements for C++ batch constructor.
"""
mem_type = self._convert_memory_type(descs.type)
# Extract (ptr, size, device_id) from 4-tuple, discard name
tuples = [(d[0], d[1], d[2]) for d in descs.descs]
return MemoryDescs(mem_type, tuples)
def _convert_memory_type(self, py_type: str) -> "MemoryType":
"""Convert Python memory type string to C++ MemoryType."""
type_map = {
"DRAM": MemoryType.DRAM,
"VRAM": MemoryType.VRAM,
"GPU": MemoryType.VRAM,
"BLK": MemoryType.BLK,
"OBJ": MemoryType.OBJ,
"FILE": MemoryType.FILE,
}
return type_map.get(py_type.upper(), MemoryType.VRAM)

View File

@ -0,0 +1,114 @@
import time
from enum import Enum
from nixl import nixl_agent, nixl_agent_config, nixl_xfer_handle
from tensorrt_llm._utils import nvtx_range
# Import base classes for type compatibility
from ..base.agent import BaseTransferAgent, RegMemoryDescs, TransferRequest, TransferStatus
class TransferState(Enum):
PENDING = "PENDING"
PROCESSING = "PROC"
DONE = "DONE"
ERROR = "ERROR"
class NixlTransferStatus(TransferStatus):
def __init__(self, agent: nixl_agent, handle: nixl_xfer_handle):
self.agent = agent
self.handle = handle
def is_completed(self):
status = TransferState(self.agent.check_xfer_state(self.handle))
return status == TransferState.DONE
def wait(self, timeout_ms=None):
start_time = time.time()
status = TransferState.PENDING
sleep_time = 0.0001 # 0.1ms in seconds
max_sleep_time = 0.01 # 10ms in seconds
timeout = timeout_ms / 1000 if timeout_ms is not None else None
while status in (TransferState.PENDING, TransferState.PROCESSING):
status = TransferState(self.agent.check_xfer_state(self.handle))
if status == TransferState.ERROR:
return False # Transfer failed
if timeout is not None and (time.time() - start_time > timeout):
return False # Timeout
time.sleep(sleep_time)
sleep_time = min(sleep_time * 2, max_sleep_time)
return status == TransferState.DONE
class NixlTransferAgent(BaseTransferAgent):
"""NixlTransferAgent using Python nixl library."""
def __init__(self, name: str, use_prog_thread: bool = True, num_threads: int = 1, **kwargs):
"""
Initialize NixlTransferAgent.
:param name: Name of the agent.
:param use_prog_thread: Whether to enable the progress thread, if available.
:param num_workers: Specify number of threads for the supported multi-threaded backends.
"""
self.name = name
self.backends = ["UCX"]
agent_config = nixl_agent_config(
enable_prog_thread=use_prog_thread, backends=self.backends, num_threads=num_threads
)
self.agent = nixl_agent(name, agent_config)
def register_memory(self, descs: RegMemoryDescs):
if not descs.descs:
raise ValueError("descs.descs must not be empty")
if isinstance(descs.descs[0], tuple):
assert len(descs.descs[0]) == 4, f"Expected 4 elements per desc, got {descs.descs[0]}"
reg_descs = self.agent.get_reg_descs(descs.descs, descs.type)
assert reg_descs is not None, "Failed to get reg_descs"
self.agent.register_memory(reg_descs)
def deregister_memory(self, descs: RegMemoryDescs):
if not descs.descs:
raise ValueError("descs.descs must not be empty")
if isinstance(descs.descs[0], tuple):
assert len(descs.descs[0]) == 4, f"Expected 4 elements per desc, got {descs.descs[0]}"
reg_descs = self.agent.get_reg_descs(descs.descs, descs.type)
assert reg_descs is not None, "Failed to get reg_descs"
self.agent.deregister_memory(reg_descs)
def load_remote_agent(self, name: str, agent_desc: bytes):
self.agent.add_remote_agent(agent_desc)
def get_local_agent_desc(self):
return self.agent.get_agent_metadata()
def invalidate_remote_agent(self, name: str):
self.agent.remove_remote_agent(name)
def check_remote_descs(self, name: str, memory_descs: list[int]) -> bool:
raise NotImplementedError
def notify_sync_message(self, name: str, sync_message: str):
raise NotImplementedError
@nvtx_range("NixlTransferAgent.submit_transfer_requests")
def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus:
src_xfer_descs = self.agent.get_xfer_descs(request.src_descs.descs, request.src_descs.type)
dst_xfer_descs = self.agent.get_xfer_descs(request.dst_descs.descs, request.dst_descs.type)
assert src_xfer_descs is not None, "Failed to get src_xfer_descs"
assert dst_xfer_descs is not None, "Failed to get dst_xfer_descs"
sync_message = "" if request.sync_message is None else request.sync_message
handle = self.agent.initialize_xfer(
request.op,
src_xfer_descs,
dst_xfer_descs,
request.remote_name,
sync_message,
)
status = self.agent.transfer(handle)
if status == "ERROR":
raise RuntimeError("NIXL transfer initialization failed.")
return NixlTransferStatus(self.agent, handle)

View File

@ -0,0 +1,43 @@
from tensorrt_llm.logger import logger
from ..base.agent import use_pure_python_transfer_agent
"""NIXL Transfer Agent implementations.
This module provides two implementations:
1. BindingsNixlTransferAgent - Uses the standalone nixl_bindings C++ module with GIL release support
2. NixlTransferAgent - Uses the Python nixl library directly (fallback)
The standalone nixl_bindings module is separate from the main trtllm bindings,
so trtllm can still function normally even without NIXL dependencies.
"""
def _load_agent(module_name, required_attributes):
try:
module = __import__(module_name, fromlist=required_attributes, level=0)
if all(hasattr(module, attr) for attr in required_attributes):
return module
except ImportError as e:
logger.info("Failed to import module: %s. Error: %s", module_name, str(e))
return None
NixlTransferStatus, NixlTransferAgent = None, None
if use_pure_python_transfer_agent():
_py_agent = _load_agent(
module_name="tensorrt_llm._torch.disaggregation.nixl._agent_py",
required_attributes=["NixlTransferAgent", "NixlTransferStatus"],
)
assert _py_agent is not None, "Failed to load pure Python NIXL Transfer Agent."
NixlTransferStatus = _py_agent.NixlTransferStatus
NixlTransferAgent = _py_agent.NixlTransferAgent
else:
_cpp_agent = _load_agent(
module_name="tensorrt_llm._torch.disaggregation.nixl._agent_cpp",
required_attributes=["BindingsNixlTransferAgent", "BindingsNixlTransferStatus"],
)
assert _cpp_agent is not None, "Failed to load C++ NIXL Transfer Agent bindings."
NixlTransferStatus = _cpp_agent.BindingsNixlTransferStatus
NixlTransferAgent = _cpp_agent.BindingsNixlTransferAgent

View File

@ -33,6 +33,8 @@ l0_a10:
- unittest/disaggregated/test_disagg_utils.py
- unittest/disaggregated/test_router.py
- unittest/disaggregated/test_remoteDictionary.py
- unittest/disaggregated/test_agent_multi_backends.py
- unittest/disaggregated/test_messenger.py
- unittest/disaggregated/test_disagg_cluster_manager_worker.py
- unittest/disaggregated/test_cluster_storage.py
- disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0]

View File

@ -37,6 +37,8 @@ l0_h100:
- unittest/disaggregated/test_disagg_utils.py
- unittest/disaggregated/test_router.py
- unittest/disaggregated/test_remoteDictionary.py
- unittest/disaggregated/test_agent_multi_backends.py
- unittest/disaggregated/test_messenger.py
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_without_reuse
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_reuse

View File

@ -142,21 +142,24 @@ def test_base_agent_config_custom():
use_prog_thread = True
multi_thread = False
use_listen_thread = True
num_workers = 4
enable_telemetry = True
backend_params = {"key1": "value1", "key2": "value2"}
config = tab.BaseAgentConfig(
name=name,
use_prog_thread=use_prog_thread,
multi_thread=multi_thread,
use_listen_thread=use_listen_thread,
num_workers=num_workers,
enable_telemetry=enable_telemetry,
backend_params=backend_params,
)
assert config.name == name
assert config.use_prog_thread == use_prog_thread
assert config.multi_thread == multi_thread
assert config.use_listen_thread == use_listen_thread
assert config.num_workers == num_workers
assert config.enable_telemetry == enable_telemetry
assert config.backend_params == backend_params
def test_base_agent_config_readwrite():
@ -175,8 +178,11 @@ def test_base_agent_config_readwrite():
config.use_listen_thread = True
assert config.use_listen_thread is True
config.num_workers = 8
assert config.num_workers == 8
config.enable_telemetry = True
assert config.enable_telemetry is True
config.backend_params = {"test_key": "test_value"}
assert config.backend_params == {"test_key": "test_value"}
def test_transfer_request():

View File

@ -0,0 +1,176 @@
from dataclasses import dataclass, field
from unittest import TestCase
from unittest.mock import Mock
import pytest
import torch
from tensorrt_llm import logger
from tensorrt_llm._torch.disaggregation.base.agent import (
MemoryDescs,
MemoryType,
RegMemoryDescs,
TransferOp,
TransferRequest,
TransferStatus,
)
from tensorrt_llm._torch.disaggregation.nixl.agent import NixlTransferAgent
class TestTransferStatus(TestCase):
def test_mock_transfer_status(self):
mock_transfer_status = Mock(spec=TransferStatus)
mock_transfer_status.is_completed.return_value = True
self.assertTrue(mock_transfer_status.is_completed())
mock_transfer_status.is_completed.assert_called_once()
mock_transfer_status.wait.return_value = True
timeout_values = [None, 1000, 5000]
for timeout in timeout_values:
with self.subTest(timeout=timeout):
result = mock_transfer_status.wait(timeout_ms=timeout)
self.assertTrue(result)
mock_transfer_status.wait.assert_called_with(timeout_ms=timeout)
def _convert_to_memory_descs(reg_descs: RegMemoryDescs) -> MemoryDescs:
tuples = [(ptr, size, device_id) for (ptr, size, device_id, _) in reg_descs.descs]
def _convert_memory_type(py_type: str) -> MemoryType:
"""Convert Python memory type string to C++ MemoryType."""
type_map = {
"DRAM": MemoryType.DRAM,
"VRAM": MemoryType.VRAM,
"GPU": MemoryType.VRAM,
"BLK": MemoryType.BLK,
"OBJ": MemoryType.OBJ,
"FILE": MemoryType.FILE,
}
return type_map.get(py_type.upper(), MemoryType.VRAM)
return MemoryDescs(_convert_memory_type(reg_descs.type), tuples)
@dataclass
class MemoryManager:
allocated_memory: list[torch.Tensor] = field(default_factory=list)
def allocate_memory(
self, size: int, name: str, memory_type=MemoryType.VRAM, device_id: int = 0
) -> RegMemoryDescs:
device = torch.device(f"cuda:{device_id}" if memory_type == MemoryType.VRAM else "cpu")
# Allocate memory block using torch.Tensor and track it
block = torch.zeros(size, dtype=torch.uint8, device=device)
self.allocated_memory.append(block)
# Return RegMemoryDescs with position arguments
memory_descs = RegMemoryDescs(
type=memory_type, descs=[(block.data_ptr(), block.numel(), device_id, name)]
)
return memory_descs
def clear_memory(self):
"""Clear all tracked memory blocks."""
self.allocated_memory.clear()
@pytest.fixture
def memory_manager():
return MemoryManager()
@pytest.fixture(params=[256, 512])
def memory_size(request):
return request.param
@pytest.fixture(params=["DRAM", "VRAM"])
def memory_type(request):
return request.param
@pytest.fixture
def alloc(memory_manager, memory_size, memory_type):
"""Allocate memory for source and destination, based on the memory_size and memory_type parameters."""
assert memory_size > 0, "Memory size must be a positive integer."
if memory_type == "VRAM" and not torch.cuda.is_available():
pytest.skip("CUDA not available for VRAM transfer tests")
src_descs = memory_manager.allocate_memory(
size=memory_size, name="src_mem", memory_type=memory_type
)
dst_descs = memory_manager.allocate_memory(
size=memory_size, name="dst_mem", memory_type=memory_type
)
return src_descs, dst_descs
@pytest.fixture
def transfer_agent_src():
return NixlTransferAgent(name="src_agent")
@pytest.fixture
def transfer_agent_dst():
return NixlTransferAgent(name="dst_agent")
def test_transfer_between_agents(
transfer_agent_src,
transfer_agent_dst,
memory_manager,
alloc,
memory_size,
memory_type,
):
"""End-to-end test of data transfer between two agents with parameterized memory sizes and types."""
# Debug log the parameters being tested
logger.info(f"Testing with memory_size={memory_size}, memory_type={memory_type}")
# Unpack source and destination memory descriptions
memory_descs_src, memory_descs_dst = alloc
# Fill source memory with sequential data for validation
src_data = memory_manager.allocated_memory[0]
assert memory_size > 0, "Memory size must be positive."
tensor = torch.arange(memory_size, dtype=torch.uint8) % 10
src_data.copy_(tensor)
# Register memory with source and destination agents
transfer_agent_src.register_memory(memory_descs_src)
transfer_agent_dst.register_memory(memory_descs_dst)
src_agent_desc = transfer_agent_src.get_local_agent_desc()
transfer_agent_dst.load_remote_agent("src_agent", src_agent_desc)
dst_agent_desc = transfer_agent_dst.get_local_agent_desc()
transfer_agent_src.load_remote_agent("dst_agent", dst_agent_desc)
# Create and submit the transfer request
transfer_request = TransferRequest(
op=TransferOp.WRITE,
src_descs=_convert_to_memory_descs(memory_descs_src),
dst_descs=_convert_to_memory_descs(memory_descs_dst),
remote_name="dst_agent",
sync_message=None,
)
transfer_status = transfer_agent_src.submit_transfer_requests(transfer_request)
assert transfer_status.wait(timeout_ms=5000), "Transfer did not complete within timeout."
# Validate transfer completion
assert transfer_status.is_completed(), "Transfer did not complete successfully."
# Validate that the destination data matches the source data
dst_data = memory_manager.allocated_memory[1]
assert torch.equal(dst_data, src_data), "Destination data does not match source data."
# Clean up by deregistering memory and clearing allocations
transfer_agent_src.deregister_memory(memory_descs_src)
transfer_agent_dst.deregister_memory(memory_descs_dst)
memory_manager.clear_memory()
transfer_agent_src.invalidate_remote_agent("dst_agent")
transfer_agent_dst.invalidate_remote_agent("src_agent")
if __name__ == "__main__":
pytest.main()

View File

@ -0,0 +1,32 @@
import os
import subprocess
import pytest
@pytest.mark.parametrize("use_py_nixl", ["0", "1"])
def test_run_with_different_env(use_py_nixl):
os.environ["TRTLLM_USE_PY_NIXL_KVCACHE"] = use_py_nixl
print(f"Running tests with TRTLLM_USE_PY_NIXL_KVCACHE={use_py_nixl}")
test_file_path = os.path.join(os.path.dirname(__file__), "test_agent.py")
print(f"Running tests in: {test_file_path}")
result = subprocess.run(
["pytest", "--capture=no", test_file_path],
env=os.environ.copy(),
capture_output=True,
text=True,
)
print(result.stdout)
if result.returncode != 0:
print("Test failed. stderr output:")
print(result.stderr)
assert result.returncode == 0, f"Tests failed with TRTLLM_USE_PY_NIXL_KVCACHE={use_py_nixl}"
if __name__ == "__main__":
pytest.main()

View File

@ -0,0 +1,127 @@
import socket
import time
import unittest
import pytest
from parameterized import parameterized
from tensorrt_llm._torch.disaggregation.native.messenger import ZMQMessenger, decode_message
from tensorrt_llm._torch.disaggregation.native.utils import get_local_ip
TEST_CASES = [
{
"name": "valid_message",
"message": [b"hello", b"world"],
"encoding": "utf-8",
"err_mode": "strict",
"expected": ("hello", "world"),
"raises": None,
},
{
"name": "invalid_input",
"message": ["hello", b"world"],
"encoding": "utf-8",
"err_mode": "strict",
"expected": None,
"raises": ValueError,
},
{
"name": "decoding_error",
"message": [b"\xff"],
"encoding": "utf-8",
"err_mode": "strict",
"expected": None,
"raises": UnicodeDecodeError,
},
{
"name": "decoding_with_ignore",
"message": [b"\xff"],
"encoding": "utf-8",
"err_mode": "ignore",
"expected": ("",),
"raises": None,
},
]
class TestDecodeMessage(unittest.TestCase):
@parameterized.expand([(case["name"], case) for case in TEST_CASES])
def test_decode_message(self, name, case):
message = case["message"]
encoding = case["encoding"]
err_mode = case["err_mode"]
expected = case["expected"]
raises = case["raises"]
if raises:
with self.assertRaises(raises):
decode_message(message, encoding=encoding, err_mode=err_mode)
else:
decoded = decode_message(message, encoding=encoding, err_mode=err_mode)
self.assertEqual(decoded, expected)
@pytest.fixture
def dynamic_endpoint():
"""Fixture to dynamically generate an available endpoint with a free port."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0)) # Bind to an available port provided by the OS
port = s.getsockname()[1]
return f"tcp://{get_local_ip()}:{port}"
@pytest.fixture
def create_messenger_pair(dynamic_endpoint):
def _create_messenger_pair(mode1, mode2):
messenger1 = ZMQMessenger(
mode1, endpoint=dynamic_endpoint if mode1 in ["ROUTER", "REP"] else None
)
messenger2 = ZMQMessenger(
mode2, endpoint=dynamic_endpoint if mode2 in ["DEALER", "REQ"] else None
)
return messenger1, messenger2
yield _create_messenger_pair
def test_router_dealer(create_messenger_pair):
"""Test ROUTER and DEALER communication."""
router, dealer = create_messenger_pair("ROUTER", "DEALER")
received_messages = []
def on_message(messages):
received_messages.extend(messages)
router.start_listener(on_message)
dealer.send([b"Hello, ROUTER!"])
time.sleep(0.1)
assert len(received_messages) > 0
assert b"Hello, ROUTER!" in received_messages
router.stop()
dealer.stop()
def test_req_rep(create_messenger_pair):
"""Test REQ and REP communication."""
rep, req = create_messenger_pair("REP", "REQ")
def on_message(messages):
rep.send(messages)
rep.start_listener(on_message)
req.send([b"Hello, REP!"])
response = req.receive()
assert response == [b"Hello, REP!"]
req.stop()
rep.stop()
if __name__ == "__main__":
unittest.main()