From 944c304bbb7860031c99cb298fee70f462e17bd7 Mon Sep 17 00:00:00 2001 From: Shi Xiaowei <39303645+Shixiaowei02@users.noreply.github.com> Date: Fri, 23 Jan 2026 02:14:50 +0800 Subject: [PATCH] [TRTLLM-9527][feat] Python transceiver components (step 2) (#10494) Signed-off-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> --- .../tensorrt_llm/executor/transferAgent.h | 3 +- .../agent_utils/connection.cpp | 2 +- .../nixl_utils/agentBindingsNanobind.cpp | 32 ++- .../nixl_utils/agentBindingsPybind.cpp | 26 ++- .../nixl_utils/transferAgent.cpp | 21 +- .../unit_tests/executor/transferAgentTest.cpp | 12 +- requirements-dev.txt | 1 + .../_torch/disaggregation/__init__.py | 0 .../_torch/disaggregation/base/__init__.py | 0 .../_torch/disaggregation/base/agent.py | 145 ++++++++++++ .../_torch/disaggregation/base/kv_transfer.py | 200 ++++++++++++++++ .../_torch/disaggregation/native/__init__.py | 0 .../_torch/disaggregation/native/messenger.py | 219 ++++++++++++++++++ .../_torch/disaggregation/native/utils.py | 41 ++++ .../_torch/disaggregation/nixl/__init__.py | 0 .../_torch/disaggregation/nixl/_agent_cpp.py | 147 ++++++++++++ .../_torch/disaggregation/nixl/_agent_py.py | 114 +++++++++ .../_torch/disaggregation/nixl/agent.py | 43 ++++ .../integration/test_lists/test-db/l0_a10.yml | 2 + .../test_lists/test-db/l0_h100.yml | 2 + .../bindings/test_transfer_agent_bindings.py | 16 +- tests/unittest/disaggregated/test_agent.py | 176 ++++++++++++++ .../test_agent_multi_backends.py | 32 +++ .../unittest/disaggregated/test_messenger.py | 127 ++++++++++ 24 files changed, 1331 insertions(+), 30 deletions(-) create mode 100644 tensorrt_llm/_torch/disaggregation/__init__.py create mode 100644 tensorrt_llm/_torch/disaggregation/base/__init__.py create mode 100644 tensorrt_llm/_torch/disaggregation/base/agent.py create mode 100644 tensorrt_llm/_torch/disaggregation/base/kv_transfer.py create mode 100644 tensorrt_llm/_torch/disaggregation/native/__init__.py create mode 100644 tensorrt_llm/_torch/disaggregation/native/messenger.py create mode 100644 tensorrt_llm/_torch/disaggregation/native/utils.py create mode 100644 tensorrt_llm/_torch/disaggregation/nixl/__init__.py create mode 100644 tensorrt_llm/_torch/disaggregation/nixl/_agent_cpp.py create mode 100644 tensorrt_llm/_torch/disaggregation/nixl/_agent_py.py create mode 100644 tensorrt_llm/_torch/disaggregation/nixl/agent.py create mode 100644 tests/unittest/disaggregated/test_agent.py create mode 100644 tests/unittest/disaggregated/test_agent_multi_backends.py create mode 100644 tests/unittest/disaggregated/test_messenger.py diff --git a/cpp/include/tensorrt_llm/executor/transferAgent.h b/cpp/include/tensorrt_llm/executor/transferAgent.h index 9bd0f97e3d..8d6a461076 100644 --- a/cpp/include/tensorrt_llm/executor/transferAgent.h +++ b/cpp/include/tensorrt_llm/executor/transferAgent.h @@ -296,7 +296,8 @@ struct BaseAgentConfig bool useProgThread; bool multiThread; bool useListenThread; - unsigned int numWorkers; + bool enableTelemetry; + std::unordered_map backendParams; }; class BaseTransferAgent diff --git a/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp b/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp index 92dba4519d..3e9c7485bb 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp +++ b/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp @@ -247,7 +247,7 @@ AgentConnectionManager::AgentConnectionManager( mAgentName = genUniqueAgentName(); // Create Agent - BaseAgentConfig config{mAgentName, true, false, true, 1}; + BaseAgentConfig config{mAgentName, true, false, true}; m_Agent = makeTransferAgent(backendType, &config); TLLM_CHECK(!mCacheTransBufferManagers.empty()); std::vector memDescs; diff --git a/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/agentBindingsNanobind.cpp b/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/agentBindingsNanobind.cpp index 671b4e8ad3..31684347cf 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/agentBindingsNanobind.cpp +++ b/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/agentBindingsNanobind.cpp @@ -26,10 +26,9 @@ #endif #include -#include #include -#include #include +#include #include #include @@ -69,6 +68,21 @@ NB_MODULE(tensorrt_llm_transfer_agent_binding, m) // MemoryDescs class nb::class_(m, "MemoryDescs") .def(nb::init>(), nb::arg("type"), nb::arg("descs")) + // Batch constructor from list of tuples: [(ptr, size, device_id), ...] + .def( + "__init__", + [](kvc::MemoryDescs* self, kvc::MemoryType type, + std::vector> const& tuples) + { + std::vector descs; + descs.reserve(tuples.size()); + for (auto const& [addr, len, deviceId] : tuples) + { + descs.emplace_back(addr, len, deviceId); + } + new (self) kvc::MemoryDescs(type, std::move(descs)); + }, + nb::arg("type"), nb::arg("tuples")) .def_prop_ro("type", &kvc::MemoryDescs::getType) .def_prop_ro("descs", &kvc::MemoryDescs::getDescs); @@ -113,17 +127,21 @@ NB_MODULE(tensorrt_llm_transfer_agent_binding, m) .def( "__init__", [](kvc::BaseAgentConfig* self, std::string name, bool use_prog_thread, bool multi_thread, - bool use_listen_thread, unsigned int num_workers) { - new (self) kvc::BaseAgentConfig{ - std::move(name), use_prog_thread, multi_thread, use_listen_thread, num_workers}; + bool use_listen_thread, bool enable_telemetry, + std::unordered_map backend_params) + { + new (self) kvc::BaseAgentConfig{std::move(name), use_prog_thread, multi_thread, use_listen_thread, + enable_telemetry, std::move(backend_params)}; }, nb::arg("name"), nb::arg("use_prog_thread") = true, nb::arg("multi_thread") = false, - nb::arg("use_listen_thread") = false, nb::arg("num_workers") = 1) + nb::arg("use_listen_thread") = false, nb::arg("enable_telemetry") = false, + nb::arg("backend_params") = std::unordered_map{}) .def_rw("name", &kvc::BaseAgentConfig::mName) .def_rw("use_prog_thread", &kvc::BaseAgentConfig::useProgThread) .def_rw("multi_thread", &kvc::BaseAgentConfig::multiThread) .def_rw("use_listen_thread", &kvc::BaseAgentConfig::useListenThread) - .def_rw("num_workers", &kvc::BaseAgentConfig::numWorkers); + .def_rw("enable_telemetry", &kvc::BaseAgentConfig::enableTelemetry) + .def_rw("backend_params", &kvc::BaseAgentConfig::backendParams); // BaseTransferAgent class (abstract base) nb::class_(m, "BaseTransferAgent") diff --git a/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/agentBindingsPybind.cpp b/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/agentBindingsPybind.cpp index 2a15144571..f7bf6ca355 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/agentBindingsPybind.cpp +++ b/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/agentBindingsPybind.cpp @@ -66,6 +66,19 @@ PYBIND11_MODULE(tensorrt_llm_transfer_agent_binding, m) // MemoryDescs class py::class_(m, "MemoryDescs") .def(py::init>(), py::arg("type"), py::arg("descs")) + // Batch constructor from list of tuples: [(ptr, size, device_id), ...] + .def(py::init( + [](kvc::MemoryType type, std::vector> const& tuples) + { + std::vector descs; + descs.reserve(tuples.size()); + for (auto const& [addr, len, deviceId] : tuples) + { + descs.emplace_back(addr, len, deviceId); + } + return kvc::MemoryDescs(type, std::move(descs)); + }), + py::arg("type"), py::arg("tuples")) .def_property_readonly("type", &kvc::MemoryDescs::getType) .def_property_readonly("descs", &kvc::MemoryDescs::getDescs); @@ -108,17 +121,20 @@ PYBIND11_MODULE(tensorrt_llm_transfer_agent_binding, m) .def(py::init<>()) .def(py::init( [](std::string name, bool use_prog_thread, bool multi_thread, bool use_listen_thread, - unsigned int num_workers) { - return kvc::BaseAgentConfig{ - std::move(name), use_prog_thread, multi_thread, use_listen_thread, num_workers}; + bool enable_telemetry, std::unordered_map backend_params) + { + return kvc::BaseAgentConfig{std::move(name), use_prog_thread, multi_thread, use_listen_thread, + enable_telemetry, std::move(backend_params)}; }), py::arg("name"), py::arg("use_prog_thread") = true, py::arg("multi_thread") = false, - py::arg("use_listen_thread") = false, py::arg("num_workers") = 1) + py::arg("use_listen_thread") = false, py::arg("enable_telemetry") = false, + py::arg("backend_params") = std::unordered_map{}) .def_readwrite("name", &kvc::BaseAgentConfig::mName) .def_readwrite("use_prog_thread", &kvc::BaseAgentConfig::useProgThread) .def_readwrite("multi_thread", &kvc::BaseAgentConfig::multiThread) .def_readwrite("use_listen_thread", &kvc::BaseAgentConfig::useListenThread) - .def_readwrite("num_workers", &kvc::BaseAgentConfig::numWorkers); + .def_readwrite("enable_telemetry", &kvc::BaseAgentConfig::enableTelemetry) + .def_readwrite("backend_params", &kvc::BaseAgentConfig::backendParams); // BaseTransferAgent class (abstract base) py::class_(m, "BaseTransferAgent") diff --git a/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/transferAgent.cpp b/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/transferAgent.cpp index 62a8d86e1c..efac8b52b2 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/transferAgent.cpp +++ b/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/transferAgent.cpp @@ -374,22 +374,28 @@ NixlTransferAgent::NixlTransferAgent(BaseAgentConfig const& config) } auto envPort = common::getEnvNixlPort(); uint16_t port = envPort > 0 ? getIncrmentPort(envPort) : getAvailablePort(); - nixlAgentConfig nixlConfig{ - config.useProgThread, true, port, nixl_thread_sync_t::NIXL_THREAD_SYNC_DEFAULT, config.numWorkers}; + uint32_t numWorker = config.backendParams.find("num_workers") != config.backendParams.end() + ? std::stoi(config.backendParams.at("num_workers")) + : 1; + nixlAgentConfig nixlConfig{config.useProgThread, true, port, nixl_thread_sync_t::NIXL_THREAD_SYNC_DEFAULT, + numWorker, 0, 10000, config.enableTelemetry}; mAddress = getAvailableIP() + ":" + std::to_string(port); mRawAgent = std::make_unique(config.mName, std::move(nixlConfig)); } else { + uint32_t numWorker = config.backendParams.find("num_workers") != config.backendParams.end() + ? std::stoi(config.backendParams.at("num_workers")) + : 1; mAddress.clear(); - nixlAgentConfig nixlConfig{ - config.useProgThread, false, 0, nixl_thread_sync_t::NIXL_THREAD_SYNC_DEFAULT, config.numWorkers}; + nixlAgentConfig nixlConfig{config.useProgThread, false, 0, nixl_thread_sync_t::NIXL_THREAD_SYNC_DEFAULT, + numWorker, 0, 10000, config.enableTelemetry}; mRawAgent = std::make_unique(config.mName, std::move(nixlConfig)); } std::string nixlBackend = common::getEnvNixlBackend(); // List of supported backends - extend this list as new backends are added - static const std::set kSUPPORTED_BACKENDS = {"UCX", "LIBFABRIC"}; + static std::set const kSUPPORTED_BACKENDS = {"UCX", "LIBFABRIC"}; if (kSUPPORTED_BACKENDS.find(nixlBackend) == kSUPPORTED_BACKENDS.end()) { @@ -400,6 +406,11 @@ NixlTransferAgent::NixlTransferAgent(BaseAgentConfig const& config) TLLM_LOG_INFO("NixlTransferAgent::NixlTransferAgent using NIXL backend: %s", nixlBackend.c_str()); nixl_b_params_t init1; + for (auto const& [key, value] : config.backendParams) + { + init1[key] = value; + TLLM_LOG_INFO("NixlTransferAgent::NixlTransferAgent backendParams: %s: %s", key.c_str(), value.c_str()); + } nixl_mem_list_t mems1; status = mRawAgent->getPluginParams(nixlBackend.c_str(), mems1, init1); TLLM_CHECK(status == NIXL_SUCCESS); diff --git a/cpp/tests/unit_tests/executor/transferAgentTest.cpp b/cpp/tests/unit_tests/executor/transferAgentTest.cpp index ad31724c73..ffb0cef1e8 100644 --- a/cpp/tests/unit_tests/executor/transferAgentTest.cpp +++ b/cpp/tests/unit_tests/executor/transferAgentTest.cpp @@ -92,7 +92,7 @@ TEST_P(TransferAgentTest, Basic) { std::string const agent0{"agent0"}, agent1{"agent1"}; - BaseAgentConfig config0{agent0, true, false, true, 1}, config1{agent1, true, false, true, 1}; + BaseAgentConfig config0{agent0, true, false, true}, config1{agent1, true, false, true}; auto xferAgent0 = makeTransferAgent(config0); auto xferAgent1 = makeTransferAgent(config1); @@ -127,7 +127,7 @@ TEST_P(TransferAgentTest, Basic2) { std::string const agent0{"agent0"}, agent1{"agent1"}; - BaseAgentConfig config0{agent0, true, false, true, 1}, config1{agent1, true, false, true, 1}; + BaseAgentConfig config0{agent0, true, false, true}, config1{agent1, true, false, true}; auto xferAgent0 = makeTransferAgent(config0); auto xferAgent1 = makeTransferAgent(config1); @@ -162,7 +162,7 @@ TEST_P(TransferAgentTest, DeviceMemory) { std::string const agent0{"agent0"}, agent1{"agent1"}; - BaseAgentConfig config0{agent0, true, false, true, 1}, config1{agent1, true, false, true, 1}; + BaseAgentConfig config0{agent0, true, false, true}, config1{agent1, true, false, true}; auto xferAgent0 = makeTransferAgent(config0); auto xferAgent1 = makeTransferAgent(config1); @@ -209,8 +209,8 @@ TEST_P(TransferAgentTest, Connect) { std::string const agent0{"agent0"}, agent1{"agent1"}, agent2{"agent2"}; - BaseAgentConfig config0{agent0, true, false, true, 1}, config1{agent1, true, false, true, 1}, - config2{agent2, true, false, true, 1}; + BaseAgentConfig config0{agent0, true, false, true}, config1{agent1, true, false, true}, + config2{agent2, true, false, true}; auto xferAgent0 = makeTransferAgent(config0); auto xferAgent1 = makeTransferAgent(config1); auto xferAgent2 = makeTransferAgent(config2); @@ -261,7 +261,7 @@ TEST_P(TransferAgentTest, SyncMessage) { constexpr std::size_t MAX_QUERY_TIMES = std::numeric_limits::max(); std::string const agent0{"agent0"}, agent1{"agent1"}; - BaseAgentConfig config0{agent0, true, false, true, 1}, config1{agent1, true, false, true, 1}; + BaseAgentConfig config0{agent0, true, false, true}, config1{agent1, true, false, true}; auto xferAgent0 = makeTransferAgent(config0); auto xferAgent1 = makeTransferAgent(config1); diff --git a/requirements-dev.txt b/requirements-dev.txt index 4ff771f955..23de534177 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -38,3 +38,4 @@ opentelemetry-semantic-conventions-ai>=0.4.1 fuzzywuzzy==0.18.0 aiperf==0.3.0 nanobind>=2.9.0 +nixl==0.8.0 diff --git a/tensorrt_llm/_torch/disaggregation/__init__.py b/tensorrt_llm/_torch/disaggregation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tensorrt_llm/_torch/disaggregation/base/__init__.py b/tensorrt_llm/_torch/disaggregation/base/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tensorrt_llm/_torch/disaggregation/base/agent.py b/tensorrt_llm/_torch/disaggregation/base/agent.py new file mode 100644 index 0000000000..1ac8647377 --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/base/agent.py @@ -0,0 +1,145 @@ +import os +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, NamedTuple, Optional + +from tensorrt_llm import logger + + +# We deliberately use a non-enum data structure here. This choice ensures that +# members are directly equivalent to the plain strings. +class TransferOp: + READ = "READ" + WRITE = "WRITE" + + +class MemoryType: + DRAM = "DRAM" + VRAM = "VRAM" + BLK = "BLK" + OBJ = "OBJ" + FILE = "FILE" + + +class MemoryDesc(NamedTuple): + ptr: int + size: int + device_id: int + name: Optional[str] = None + + +@dataclass +class MemoryDescs: + type: str + descs: List[MemoryDesc] + + +@dataclass +class TransferRequest: + op: TransferOp + src_descs: MemoryDescs + dst_descs: MemoryDescs + remote_name: str + sync_message: Optional[str] = None + + +@dataclass +class RegMemoryDescs: + type: str + descs: List[MemoryDesc] + + +class TransferStatus(ABC): + @abstractmethod + def is_completed(self) -> bool: ... + + @abstractmethod + def wait(self, timeout_ms: int | None = None) -> bool: ... + + +class BaseTransferAgent(ABC): + @abstractmethod + def register_memory(self, descs: RegMemoryDescs) -> None: ... + + @abstractmethod + def deregister_memory(self, descs: RegMemoryDescs) -> None: ... + + @abstractmethod + def load_remote_agent(self, name: str, agent_desc: bytes) -> None: ... + + @abstractmethod + def get_local_agent_desc(self) -> bytes: ... + + @abstractmethod + def invalidate_remote_agent(self, name: str) -> None: ... + + @abstractmethod + def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus: ... + + @abstractmethod + def notify_sync_message(self, name: str, sync_message: str) -> None: ... + + @abstractmethod + def check_remote_descs(self, name: str, memory_descs: MemoryDescs) -> bool: ... + + +def _force_py_nixl_kv_transfer() -> bool: + env_value = os.getenv("TRTLLM_USE_PY_NIXL_KVCACHE", "0") + if env_value not in {"0", "1"}: + logger.warning( + f"Invalid value for TRTLLM_USE_PY_NIXL_KVCACHE: {env_value}. Expected '0' or '1'. Defaulting to '0'." + ) + return False + if env_value == "1": + logger.info("Forcing use of pure Python NIXL KV Transfer Agent implementation.") + return True + return False + + +def _try_load_cpp_binding(): + try: + import tensorrt_llm.tensorrt_llm_transfer_agent_binding as _cpp_binding + + required_attributes = [ + "MemoryType", + "TransferOp", + "MemoryDesc", + "MemoryDescs", + "TransferRequest", + "TransferStatus", + "BaseTransferAgent", + ] + if all(hasattr(_cpp_binding, attr) for attr in required_attributes): + return _cpp_binding + except ImportError: + logger.info("tensorrt_llm_transfer_agent_binding module not found.") + return None + + +_use_pure_python_transfer_agent = None + +# The current implementation still implicitly depends on cpp_bindings. +# We should remove this dependency. +_cpp_binding = _try_load_cpp_binding() + +if _force_py_nixl_kv_transfer(): + logger.info("Using pure Python transfer agent (forced by TRTLLM_USE_PY_NIXL_KVCACHE)") + _use_pure_python_transfer_agent = True +else: + if _cpp_binding: + MemoryType = _cpp_binding.MemoryType + TransferOp = _cpp_binding.TransferOp + MemoryDesc = _cpp_binding.MemoryDesc + MemoryDescs = _cpp_binding.MemoryDescs + TransferRequest = _cpp_binding.TransferRequest + TransferStatus = _cpp_binding.TransferStatus + BaseTransferAgent = _cpp_binding.BaseTransferAgent + logger.info("Using C++ transfer agent binding") + _use_pure_python_transfer_agent = False + else: + logger.info("C++ transfer agent binding unavailable, using pure Python implementation") + _use_pure_python_transfer_agent = True + + +def use_pure_python_transfer_agent() -> bool: + return _use_pure_python_transfer_agent diff --git a/tensorrt_llm/_torch/disaggregation/base/kv_transfer.py b/tensorrt_llm/_torch/disaggregation/base/kv_transfer.py new file mode 100644 index 0000000000..9cb349f6d2 --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/base/kv_transfer.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Optional + +from tensorrt_llm import DisaggregatedParams + + +@dataclass +class TokenRange: + """Range of tokens in the sequence dimension.""" + + start: int + end: int # exclusive + + def __post_init__(self): + if self.start < 0 or self.end < 0: + raise ValueError("Token indices must be non-negative") + if self.start >= self.end: + raise ValueError(f"Invalid range: [{self.start}, {self.end})") + + +@dataclass +class LayerRange: + """Range of layers to transfer.""" + + start: int + end: int # exclusive + + def __post_init__(self): + if self.start < 0 or self.end < 0: + raise ValueError("Layer indices must be non-negative") + if self.start >= self.end: + raise ValueError(f"Invalid range: [{self.start}, {self.end})") + + +@dataclass +class KVSlice: + """ + Specifies which portion of KV cache to transfer. + """ + + token_range: Optional[TokenRange] = None + layer_range: Optional[LayerRange] = None + block_ids: List[int] = field(default_factory=list) # Physical block IDs + is_last_slice: bool = False + + +class SessionStatus(Enum): + """Status of a transfer session. + + Represents the various stages/statuses that a file transfer session can go through: + + - INIT: The session has been initialized but not yet ready. + - READY: The session is ready to start transferring. + - TRANSFERRING: The session is in progress, currently transferring data. + - TRANSFERRED: The primary transfer has completed successfully. + - AUX_TRANSFERRED: The auxiliary part (such as tokens) of the transfer has completed successfully. + - COMPLETED: The entire session process, including all transfers, has been successfully completed. + - CANCELED: The session has been canceled by the user or system. + - ERROR: An error occurred during the session. The session could not complete successfully. + """ + + INIT = "INIT" + READY = "READY" + TRANSFERRING = "TRANSFERRING" + TRANSFERRED = "TRANSFERRED" + AUX_TRANSFERRED = "AUX_TRANSFERRED" + COMPLETED = "COMPLETED" + CANCELED = "CANCELED" + ERROR = "ERROR" + + +TaskIdType = int + + +@dataclass +class SessionState: + """State of a transfer session.""" + + status: SessionStatus + finished_tasks: List[TaskIdType] + + +@dataclass +class SessionArgsBase: + """Base arguments for transfer sessions.""" + + params: DisaggregatedParams + + +class SenderBase(ABC): + """Base class for sending KV cache data.""" + + ... + + +class ReceiverBase(ABC): + """Base class for receiving KV cache data.""" + + ... + + +class TxSessionBase(ABC): + def __init__(self, sender: SenderBase, args: SessionArgsBase): + """ + Initializes the transmission session. + :param sender: The sender instance responsible for sending data. + :param args: The session arguments. + """ + self._sender = sender + self._base_args = args + + @property + @abstractmethod + def state(self) -> SessionState: + """ + Returns the current state of the session. + """ + ... + + @abstractmethod + def poll_task(self, id: TaskIdType) -> SessionStatus: + """ + Polls the status of a specific task by its ID. + :param id: The task ID to poll. + """ + ... + + @abstractmethod + def send(self, slice: KVSlice) -> TaskIdType: + """ + Sends a slice of KV cache data and returns the task ID. + :param slice: The KV slice to send. + """ + ... + + @property + @abstractmethod + def exception(self) -> Optional[Exception]: + """ + Returns any exception that occurred during the session. + """ + ... + + @abstractmethod + def close(self) -> None: + """ + Closes the session and releases any resources. + """ + ... + + +class RxSessionBase(ABC): + def __init__(self, receiver: ReceiverBase, args: SessionArgsBase): + """ + Initializes the reception session. + :param receiver: The receiver instance responsible for receiving data. + """ + self._receiver = receiver + self._base_args = args + + @property + @abstractmethod + def state(self) -> SessionState: + """ + Returns the current state of the session. + """ + ... + + @abstractmethod + def poll_task(self, task_id: TaskIdType) -> SessionStatus: + """ + Polls the status of a specific task by its ID. + :param task_id: The task ID to poll. + """ + ... + + @abstractmethod + def receive(self, slice: KVSlice) -> TaskIdType: + """ + Receives a slice of KV cache data and returns the task ID. + :param slice: The KV slice to receive. + """ + ... + + @property + @abstractmethod + def exception(self) -> Optional[Exception]: + """Returns any exception that occurred during the session.""" + ... + + @abstractmethod + def close(self) -> None: + """ + Closes the session and releases any resources. + """ + ... diff --git a/tensorrt_llm/_torch/disaggregation/native/__init__.py b/tensorrt_llm/_torch/disaggregation/native/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tensorrt_llm/_torch/disaggregation/native/messenger.py b/tensorrt_llm/_torch/disaggregation/native/messenger.py new file mode 100644 index 0000000000..c10c623492 --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/native/messenger.py @@ -0,0 +1,219 @@ +from abc import ABC, abstractmethod +from threading import Event, Lock, Thread +from typing import Callable, Optional + +import zmq + +from tensorrt_llm import logger +from tensorrt_llm._torch.disaggregation.native.utils import get_local_ip + + +class MessengerInterface(ABC): + """ + Abstract base class for messenger implementations. + """ + + @abstractmethod + def start(self) -> None: + """ + Start the messenger service. + """ + ... + + @abstractmethod + def send(self, messages: list[bytes], recipient: Optional[bytes] = None) -> None: + """ + Send messages to a recipient. + :param messages: List of byte messages to send. + :param recipient: Optional recipient identifier. + """ + ... + + @abstractmethod + def receive(self) -> list[bytes]: + """ + Receive messages. + :return: List of byte messages received. + """ + ... + + @abstractmethod + def start_listener(self, on_message: Callable[[list[bytes]], Optional[bool]]) -> None: + """ + Start a listener thread to handle incoming messages. + :param on_message: Callback function to process received messages. + """ + ... + + @abstractmethod + def stop(self) -> None: + """ + Stop the messenger service. + """ + ... + + @property + @abstractmethod + def endpoint(self) -> str: + """ + Get the endpoint of the messenger. + :return: Endpoint string. + """ + ... + + +def decode_message( + message: list[bytes], encoding: str = "ascii", err_mode: str = "strict" +) -> tuple: + if not isinstance(message, list) or not all(isinstance(m, bytes) for m in message): + raise ValueError("Input must be a list of bytes") + return tuple(m.decode(encoding, errors=err_mode) for m in message) + + +class ZMQMessenger(MessengerInterface): + SOCKET_MODES = { + "ROUTER": zmq.ROUTER, # Handles multiple connections and routes messages by address. + "DEALER": zmq.DEALER, # Load balances outgoing messages and receives replies fairly. + "REQ": zmq.REQ, # Sends requests and waits for replies (synchronous). + "REP": zmq.REP, # Receives requests and sends replies (synchronous). + } + + def __init__(self, mode: str, endpoint: Optional[str] = None) -> None: + if mode not in self.SOCKET_MODES: + raise ValueError( + f"Invalid mode '{mode}'. Allowed modes are {list(self.SOCKET_MODES.keys())}" + ) + self._context = zmq.Context() + self._mode = mode + self._socket = self._context.socket(self.SOCKET_MODES[mode]) + self._endpoint: Optional[str] = None + self._lock = Lock() + self._closed = False + self._stop_event = Event() + self._listener_thread: Optional[Thread] = None + self._initialize_control_sockets() + + if endpoint is None: + if mode in ["DEALER", "REQ"]: + raise ValueError("endpoint is required for DEALER/REQ modes") + endpoint = f"tcp://{get_local_ip()}:*" + + if mode in ["ROUTER", "REP"]: + self._socket.bind(endpoint) + self._endpoint = self._socket.getsockopt_string(zmq.LAST_ENDPOINT) + elif mode in ["DEALER", "REQ"]: + self._socket.connect(endpoint) + self._endpoint = endpoint + + logger.info(f"Initialized ZMQMessenger(mode={mode}, endpoint={self._endpoint})") + + def _initialize_control_sockets(self) -> None: + self._control_socket = self._context.socket(zmq.PAIR) + self._internal_socket = self._context.socket(zmq.PAIR) + inproc_endpoint = "inproc://stop_listener" + self._control_socket.bind(inproc_endpoint) + self._internal_socket.connect(inproc_endpoint) + + def start(self) -> None: + pass + + def send(self, messages: list[bytes], recipient: Optional[bytes] = None) -> None: + if recipient: + self._socket.send_multipart([recipient] + messages) + else: + self._socket.send_multipart(messages) + + def receive(self) -> list[bytes]: + return self._socket.recv_multipart() + + def start_listener( + self, + on_message: Callable[[list[bytes]], Optional[bool]], + on_error: Optional[Callable[[Exception], None]] = None, + ) -> None: + assert self._mode in ["ROUTER", "REP"], ( + "Listener can only be started in ROUTER or REP modes" + ) + if self._listener_thread and self._listener_thread.is_alive(): + raise RuntimeError("Listener already running") + + def handle_listener_exceptions( + exception: Exception, on_error: Optional[Callable[[Exception], None]] + ) -> None: + logger.error(f"Error in listener: {exception}") + if on_error: + on_error(exception) + else: + self._stop_event.set() + + def listener() -> None: + poller = zmq.Poller() + poller.register(self._socket, zmq.POLLIN) + poller.register(self._control_socket, zmq.POLLIN) + + while not self._stop_event.is_set(): + events = dict(poller.poll(timeout=100)) + try: + if self._control_socket in events: + self._stop_event.set() + elif self._socket in events: + messages = self.receive() + persist = on_message(messages) + if persist is False: + self._stop_event.set() + except zmq.ZMQError as e: + handle_listener_exceptions(e, on_error) + break + except Exception as e: + handle_listener_exceptions(e, on_error) + break + + self._stop_event.set() + + self._listener_thread = Thread(target=listener, daemon=True) + self._listener_thread.start() + + def stop(self, timeout: int = 5) -> None: + def _close_socket(socket: zmq.Socket) -> None: + try: + if not socket.closed: + socket.close() + except Exception as e: + logger.error(f"Error closing socket: {e}") + + with self._lock: + if self._closed: + return + self._closed = True + logger.debug("Stopping ZMQMessenger...") + + self._stop_event.set() + self._internal_socket.send(b"STOP") + if self._listener_thread: + self._internal_socket.send(b"STOP") + self._listener_thread.join(timeout) + if self._listener_thread.is_alive(): + logger.warning("Listener thread did not terminate within timeout") + + _close_socket(self._socket) + _close_socket(self._internal_socket) + _close_socket(self._control_socket) + + try: + if not self._context.closed: + self._context.term() + except Exception as e: + logger.error(f"Error terminating ZMQ context: {e}") + + @property + def endpoint(self) -> str: + assert self._endpoint is not None + return self._endpoint + + def __enter__(self) -> "ZMQMessenger": + return self + + def __exit__( + self, exc_type: Optional[type], exc_val: Optional[BaseException], exc_tb: Optional + ) -> None: + self.stop() diff --git a/tensorrt_llm/_torch/disaggregation/native/utils.py b/tensorrt_llm/_torch/disaggregation/native/utils.py new file mode 100644 index 0000000000..8bc11d91ea --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/native/utils.py @@ -0,0 +1,41 @@ +from tensorrt_llm import logger + + +def get_local_ip() -> str: + try: + import socket + + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect(("10.255.255.255", 1)) + ip = s.getsockname()[0] + if not ip.startswith("127."): + return ip + + except (ImportError, OSError, ValueError): + logger.error("Failed to get local IP via UDP socket method.") + pass + + try: + import netifaces + + for iface in netifaces.interfaces(): + addrs = netifaces.ifaddresses(iface) + if netifaces.AF_INET in addrs: + for addr in addrs[netifaces.AF_INET]: + ip = addr.get("addr", "") + if not ip.startswith("127.") and not ip.startswith("169.254"): + return ip + except (ImportError, OSError, ValueError) as e: + logger.error(f"Failed to get local IP via netifaces: {e}") + pass + + try: + hostname = socket.gethostname() + ip = socket.gethostbyname(hostname) + if not ip.startswith("127."): + return ip + except (OSError, ValueError): + logger.error("Failed to get local IP via hostname resolution.") + pass + + return "127.0.0.1" diff --git a/tensorrt_llm/_torch/disaggregation/nixl/__init__.py b/tensorrt_llm/_torch/disaggregation/nixl/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tensorrt_llm/_torch/disaggregation/nixl/_agent_cpp.py b/tensorrt_llm/_torch/disaggregation/nixl/_agent_cpp.py new file mode 100644 index 0000000000..84f85d41b3 --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/nixl/_agent_cpp.py @@ -0,0 +1,147 @@ +from tensorrt_llm._utils import nvtx_range +from tensorrt_llm.tensorrt_llm_transfer_agent_binding import ( # noqa: E402 + AgentDesc, + BaseAgentConfig, + MemoryDescs, + MemoryType, + TransferState, +) +from tensorrt_llm.tensorrt_llm_transfer_agent_binding import ( + NixlTransferAgent as CppNixlTransferAgent, +) + +from ..base.agent import BaseTransferAgent, RegMemoryDescs, TransferRequest, TransferStatus + + +class BindingsNixlTransferStatus(TransferStatus): + """TransferStatus wrapper using C++ bindings with GIL release.""" + + def __init__(self, cpp_status): + self._cpp_status = cpp_status + + def is_completed(self) -> bool: + """Check if transfer is completed (releases GIL).""" + return self._cpp_status.is_completed() + + @nvtx_range("BindingsNixlTransferStatus.wait") + def wait(self, timeout_ms=None) -> bool: + """Wait for transfer to complete (releases GIL).""" + if timeout_ms is None: + timeout_ms = -1 + return self._cpp_status.wait(timeout_ms) == TransferState.SUCCESS + + +class BindingsNixlTransferAgent(BaseTransferAgent): + """NixlTransferAgent using C++ bindings with GIL release support. + + This implementation uses the standalone nixl_bindings C++ module which releases + the GIL during blocking operations like wait(). + + The nixl_bindings module is independent from the main trtllm bindings, + so trtllm can function normally even without NIXL. + """ + + def __init__( + self, + name: str, + use_prog_thread: bool = True, + num_threads: int = 1, + enable_telemetry: bool = False, + **kwargs, + ): + backend_params = kwargs + for key, value in backend_params.items(): + backend_params[key] = str(value) + backend_params["num_threads"] = str(num_threads) + + config = BaseAgentConfig( + name, + use_prog_thread, + multi_thread=False, + use_listen_thread=False, + enable_telemetry=enable_telemetry, + backend_params=backend_params, + ) + self._cpp_agent = CppNixlTransferAgent(config) + self.name = name + + def register_memory(self, descs: RegMemoryDescs): + """Register memory regions.""" + cpp_descs = self._convert_reg_memory_descs(descs) + self._cpp_agent.register_memory(cpp_descs) + + def deregister_memory(self, descs: RegMemoryDescs): + """Deregister memory regions.""" + cpp_descs = self._convert_reg_memory_descs(descs) + self._cpp_agent.deregister_memory(cpp_descs) + + def load_remote_agent(self, name: str, agent_desc: bytes): + """Load a remote agent by its descriptor (bytes).""" + # AgentDesc expects std::string which can hold binary data + desc_str = agent_desc if isinstance(agent_desc, bytes) else agent_desc.encode() + cpp_desc = AgentDesc(desc_str) + self._cpp_agent.load_remote_agent(name, cpp_desc) + + def load_remote_agent_by_connection(self, name: str, connection_info: str): + """Load a remote agent by connection info.""" + self._cpp_agent.load_remote_agent_by_connection(name, connection_info) + + def get_local_agent_desc(self) -> bytes: + """Get the local agent descriptor as bytes.""" + agent_desc = self._cpp_agent.get_local_agent_desc() + return agent_desc.backend_agent_desc # Returns bytes + + def get_local_connection_info(self) -> str: + """Get the local connection info.""" + return self._cpp_agent.get_local_connection_info() + + def invalidate_remote_agent(self, name: str): + """Invalidate a remote agent.""" + self._cpp_agent.invalidate_remote_agent(name) + + def check_remote_descs(self, name: str, memory_descs: MemoryDescs) -> bool: + """Check if remote descriptors are available. + + memory_descs should be C++ MemoryDescs type. + """ + return self._cpp_agent.check_remote_descs(name, memory_descs) + + def notify_sync_message(self, name: str, sync_message: str): + """Send a sync message to a remote agent.""" + self._cpp_agent.notify_sync_message(name, sync_message) + + def get_notified_sync_messages(self): + """Get notified sync messages.""" + return self._cpp_agent.get_notified_sync_messages() + + @nvtx_range("BindingsNixlTransferAgent.submit_transfer_requests") + def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus: + """Submit transfer requests and return status. + + request should be a C++ TransferRequest (from tensorrt_llm_transfer_agent_binding). + """ + cpp_status = self._cpp_agent.submit_transfer_requests(request) + return BindingsNixlTransferStatus(cpp_status) + + def _convert_reg_memory_descs(self, descs: RegMemoryDescs) -> "MemoryDescs": + """Convert Python RegMemoryDescs to C++ MemoryDescs. + + RegMemoryDescs.descs is List[Tuple[int, int, int, str]] = (ptr, size, device_id, name) + Extract first 3 elements for C++ batch constructor. + """ + mem_type = self._convert_memory_type(descs.type) + # Extract (ptr, size, device_id) from 4-tuple, discard name + tuples = [(d[0], d[1], d[2]) for d in descs.descs] + return MemoryDescs(mem_type, tuples) + + def _convert_memory_type(self, py_type: str) -> "MemoryType": + """Convert Python memory type string to C++ MemoryType.""" + type_map = { + "DRAM": MemoryType.DRAM, + "VRAM": MemoryType.VRAM, + "GPU": MemoryType.VRAM, + "BLK": MemoryType.BLK, + "OBJ": MemoryType.OBJ, + "FILE": MemoryType.FILE, + } + return type_map.get(py_type.upper(), MemoryType.VRAM) diff --git a/tensorrt_llm/_torch/disaggregation/nixl/_agent_py.py b/tensorrt_llm/_torch/disaggregation/nixl/_agent_py.py new file mode 100644 index 0000000000..cd01f3024a --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/nixl/_agent_py.py @@ -0,0 +1,114 @@ +import time +from enum import Enum + +from nixl import nixl_agent, nixl_agent_config, nixl_xfer_handle + +from tensorrt_llm._utils import nvtx_range + +# Import base classes for type compatibility +from ..base.agent import BaseTransferAgent, RegMemoryDescs, TransferRequest, TransferStatus + + +class TransferState(Enum): + PENDING = "PENDING" + PROCESSING = "PROC" + DONE = "DONE" + ERROR = "ERROR" + + +class NixlTransferStatus(TransferStatus): + def __init__(self, agent: nixl_agent, handle: nixl_xfer_handle): + self.agent = agent + self.handle = handle + + def is_completed(self): + status = TransferState(self.agent.check_xfer_state(self.handle)) + return status == TransferState.DONE + + def wait(self, timeout_ms=None): + start_time = time.time() + status = TransferState.PENDING + sleep_time = 0.0001 # 0.1ms in seconds + max_sleep_time = 0.01 # 10ms in seconds + + timeout = timeout_ms / 1000 if timeout_ms is not None else None + + while status in (TransferState.PENDING, TransferState.PROCESSING): + status = TransferState(self.agent.check_xfer_state(self.handle)) + if status == TransferState.ERROR: + return False # Transfer failed + if timeout is not None and (time.time() - start_time > timeout): + return False # Timeout + time.sleep(sleep_time) + sleep_time = min(sleep_time * 2, max_sleep_time) + return status == TransferState.DONE + + +class NixlTransferAgent(BaseTransferAgent): + """NixlTransferAgent using Python nixl library.""" + + def __init__(self, name: str, use_prog_thread: bool = True, num_threads: int = 1, **kwargs): + """ + Initialize NixlTransferAgent. + :param name: Name of the agent. + :param use_prog_thread: Whether to enable the progress thread, if available. + :param num_workers: Specify number of threads for the supported multi-threaded backends. + """ + self.name = name + self.backends = ["UCX"] + agent_config = nixl_agent_config( + enable_prog_thread=use_prog_thread, backends=self.backends, num_threads=num_threads + ) + self.agent = nixl_agent(name, agent_config) + + def register_memory(self, descs: RegMemoryDescs): + if not descs.descs: + raise ValueError("descs.descs must not be empty") + if isinstance(descs.descs[0], tuple): + assert len(descs.descs[0]) == 4, f"Expected 4 elements per desc, got {descs.descs[0]}" + reg_descs = self.agent.get_reg_descs(descs.descs, descs.type) + assert reg_descs is not None, "Failed to get reg_descs" + self.agent.register_memory(reg_descs) + + def deregister_memory(self, descs: RegMemoryDescs): + if not descs.descs: + raise ValueError("descs.descs must not be empty") + if isinstance(descs.descs[0], tuple): + assert len(descs.descs[0]) == 4, f"Expected 4 elements per desc, got {descs.descs[0]}" + reg_descs = self.agent.get_reg_descs(descs.descs, descs.type) + assert reg_descs is not None, "Failed to get reg_descs" + self.agent.deregister_memory(reg_descs) + + def load_remote_agent(self, name: str, agent_desc: bytes): + self.agent.add_remote_agent(agent_desc) + + def get_local_agent_desc(self): + return self.agent.get_agent_metadata() + + def invalidate_remote_agent(self, name: str): + self.agent.remove_remote_agent(name) + + def check_remote_descs(self, name: str, memory_descs: list[int]) -> bool: + raise NotImplementedError + + def notify_sync_message(self, name: str, sync_message: str): + raise NotImplementedError + + @nvtx_range("NixlTransferAgent.submit_transfer_requests") + def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus: + src_xfer_descs = self.agent.get_xfer_descs(request.src_descs.descs, request.src_descs.type) + dst_xfer_descs = self.agent.get_xfer_descs(request.dst_descs.descs, request.dst_descs.type) + assert src_xfer_descs is not None, "Failed to get src_xfer_descs" + assert dst_xfer_descs is not None, "Failed to get dst_xfer_descs" + sync_message = "" if request.sync_message is None else request.sync_message + handle = self.agent.initialize_xfer( + request.op, + src_xfer_descs, + dst_xfer_descs, + request.remote_name, + sync_message, + ) + status = self.agent.transfer(handle) + if status == "ERROR": + raise RuntimeError("NIXL transfer initialization failed.") + return NixlTransferStatus(self.agent, handle) diff --git a/tensorrt_llm/_torch/disaggregation/nixl/agent.py b/tensorrt_llm/_torch/disaggregation/nixl/agent.py new file mode 100644 index 0000000000..5f6f3db154 --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/nixl/agent.py @@ -0,0 +1,43 @@ +from tensorrt_llm.logger import logger + +from ..base.agent import use_pure_python_transfer_agent + +"""NIXL Transfer Agent implementations. + +This module provides two implementations: +1. BindingsNixlTransferAgent - Uses the standalone nixl_bindings C++ module with GIL release support +2. NixlTransferAgent - Uses the Python nixl library directly (fallback) + +The standalone nixl_bindings module is separate from the main trtllm bindings, +so trtllm can still function normally even without NIXL dependencies. +""" + + +def _load_agent(module_name, required_attributes): + try: + module = __import__(module_name, fromlist=required_attributes, level=0) + if all(hasattr(module, attr) for attr in required_attributes): + return module + except ImportError as e: + logger.info("Failed to import module: %s. Error: %s", module_name, str(e)) + return None + + +NixlTransferStatus, NixlTransferAgent = None, None + +if use_pure_python_transfer_agent(): + _py_agent = _load_agent( + module_name="tensorrt_llm._torch.disaggregation.nixl._agent_py", + required_attributes=["NixlTransferAgent", "NixlTransferStatus"], + ) + assert _py_agent is not None, "Failed to load pure Python NIXL Transfer Agent." + NixlTransferStatus = _py_agent.NixlTransferStatus + NixlTransferAgent = _py_agent.NixlTransferAgent +else: + _cpp_agent = _load_agent( + module_name="tensorrt_llm._torch.disaggregation.nixl._agent_cpp", + required_attributes=["BindingsNixlTransferAgent", "BindingsNixlTransferStatus"], + ) + assert _cpp_agent is not None, "Failed to load C++ NIXL Transfer Agent bindings." + NixlTransferStatus = _cpp_agent.BindingsNixlTransferStatus + NixlTransferAgent = _cpp_agent.BindingsNixlTransferAgent diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 6521d65766..efc5f92021 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -33,6 +33,8 @@ l0_a10: - unittest/disaggregated/test_disagg_utils.py - unittest/disaggregated/test_router.py - unittest/disaggregated/test_remoteDictionary.py + - unittest/disaggregated/test_agent_multi_backends.py + - unittest/disaggregated/test_messenger.py - unittest/disaggregated/test_disagg_cluster_manager_worker.py - unittest/disaggregated/test_cluster_storage.py - disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0] diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 1a4617fd98..9d98ff7910 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -37,6 +37,8 @@ l0_h100: - unittest/disaggregated/test_disagg_utils.py - unittest/disaggregated/test_router.py - unittest/disaggregated/test_remoteDictionary.py + - unittest/disaggregated/test_agent_multi_backends.py + - unittest/disaggregated/test_messenger.py - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_without_reuse - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_reuse diff --git a/tests/unittest/bindings/test_transfer_agent_bindings.py b/tests/unittest/bindings/test_transfer_agent_bindings.py index fe77a67483..2191f65dc9 100644 --- a/tests/unittest/bindings/test_transfer_agent_bindings.py +++ b/tests/unittest/bindings/test_transfer_agent_bindings.py @@ -142,21 +142,24 @@ def test_base_agent_config_custom(): use_prog_thread = True multi_thread = False use_listen_thread = True - num_workers = 4 + enable_telemetry = True + backend_params = {"key1": "value1", "key2": "value2"} config = tab.BaseAgentConfig( name=name, use_prog_thread=use_prog_thread, multi_thread=multi_thread, use_listen_thread=use_listen_thread, - num_workers=num_workers, + enable_telemetry=enable_telemetry, + backend_params=backend_params, ) assert config.name == name assert config.use_prog_thread == use_prog_thread assert config.multi_thread == multi_thread assert config.use_listen_thread == use_listen_thread - assert config.num_workers == num_workers + assert config.enable_telemetry == enable_telemetry + assert config.backend_params == backend_params def test_base_agent_config_readwrite(): @@ -175,8 +178,11 @@ def test_base_agent_config_readwrite(): config.use_listen_thread = True assert config.use_listen_thread is True - config.num_workers = 8 - assert config.num_workers == 8 + config.enable_telemetry = True + assert config.enable_telemetry is True + + config.backend_params = {"test_key": "test_value"} + assert config.backend_params == {"test_key": "test_value"} def test_transfer_request(): diff --git a/tests/unittest/disaggregated/test_agent.py b/tests/unittest/disaggregated/test_agent.py new file mode 100644 index 0000000000..134c1ea1a4 --- /dev/null +++ b/tests/unittest/disaggregated/test_agent.py @@ -0,0 +1,176 @@ +from dataclasses import dataclass, field +from unittest import TestCase +from unittest.mock import Mock + +import pytest +import torch + +from tensorrt_llm import logger +from tensorrt_llm._torch.disaggregation.base.agent import ( + MemoryDescs, + MemoryType, + RegMemoryDescs, + TransferOp, + TransferRequest, + TransferStatus, +) +from tensorrt_llm._torch.disaggregation.nixl.agent import NixlTransferAgent + + +class TestTransferStatus(TestCase): + def test_mock_transfer_status(self): + mock_transfer_status = Mock(spec=TransferStatus) + mock_transfer_status.is_completed.return_value = True + self.assertTrue(mock_transfer_status.is_completed()) + mock_transfer_status.is_completed.assert_called_once() + mock_transfer_status.wait.return_value = True + timeout_values = [None, 1000, 5000] + for timeout in timeout_values: + with self.subTest(timeout=timeout): + result = mock_transfer_status.wait(timeout_ms=timeout) + self.assertTrue(result) + mock_transfer_status.wait.assert_called_with(timeout_ms=timeout) + + +def _convert_to_memory_descs(reg_descs: RegMemoryDescs) -> MemoryDescs: + tuples = [(ptr, size, device_id) for (ptr, size, device_id, _) in reg_descs.descs] + + def _convert_memory_type(py_type: str) -> MemoryType: + """Convert Python memory type string to C++ MemoryType.""" + type_map = { + "DRAM": MemoryType.DRAM, + "VRAM": MemoryType.VRAM, + "GPU": MemoryType.VRAM, + "BLK": MemoryType.BLK, + "OBJ": MemoryType.OBJ, + "FILE": MemoryType.FILE, + } + return type_map.get(py_type.upper(), MemoryType.VRAM) + + return MemoryDescs(_convert_memory_type(reg_descs.type), tuples) + + +@dataclass +class MemoryManager: + allocated_memory: list[torch.Tensor] = field(default_factory=list) + + def allocate_memory( + self, size: int, name: str, memory_type=MemoryType.VRAM, device_id: int = 0 + ) -> RegMemoryDescs: + device = torch.device(f"cuda:{device_id}" if memory_type == MemoryType.VRAM else "cpu") + + # Allocate memory block using torch.Tensor and track it + block = torch.zeros(size, dtype=torch.uint8, device=device) + self.allocated_memory.append(block) + + # Return RegMemoryDescs with position arguments + memory_descs = RegMemoryDescs( + type=memory_type, descs=[(block.data_ptr(), block.numel(), device_id, name)] + ) + return memory_descs + + def clear_memory(self): + """Clear all tracked memory blocks.""" + self.allocated_memory.clear() + + +@pytest.fixture +def memory_manager(): + return MemoryManager() + + +@pytest.fixture(params=[256, 512]) +def memory_size(request): + return request.param + + +@pytest.fixture(params=["DRAM", "VRAM"]) +def memory_type(request): + return request.param + + +@pytest.fixture +def alloc(memory_manager, memory_size, memory_type): + """Allocate memory for source and destination, based on the memory_size and memory_type parameters.""" + assert memory_size > 0, "Memory size must be a positive integer." + if memory_type == "VRAM" and not torch.cuda.is_available(): + pytest.skip("CUDA not available for VRAM transfer tests") + src_descs = memory_manager.allocate_memory( + size=memory_size, name="src_mem", memory_type=memory_type + ) + dst_descs = memory_manager.allocate_memory( + size=memory_size, name="dst_mem", memory_type=memory_type + ) + return src_descs, dst_descs + + +@pytest.fixture +def transfer_agent_src(): + return NixlTransferAgent(name="src_agent") + + +@pytest.fixture +def transfer_agent_dst(): + return NixlTransferAgent(name="dst_agent") + + +def test_transfer_between_agents( + transfer_agent_src, + transfer_agent_dst, + memory_manager, + alloc, + memory_size, + memory_type, +): + """End-to-end test of data transfer between two agents with parameterized memory sizes and types.""" + # Debug log the parameters being tested + logger.info(f"Testing with memory_size={memory_size}, memory_type={memory_type}") + + # Unpack source and destination memory descriptions + memory_descs_src, memory_descs_dst = alloc + + # Fill source memory with sequential data for validation + src_data = memory_manager.allocated_memory[0] + assert memory_size > 0, "Memory size must be positive." + tensor = torch.arange(memory_size, dtype=torch.uint8) % 10 + src_data.copy_(tensor) + + # Register memory with source and destination agents + transfer_agent_src.register_memory(memory_descs_src) + transfer_agent_dst.register_memory(memory_descs_dst) + + src_agent_desc = transfer_agent_src.get_local_agent_desc() + transfer_agent_dst.load_remote_agent("src_agent", src_agent_desc) + + dst_agent_desc = transfer_agent_dst.get_local_agent_desc() + transfer_agent_src.load_remote_agent("dst_agent", dst_agent_desc) + + # Create and submit the transfer request + transfer_request = TransferRequest( + op=TransferOp.WRITE, + src_descs=_convert_to_memory_descs(memory_descs_src), + dst_descs=_convert_to_memory_descs(memory_descs_dst), + remote_name="dst_agent", + sync_message=None, + ) + transfer_status = transfer_agent_src.submit_transfer_requests(transfer_request) + assert transfer_status.wait(timeout_ms=5000), "Transfer did not complete within timeout." + + # Validate transfer completion + assert transfer_status.is_completed(), "Transfer did not complete successfully." + + # Validate that the destination data matches the source data + dst_data = memory_manager.allocated_memory[1] + assert torch.equal(dst_data, src_data), "Destination data does not match source data." + + # Clean up by deregistering memory and clearing allocations + transfer_agent_src.deregister_memory(memory_descs_src) + transfer_agent_dst.deregister_memory(memory_descs_dst) + memory_manager.clear_memory() + + transfer_agent_src.invalidate_remote_agent("dst_agent") + transfer_agent_dst.invalidate_remote_agent("src_agent") + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/unittest/disaggregated/test_agent_multi_backends.py b/tests/unittest/disaggregated/test_agent_multi_backends.py new file mode 100644 index 0000000000..0a95bad03b --- /dev/null +++ b/tests/unittest/disaggregated/test_agent_multi_backends.py @@ -0,0 +1,32 @@ +import os +import subprocess + +import pytest + + +@pytest.mark.parametrize("use_py_nixl", ["0", "1"]) +def test_run_with_different_env(use_py_nixl): + os.environ["TRTLLM_USE_PY_NIXL_KVCACHE"] = use_py_nixl + print(f"Running tests with TRTLLM_USE_PY_NIXL_KVCACHE={use_py_nixl}") + + test_file_path = os.path.join(os.path.dirname(__file__), "test_agent.py") + print(f"Running tests in: {test_file_path}") + + result = subprocess.run( + ["pytest", "--capture=no", test_file_path], + env=os.environ.copy(), + capture_output=True, + text=True, + ) + + print(result.stdout) + + if result.returncode != 0: + print("Test failed. stderr output:") + print(result.stderr) + + assert result.returncode == 0, f"Tests failed with TRTLLM_USE_PY_NIXL_KVCACHE={use_py_nixl}" + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/unittest/disaggregated/test_messenger.py b/tests/unittest/disaggregated/test_messenger.py new file mode 100644 index 0000000000..94e052d531 --- /dev/null +++ b/tests/unittest/disaggregated/test_messenger.py @@ -0,0 +1,127 @@ +import socket +import time +import unittest + +import pytest +from parameterized import parameterized + +from tensorrt_llm._torch.disaggregation.native.messenger import ZMQMessenger, decode_message +from tensorrt_llm._torch.disaggregation.native.utils import get_local_ip + +TEST_CASES = [ + { + "name": "valid_message", + "message": [b"hello", b"world"], + "encoding": "utf-8", + "err_mode": "strict", + "expected": ("hello", "world"), + "raises": None, + }, + { + "name": "invalid_input", + "message": ["hello", b"world"], + "encoding": "utf-8", + "err_mode": "strict", + "expected": None, + "raises": ValueError, + }, + { + "name": "decoding_error", + "message": [b"\xff"], + "encoding": "utf-8", + "err_mode": "strict", + "expected": None, + "raises": UnicodeDecodeError, + }, + { + "name": "decoding_with_ignore", + "message": [b"\xff"], + "encoding": "utf-8", + "err_mode": "ignore", + "expected": ("",), + "raises": None, + }, +] + + +class TestDecodeMessage(unittest.TestCase): + @parameterized.expand([(case["name"], case) for case in TEST_CASES]) + def test_decode_message(self, name, case): + message = case["message"] + encoding = case["encoding"] + err_mode = case["err_mode"] + expected = case["expected"] + raises = case["raises"] + + if raises: + with self.assertRaises(raises): + decode_message(message, encoding=encoding, err_mode=err_mode) + else: + decoded = decode_message(message, encoding=encoding, err_mode=err_mode) + self.assertEqual(decoded, expected) + + +@pytest.fixture +def dynamic_endpoint(): + """Fixture to dynamically generate an available endpoint with a free port.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) # Bind to an available port provided by the OS + port = s.getsockname()[1] + return f"tcp://{get_local_ip()}:{port}" + + +@pytest.fixture +def create_messenger_pair(dynamic_endpoint): + def _create_messenger_pair(mode1, mode2): + messenger1 = ZMQMessenger( + mode1, endpoint=dynamic_endpoint if mode1 in ["ROUTER", "REP"] else None + ) + messenger2 = ZMQMessenger( + mode2, endpoint=dynamic_endpoint if mode2 in ["DEALER", "REQ"] else None + ) + return messenger1, messenger2 + + yield _create_messenger_pair + + +def test_router_dealer(create_messenger_pair): + """Test ROUTER and DEALER communication.""" + router, dealer = create_messenger_pair("ROUTER", "DEALER") + + received_messages = [] + + def on_message(messages): + received_messages.extend(messages) + + router.start_listener(on_message) + + dealer.send([b"Hello, ROUTER!"]) + + time.sleep(0.1) + + assert len(received_messages) > 0 + assert b"Hello, ROUTER!" in received_messages + + router.stop() + dealer.stop() + + +def test_req_rep(create_messenger_pair): + """Test REQ and REP communication.""" + rep, req = create_messenger_pair("REP", "REQ") + + def on_message(messages): + rep.send(messages) + + rep.start_listener(on_message) + + req.send([b"Hello, REP!"]) + response = req.receive() + assert response == [b"Hello, REP!"] + + req.stop() + rep.stop() + + +if __name__ == "__main__": + unittest.main()