mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[TRTLLM-9527][feat] Python transceiver components (step 2) (#10494)
Signed-off-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
This commit is contained in:
parent
9adef4eb28
commit
944c304bbb
@ -296,7 +296,8 @@ struct BaseAgentConfig
|
||||
bool useProgThread;
|
||||
bool multiThread;
|
||||
bool useListenThread;
|
||||
unsigned int numWorkers;
|
||||
bool enableTelemetry;
|
||||
std::unordered_map<std::string, std::string> backendParams;
|
||||
};
|
||||
|
||||
class BaseTransferAgent
|
||||
|
||||
@ -247,7 +247,7 @@ AgentConnectionManager::AgentConnectionManager(
|
||||
|
||||
mAgentName = genUniqueAgentName();
|
||||
// Create Agent
|
||||
BaseAgentConfig config{mAgentName, true, false, true, 1};
|
||||
BaseAgentConfig config{mAgentName, true, false, true};
|
||||
m_Agent = makeTransferAgent(backendType, &config);
|
||||
TLLM_CHECK(!mCacheTransBufferManagers.empty());
|
||||
std::vector<MemoryDesc> memDescs;
|
||||
|
||||
@ -26,10 +26,9 @@
|
||||
#endif
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/function.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/pair.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/tuple.h>
|
||||
#include <nanobind/stl/unordered_map.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
@ -69,6 +68,21 @@ NB_MODULE(tensorrt_llm_transfer_agent_binding, m)
|
||||
// MemoryDescs class
|
||||
nb::class_<kvc::MemoryDescs>(m, "MemoryDescs")
|
||||
.def(nb::init<kvc::MemoryType, std::vector<kvc::MemoryDesc>>(), nb::arg("type"), nb::arg("descs"))
|
||||
// Batch constructor from list of tuples: [(ptr, size, device_id), ...]
|
||||
.def(
|
||||
"__init__",
|
||||
[](kvc::MemoryDescs* self, kvc::MemoryType type,
|
||||
std::vector<std::tuple<uintptr_t, size_t, uint32_t>> const& tuples)
|
||||
{
|
||||
std::vector<kvc::MemoryDesc> descs;
|
||||
descs.reserve(tuples.size());
|
||||
for (auto const& [addr, len, deviceId] : tuples)
|
||||
{
|
||||
descs.emplace_back(addr, len, deviceId);
|
||||
}
|
||||
new (self) kvc::MemoryDescs(type, std::move(descs));
|
||||
},
|
||||
nb::arg("type"), nb::arg("tuples"))
|
||||
.def_prop_ro("type", &kvc::MemoryDescs::getType)
|
||||
.def_prop_ro("descs", &kvc::MemoryDescs::getDescs);
|
||||
|
||||
@ -113,17 +127,21 @@ NB_MODULE(tensorrt_llm_transfer_agent_binding, m)
|
||||
.def(
|
||||
"__init__",
|
||||
[](kvc::BaseAgentConfig* self, std::string name, bool use_prog_thread, bool multi_thread,
|
||||
bool use_listen_thread, unsigned int num_workers) {
|
||||
new (self) kvc::BaseAgentConfig{
|
||||
std::move(name), use_prog_thread, multi_thread, use_listen_thread, num_workers};
|
||||
bool use_listen_thread, bool enable_telemetry,
|
||||
std::unordered_map<std::string, std::string> backend_params)
|
||||
{
|
||||
new (self) kvc::BaseAgentConfig{std::move(name), use_prog_thread, multi_thread, use_listen_thread,
|
||||
enable_telemetry, std::move(backend_params)};
|
||||
},
|
||||
nb::arg("name"), nb::arg("use_prog_thread") = true, nb::arg("multi_thread") = false,
|
||||
nb::arg("use_listen_thread") = false, nb::arg("num_workers") = 1)
|
||||
nb::arg("use_listen_thread") = false, nb::arg("enable_telemetry") = false,
|
||||
nb::arg("backend_params") = std::unordered_map<std::string, std::string>{})
|
||||
.def_rw("name", &kvc::BaseAgentConfig::mName)
|
||||
.def_rw("use_prog_thread", &kvc::BaseAgentConfig::useProgThread)
|
||||
.def_rw("multi_thread", &kvc::BaseAgentConfig::multiThread)
|
||||
.def_rw("use_listen_thread", &kvc::BaseAgentConfig::useListenThread)
|
||||
.def_rw("num_workers", &kvc::BaseAgentConfig::numWorkers);
|
||||
.def_rw("enable_telemetry", &kvc::BaseAgentConfig::enableTelemetry)
|
||||
.def_rw("backend_params", &kvc::BaseAgentConfig::backendParams);
|
||||
|
||||
// BaseTransferAgent class (abstract base)
|
||||
nb::class_<kvc::BaseTransferAgent>(m, "BaseTransferAgent")
|
||||
|
||||
@ -66,6 +66,19 @@ PYBIND11_MODULE(tensorrt_llm_transfer_agent_binding, m)
|
||||
// MemoryDescs class
|
||||
py::class_<kvc::MemoryDescs>(m, "MemoryDescs")
|
||||
.def(py::init<kvc::MemoryType, std::vector<kvc::MemoryDesc>>(), py::arg("type"), py::arg("descs"))
|
||||
// Batch constructor from list of tuples: [(ptr, size, device_id), ...]
|
||||
.def(py::init(
|
||||
[](kvc::MemoryType type, std::vector<std::tuple<uintptr_t, size_t, uint32_t>> const& tuples)
|
||||
{
|
||||
std::vector<kvc::MemoryDesc> descs;
|
||||
descs.reserve(tuples.size());
|
||||
for (auto const& [addr, len, deviceId] : tuples)
|
||||
{
|
||||
descs.emplace_back(addr, len, deviceId);
|
||||
}
|
||||
return kvc::MemoryDescs(type, std::move(descs));
|
||||
}),
|
||||
py::arg("type"), py::arg("tuples"))
|
||||
.def_property_readonly("type", &kvc::MemoryDescs::getType)
|
||||
.def_property_readonly("descs", &kvc::MemoryDescs::getDescs);
|
||||
|
||||
@ -108,17 +121,20 @@ PYBIND11_MODULE(tensorrt_llm_transfer_agent_binding, m)
|
||||
.def(py::init<>())
|
||||
.def(py::init(
|
||||
[](std::string name, bool use_prog_thread, bool multi_thread, bool use_listen_thread,
|
||||
unsigned int num_workers) {
|
||||
return kvc::BaseAgentConfig{
|
||||
std::move(name), use_prog_thread, multi_thread, use_listen_thread, num_workers};
|
||||
bool enable_telemetry, std::unordered_map<std::string, std::string> backend_params)
|
||||
{
|
||||
return kvc::BaseAgentConfig{std::move(name), use_prog_thread, multi_thread, use_listen_thread,
|
||||
enable_telemetry, std::move(backend_params)};
|
||||
}),
|
||||
py::arg("name"), py::arg("use_prog_thread") = true, py::arg("multi_thread") = false,
|
||||
py::arg("use_listen_thread") = false, py::arg("num_workers") = 1)
|
||||
py::arg("use_listen_thread") = false, py::arg("enable_telemetry") = false,
|
||||
py::arg("backend_params") = std::unordered_map<std::string, std::string>{})
|
||||
.def_readwrite("name", &kvc::BaseAgentConfig::mName)
|
||||
.def_readwrite("use_prog_thread", &kvc::BaseAgentConfig::useProgThread)
|
||||
.def_readwrite("multi_thread", &kvc::BaseAgentConfig::multiThread)
|
||||
.def_readwrite("use_listen_thread", &kvc::BaseAgentConfig::useListenThread)
|
||||
.def_readwrite("num_workers", &kvc::BaseAgentConfig::numWorkers);
|
||||
.def_readwrite("enable_telemetry", &kvc::BaseAgentConfig::enableTelemetry)
|
||||
.def_readwrite("backend_params", &kvc::BaseAgentConfig::backendParams);
|
||||
|
||||
// BaseTransferAgent class (abstract base)
|
||||
py::class_<kvc::BaseTransferAgent>(m, "BaseTransferAgent")
|
||||
|
||||
@ -374,22 +374,28 @@ NixlTransferAgent::NixlTransferAgent(BaseAgentConfig const& config)
|
||||
}
|
||||
auto envPort = common::getEnvNixlPort();
|
||||
uint16_t port = envPort > 0 ? getIncrmentPort(envPort) : getAvailablePort();
|
||||
nixlAgentConfig nixlConfig{
|
||||
config.useProgThread, true, port, nixl_thread_sync_t::NIXL_THREAD_SYNC_DEFAULT, config.numWorkers};
|
||||
uint32_t numWorker = config.backendParams.find("num_workers") != config.backendParams.end()
|
||||
? std::stoi(config.backendParams.at("num_workers"))
|
||||
: 1;
|
||||
nixlAgentConfig nixlConfig{config.useProgThread, true, port, nixl_thread_sync_t::NIXL_THREAD_SYNC_DEFAULT,
|
||||
numWorker, 0, 10000, config.enableTelemetry};
|
||||
mAddress = getAvailableIP() + ":" + std::to_string(port);
|
||||
mRawAgent = std::make_unique<nixlAgent>(config.mName, std::move(nixlConfig));
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32_t numWorker = config.backendParams.find("num_workers") != config.backendParams.end()
|
||||
? std::stoi(config.backendParams.at("num_workers"))
|
||||
: 1;
|
||||
mAddress.clear();
|
||||
nixlAgentConfig nixlConfig{
|
||||
config.useProgThread, false, 0, nixl_thread_sync_t::NIXL_THREAD_SYNC_DEFAULT, config.numWorkers};
|
||||
nixlAgentConfig nixlConfig{config.useProgThread, false, 0, nixl_thread_sync_t::NIXL_THREAD_SYNC_DEFAULT,
|
||||
numWorker, 0, 10000, config.enableTelemetry};
|
||||
mRawAgent = std::make_unique<nixlAgent>(config.mName, std::move(nixlConfig));
|
||||
}
|
||||
|
||||
std::string nixlBackend = common::getEnvNixlBackend();
|
||||
// List of supported backends - extend this list as new backends are added
|
||||
static const std::set<std::string> kSUPPORTED_BACKENDS = {"UCX", "LIBFABRIC"};
|
||||
static std::set<std::string> const kSUPPORTED_BACKENDS = {"UCX", "LIBFABRIC"};
|
||||
|
||||
if (kSUPPORTED_BACKENDS.find(nixlBackend) == kSUPPORTED_BACKENDS.end())
|
||||
{
|
||||
@ -400,6 +406,11 @@ NixlTransferAgent::NixlTransferAgent(BaseAgentConfig const& config)
|
||||
TLLM_LOG_INFO("NixlTransferAgent::NixlTransferAgent using NIXL backend: %s", nixlBackend.c_str());
|
||||
|
||||
nixl_b_params_t init1;
|
||||
for (auto const& [key, value] : config.backendParams)
|
||||
{
|
||||
init1[key] = value;
|
||||
TLLM_LOG_INFO("NixlTransferAgent::NixlTransferAgent backendParams: %s: %s", key.c_str(), value.c_str());
|
||||
}
|
||||
nixl_mem_list_t mems1;
|
||||
status = mRawAgent->getPluginParams(nixlBackend.c_str(), mems1, init1);
|
||||
TLLM_CHECK(status == NIXL_SUCCESS);
|
||||
|
||||
@ -92,7 +92,7 @@ TEST_P(TransferAgentTest, Basic)
|
||||
{
|
||||
|
||||
std::string const agent0{"agent0"}, agent1{"agent1"};
|
||||
BaseAgentConfig config0{agent0, true, false, true, 1}, config1{agent1, true, false, true, 1};
|
||||
BaseAgentConfig config0{agent0, true, false, true}, config1{agent1, true, false, true};
|
||||
auto xferAgent0 = makeTransferAgent(config0);
|
||||
auto xferAgent1 = makeTransferAgent(config1);
|
||||
|
||||
@ -127,7 +127,7 @@ TEST_P(TransferAgentTest, Basic2)
|
||||
{
|
||||
|
||||
std::string const agent0{"agent0"}, agent1{"agent1"};
|
||||
BaseAgentConfig config0{agent0, true, false, true, 1}, config1{agent1, true, false, true, 1};
|
||||
BaseAgentConfig config0{agent0, true, false, true}, config1{agent1, true, false, true};
|
||||
auto xferAgent0 = makeTransferAgent(config0);
|
||||
auto xferAgent1 = makeTransferAgent(config1);
|
||||
|
||||
@ -162,7 +162,7 @@ TEST_P(TransferAgentTest, DeviceMemory)
|
||||
{
|
||||
|
||||
std::string const agent0{"agent0"}, agent1{"agent1"};
|
||||
BaseAgentConfig config0{agent0, true, false, true, 1}, config1{agent1, true, false, true, 1};
|
||||
BaseAgentConfig config0{agent0, true, false, true}, config1{agent1, true, false, true};
|
||||
auto xferAgent0 = makeTransferAgent(config0);
|
||||
auto xferAgent1 = makeTransferAgent(config1);
|
||||
|
||||
@ -209,8 +209,8 @@ TEST_P(TransferAgentTest, Connect)
|
||||
{
|
||||
|
||||
std::string const agent0{"agent0"}, agent1{"agent1"}, agent2{"agent2"};
|
||||
BaseAgentConfig config0{agent0, true, false, true, 1}, config1{agent1, true, false, true, 1},
|
||||
config2{agent2, true, false, true, 1};
|
||||
BaseAgentConfig config0{agent0, true, false, true}, config1{agent1, true, false, true},
|
||||
config2{agent2, true, false, true};
|
||||
auto xferAgent0 = makeTransferAgent(config0);
|
||||
auto xferAgent1 = makeTransferAgent(config1);
|
||||
auto xferAgent2 = makeTransferAgent(config2);
|
||||
@ -261,7 +261,7 @@ TEST_P(TransferAgentTest, SyncMessage)
|
||||
{
|
||||
constexpr std::size_t MAX_QUERY_TIMES = std::numeric_limits<size_t>::max();
|
||||
std::string const agent0{"agent0"}, agent1{"agent1"};
|
||||
BaseAgentConfig config0{agent0, true, false, true, 1}, config1{agent1, true, false, true, 1};
|
||||
BaseAgentConfig config0{agent0, true, false, true}, config1{agent1, true, false, true};
|
||||
auto xferAgent0 = makeTransferAgent(config0);
|
||||
auto xferAgent1 = makeTransferAgent(config1);
|
||||
|
||||
|
||||
@ -38,3 +38,4 @@ opentelemetry-semantic-conventions-ai>=0.4.1
|
||||
fuzzywuzzy==0.18.0
|
||||
aiperf==0.3.0
|
||||
nanobind>=2.9.0
|
||||
nixl==0.8.0
|
||||
|
||||
0
tensorrt_llm/_torch/disaggregation/__init__.py
Normal file
0
tensorrt_llm/_torch/disaggregation/__init__.py
Normal file
0
tensorrt_llm/_torch/disaggregation/base/__init__.py
Normal file
0
tensorrt_llm/_torch/disaggregation/base/__init__.py
Normal file
145
tensorrt_llm/_torch/disaggregation/base/agent.py
Normal file
145
tensorrt_llm/_torch/disaggregation/base/agent.py
Normal file
@ -0,0 +1,145 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import List, NamedTuple, Optional
|
||||
|
||||
from tensorrt_llm import logger
|
||||
|
||||
|
||||
# We deliberately use a non-enum data structure here. This choice ensures that
|
||||
# members are directly equivalent to the plain strings.
|
||||
class TransferOp:
|
||||
READ = "READ"
|
||||
WRITE = "WRITE"
|
||||
|
||||
|
||||
class MemoryType:
|
||||
DRAM = "DRAM"
|
||||
VRAM = "VRAM"
|
||||
BLK = "BLK"
|
||||
OBJ = "OBJ"
|
||||
FILE = "FILE"
|
||||
|
||||
|
||||
class MemoryDesc(NamedTuple):
|
||||
ptr: int
|
||||
size: int
|
||||
device_id: int
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryDescs:
|
||||
type: str
|
||||
descs: List[MemoryDesc]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransferRequest:
|
||||
op: TransferOp
|
||||
src_descs: MemoryDescs
|
||||
dst_descs: MemoryDescs
|
||||
remote_name: str
|
||||
sync_message: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegMemoryDescs:
|
||||
type: str
|
||||
descs: List[MemoryDesc]
|
||||
|
||||
|
||||
class TransferStatus(ABC):
|
||||
@abstractmethod
|
||||
def is_completed(self) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
def wait(self, timeout_ms: int | None = None) -> bool: ...
|
||||
|
||||
|
||||
class BaseTransferAgent(ABC):
|
||||
@abstractmethod
|
||||
def register_memory(self, descs: RegMemoryDescs) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def deregister_memory(self, descs: RegMemoryDescs) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def load_remote_agent(self, name: str, agent_desc: bytes) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_local_agent_desc(self) -> bytes: ...
|
||||
|
||||
@abstractmethod
|
||||
def invalidate_remote_agent(self, name: str) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus: ...
|
||||
|
||||
@abstractmethod
|
||||
def notify_sync_message(self, name: str, sync_message: str) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def check_remote_descs(self, name: str, memory_descs: MemoryDescs) -> bool: ...
|
||||
|
||||
|
||||
def _force_py_nixl_kv_transfer() -> bool:
|
||||
env_value = os.getenv("TRTLLM_USE_PY_NIXL_KVCACHE", "0")
|
||||
if env_value not in {"0", "1"}:
|
||||
logger.warning(
|
||||
f"Invalid value for TRTLLM_USE_PY_NIXL_KVCACHE: {env_value}. Expected '0' or '1'. Defaulting to '0'."
|
||||
)
|
||||
return False
|
||||
if env_value == "1":
|
||||
logger.info("Forcing use of pure Python NIXL KV Transfer Agent implementation.")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _try_load_cpp_binding():
|
||||
try:
|
||||
import tensorrt_llm.tensorrt_llm_transfer_agent_binding as _cpp_binding
|
||||
|
||||
required_attributes = [
|
||||
"MemoryType",
|
||||
"TransferOp",
|
||||
"MemoryDesc",
|
||||
"MemoryDescs",
|
||||
"TransferRequest",
|
||||
"TransferStatus",
|
||||
"BaseTransferAgent",
|
||||
]
|
||||
if all(hasattr(_cpp_binding, attr) for attr in required_attributes):
|
||||
return _cpp_binding
|
||||
except ImportError:
|
||||
logger.info("tensorrt_llm_transfer_agent_binding module not found.")
|
||||
return None
|
||||
|
||||
|
||||
_use_pure_python_transfer_agent = None
|
||||
|
||||
# The current implementation still implicitly depends on cpp_bindings.
|
||||
# We should remove this dependency.
|
||||
_cpp_binding = _try_load_cpp_binding()
|
||||
|
||||
if _force_py_nixl_kv_transfer():
|
||||
logger.info("Using pure Python transfer agent (forced by TRTLLM_USE_PY_NIXL_KVCACHE)")
|
||||
_use_pure_python_transfer_agent = True
|
||||
else:
|
||||
if _cpp_binding:
|
||||
MemoryType = _cpp_binding.MemoryType
|
||||
TransferOp = _cpp_binding.TransferOp
|
||||
MemoryDesc = _cpp_binding.MemoryDesc
|
||||
MemoryDescs = _cpp_binding.MemoryDescs
|
||||
TransferRequest = _cpp_binding.TransferRequest
|
||||
TransferStatus = _cpp_binding.TransferStatus
|
||||
BaseTransferAgent = _cpp_binding.BaseTransferAgent
|
||||
logger.info("Using C++ transfer agent binding")
|
||||
_use_pure_python_transfer_agent = False
|
||||
else:
|
||||
logger.info("C++ transfer agent binding unavailable, using pure Python implementation")
|
||||
_use_pure_python_transfer_agent = True
|
||||
|
||||
|
||||
def use_pure_python_transfer_agent() -> bool:
|
||||
return _use_pure_python_transfer_agent
|
||||
200
tensorrt_llm/_torch/disaggregation/base/kv_transfer.py
Normal file
200
tensorrt_llm/_torch/disaggregation/base/kv_transfer.py
Normal file
@ -0,0 +1,200 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from tensorrt_llm import DisaggregatedParams
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenRange:
|
||||
"""Range of tokens in the sequence dimension."""
|
||||
|
||||
start: int
|
||||
end: int # exclusive
|
||||
|
||||
def __post_init__(self):
|
||||
if self.start < 0 or self.end < 0:
|
||||
raise ValueError("Token indices must be non-negative")
|
||||
if self.start >= self.end:
|
||||
raise ValueError(f"Invalid range: [{self.start}, {self.end})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerRange:
|
||||
"""Range of layers to transfer."""
|
||||
|
||||
start: int
|
||||
end: int # exclusive
|
||||
|
||||
def __post_init__(self):
|
||||
if self.start < 0 or self.end < 0:
|
||||
raise ValueError("Layer indices must be non-negative")
|
||||
if self.start >= self.end:
|
||||
raise ValueError(f"Invalid range: [{self.start}, {self.end})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVSlice:
|
||||
"""
|
||||
Specifies which portion of KV cache to transfer.
|
||||
"""
|
||||
|
||||
token_range: Optional[TokenRange] = None
|
||||
layer_range: Optional[LayerRange] = None
|
||||
block_ids: List[int] = field(default_factory=list) # Physical block IDs
|
||||
is_last_slice: bool = False
|
||||
|
||||
|
||||
class SessionStatus(Enum):
|
||||
"""Status of a transfer session.
|
||||
|
||||
Represents the various stages/statuses that a file transfer session can go through:
|
||||
|
||||
- INIT: The session has been initialized but not yet ready.
|
||||
- READY: The session is ready to start transferring.
|
||||
- TRANSFERRING: The session is in progress, currently transferring data.
|
||||
- TRANSFERRED: The primary transfer has completed successfully.
|
||||
- AUX_TRANSFERRED: The auxiliary part (such as tokens) of the transfer has completed successfully.
|
||||
- COMPLETED: The entire session process, including all transfers, has been successfully completed.
|
||||
- CANCELED: The session has been canceled by the user or system.
|
||||
- ERROR: An error occurred during the session. The session could not complete successfully.
|
||||
"""
|
||||
|
||||
INIT = "INIT"
|
||||
READY = "READY"
|
||||
TRANSFERRING = "TRANSFERRING"
|
||||
TRANSFERRED = "TRANSFERRED"
|
||||
AUX_TRANSFERRED = "AUX_TRANSFERRED"
|
||||
COMPLETED = "COMPLETED"
|
||||
CANCELED = "CANCELED"
|
||||
ERROR = "ERROR"
|
||||
|
||||
|
||||
TaskIdType = int
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionState:
|
||||
"""State of a transfer session."""
|
||||
|
||||
status: SessionStatus
|
||||
finished_tasks: List[TaskIdType]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionArgsBase:
|
||||
"""Base arguments for transfer sessions."""
|
||||
|
||||
params: DisaggregatedParams
|
||||
|
||||
|
||||
class SenderBase(ABC):
|
||||
"""Base class for sending KV cache data."""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class ReceiverBase(ABC):
|
||||
"""Base class for receiving KV cache data."""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class TxSessionBase(ABC):
|
||||
def __init__(self, sender: SenderBase, args: SessionArgsBase):
|
||||
"""
|
||||
Initializes the transmission session.
|
||||
:param sender: The sender instance responsible for sending data.
|
||||
:param args: The session arguments.
|
||||
"""
|
||||
self._sender = sender
|
||||
self._base_args = args
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def state(self) -> SessionState:
|
||||
"""
|
||||
Returns the current state of the session.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def poll_task(self, id: TaskIdType) -> SessionStatus:
|
||||
"""
|
||||
Polls the status of a specific task by its ID.
|
||||
:param id: The task ID to poll.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def send(self, slice: KVSlice) -> TaskIdType:
|
||||
"""
|
||||
Sends a slice of KV cache data and returns the task ID.
|
||||
:param slice: The KV slice to send.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def exception(self) -> Optional[Exception]:
|
||||
"""
|
||||
Returns any exception that occurred during the session.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Closes the session and releases any resources.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class RxSessionBase(ABC):
|
||||
def __init__(self, receiver: ReceiverBase, args: SessionArgsBase):
|
||||
"""
|
||||
Initializes the reception session.
|
||||
:param receiver: The receiver instance responsible for receiving data.
|
||||
"""
|
||||
self._receiver = receiver
|
||||
self._base_args = args
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def state(self) -> SessionState:
|
||||
"""
|
||||
Returns the current state of the session.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def poll_task(self, task_id: TaskIdType) -> SessionStatus:
|
||||
"""
|
||||
Polls the status of a specific task by its ID.
|
||||
:param task_id: The task ID to poll.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def receive(self, slice: KVSlice) -> TaskIdType:
|
||||
"""
|
||||
Receives a slice of KV cache data and returns the task ID.
|
||||
:param slice: The KV slice to receive.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def exception(self) -> Optional[Exception]:
|
||||
"""Returns any exception that occurred during the session."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Closes the session and releases any resources.
|
||||
"""
|
||||
...
|
||||
219
tensorrt_llm/_torch/disaggregation/native/messenger.py
Normal file
219
tensorrt_llm/_torch/disaggregation/native/messenger.py
Normal file
@ -0,0 +1,219 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Callable, Optional
|
||||
|
||||
import zmq
|
||||
|
||||
from tensorrt_llm import logger
|
||||
from tensorrt_llm._torch.disaggregation.native.utils import get_local_ip
|
||||
|
||||
|
||||
class MessengerInterface(ABC):
|
||||
"""
|
||||
Abstract base class for messenger implementations.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def start(self) -> None:
|
||||
"""
|
||||
Start the messenger service.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def send(self, messages: list[bytes], recipient: Optional[bytes] = None) -> None:
|
||||
"""
|
||||
Send messages to a recipient.
|
||||
:param messages: List of byte messages to send.
|
||||
:param recipient: Optional recipient identifier.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def receive(self) -> list[bytes]:
|
||||
"""
|
||||
Receive messages.
|
||||
:return: List of byte messages received.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def start_listener(self, on_message: Callable[[list[bytes]], Optional[bool]]) -> None:
|
||||
"""
|
||||
Start a listener thread to handle incoming messages.
|
||||
:param on_message: Callback function to process received messages.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
Stop the messenger service.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def endpoint(self) -> str:
|
||||
"""
|
||||
Get the endpoint of the messenger.
|
||||
:return: Endpoint string.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
def decode_message(
|
||||
message: list[bytes], encoding: str = "ascii", err_mode: str = "strict"
|
||||
) -> tuple:
|
||||
if not isinstance(message, list) or not all(isinstance(m, bytes) for m in message):
|
||||
raise ValueError("Input must be a list of bytes")
|
||||
return tuple(m.decode(encoding, errors=err_mode) for m in message)
|
||||
|
||||
|
||||
class ZMQMessenger(MessengerInterface):
|
||||
SOCKET_MODES = {
|
||||
"ROUTER": zmq.ROUTER, # Handles multiple connections and routes messages by address.
|
||||
"DEALER": zmq.DEALER, # Load balances outgoing messages and receives replies fairly.
|
||||
"REQ": zmq.REQ, # Sends requests and waits for replies (synchronous).
|
||||
"REP": zmq.REP, # Receives requests and sends replies (synchronous).
|
||||
}
|
||||
|
||||
def __init__(self, mode: str, endpoint: Optional[str] = None) -> None:
|
||||
if mode not in self.SOCKET_MODES:
|
||||
raise ValueError(
|
||||
f"Invalid mode '{mode}'. Allowed modes are {list(self.SOCKET_MODES.keys())}"
|
||||
)
|
||||
self._context = zmq.Context()
|
||||
self._mode = mode
|
||||
self._socket = self._context.socket(self.SOCKET_MODES[mode])
|
||||
self._endpoint: Optional[str] = None
|
||||
self._lock = Lock()
|
||||
self._closed = False
|
||||
self._stop_event = Event()
|
||||
self._listener_thread: Optional[Thread] = None
|
||||
self._initialize_control_sockets()
|
||||
|
||||
if endpoint is None:
|
||||
if mode in ["DEALER", "REQ"]:
|
||||
raise ValueError("endpoint is required for DEALER/REQ modes")
|
||||
endpoint = f"tcp://{get_local_ip()}:*"
|
||||
|
||||
if mode in ["ROUTER", "REP"]:
|
||||
self._socket.bind(endpoint)
|
||||
self._endpoint = self._socket.getsockopt_string(zmq.LAST_ENDPOINT)
|
||||
elif mode in ["DEALER", "REQ"]:
|
||||
self._socket.connect(endpoint)
|
||||
self._endpoint = endpoint
|
||||
|
||||
logger.info(f"Initialized ZMQMessenger(mode={mode}, endpoint={self._endpoint})")
|
||||
|
||||
def _initialize_control_sockets(self) -> None:
|
||||
self._control_socket = self._context.socket(zmq.PAIR)
|
||||
self._internal_socket = self._context.socket(zmq.PAIR)
|
||||
inproc_endpoint = "inproc://stop_listener"
|
||||
self._control_socket.bind(inproc_endpoint)
|
||||
self._internal_socket.connect(inproc_endpoint)
|
||||
|
||||
def start(self) -> None:
|
||||
pass
|
||||
|
||||
def send(self, messages: list[bytes], recipient: Optional[bytes] = None) -> None:
|
||||
if recipient:
|
||||
self._socket.send_multipart([recipient] + messages)
|
||||
else:
|
||||
self._socket.send_multipart(messages)
|
||||
|
||||
def receive(self) -> list[bytes]:
|
||||
return self._socket.recv_multipart()
|
||||
|
||||
def start_listener(
|
||||
self,
|
||||
on_message: Callable[[list[bytes]], Optional[bool]],
|
||||
on_error: Optional[Callable[[Exception], None]] = None,
|
||||
) -> None:
|
||||
assert self._mode in ["ROUTER", "REP"], (
|
||||
"Listener can only be started in ROUTER or REP modes"
|
||||
)
|
||||
if self._listener_thread and self._listener_thread.is_alive():
|
||||
raise RuntimeError("Listener already running")
|
||||
|
||||
def handle_listener_exceptions(
|
||||
exception: Exception, on_error: Optional[Callable[[Exception], None]]
|
||||
) -> None:
|
||||
logger.error(f"Error in listener: {exception}")
|
||||
if on_error:
|
||||
on_error(exception)
|
||||
else:
|
||||
self._stop_event.set()
|
||||
|
||||
def listener() -> None:
|
||||
poller = zmq.Poller()
|
||||
poller.register(self._socket, zmq.POLLIN)
|
||||
poller.register(self._control_socket, zmq.POLLIN)
|
||||
|
||||
while not self._stop_event.is_set():
|
||||
events = dict(poller.poll(timeout=100))
|
||||
try:
|
||||
if self._control_socket in events:
|
||||
self._stop_event.set()
|
||||
elif self._socket in events:
|
||||
messages = self.receive()
|
||||
persist = on_message(messages)
|
||||
if persist is False:
|
||||
self._stop_event.set()
|
||||
except zmq.ZMQError as e:
|
||||
handle_listener_exceptions(e, on_error)
|
||||
break
|
||||
except Exception as e:
|
||||
handle_listener_exceptions(e, on_error)
|
||||
break
|
||||
|
||||
self._stop_event.set()
|
||||
|
||||
self._listener_thread = Thread(target=listener, daemon=True)
|
||||
self._listener_thread.start()
|
||||
|
||||
def stop(self, timeout: int = 5) -> None:
|
||||
def _close_socket(socket: zmq.Socket) -> None:
|
||||
try:
|
||||
if not socket.closed:
|
||||
socket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing socket: {e}")
|
||||
|
||||
with self._lock:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
logger.debug("Stopping ZMQMessenger...")
|
||||
|
||||
self._stop_event.set()
|
||||
self._internal_socket.send(b"STOP")
|
||||
if self._listener_thread:
|
||||
self._internal_socket.send(b"STOP")
|
||||
self._listener_thread.join(timeout)
|
||||
if self._listener_thread.is_alive():
|
||||
logger.warning("Listener thread did not terminate within timeout")
|
||||
|
||||
_close_socket(self._socket)
|
||||
_close_socket(self._internal_socket)
|
||||
_close_socket(self._control_socket)
|
||||
|
||||
try:
|
||||
if not self._context.closed:
|
||||
self._context.term()
|
||||
except Exception as e:
|
||||
logger.error(f"Error terminating ZMQ context: {e}")
|
||||
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
assert self._endpoint is not None
|
||||
return self._endpoint
|
||||
|
||||
def __enter__(self) -> "ZMQMessenger":
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self, exc_type: Optional[type], exc_val: Optional[BaseException], exc_tb: Optional
|
||||
) -> None:
|
||||
self.stop()
|
||||
41
tensorrt_llm/_torch/disaggregation/native/utils.py
Normal file
41
tensorrt_llm/_torch/disaggregation/native/utils.py
Normal file
@ -0,0 +1,41 @@
|
||||
from tensorrt_llm import logger
|
||||
|
||||
|
||||
def get_local_ip() -> str:
|
||||
try:
|
||||
import socket
|
||||
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||
s.connect(("10.255.255.255", 1))
|
||||
ip = s.getsockname()[0]
|
||||
if not ip.startswith("127."):
|
||||
return ip
|
||||
|
||||
except (ImportError, OSError, ValueError):
|
||||
logger.error("Failed to get local IP via UDP socket method.")
|
||||
pass
|
||||
|
||||
try:
|
||||
import netifaces
|
||||
|
||||
for iface in netifaces.interfaces():
|
||||
addrs = netifaces.ifaddresses(iface)
|
||||
if netifaces.AF_INET in addrs:
|
||||
for addr in addrs[netifaces.AF_INET]:
|
||||
ip = addr.get("addr", "")
|
||||
if not ip.startswith("127.") and not ip.startswith("169.254"):
|
||||
return ip
|
||||
except (ImportError, OSError, ValueError) as e:
|
||||
logger.error(f"Failed to get local IP via netifaces: {e}")
|
||||
pass
|
||||
|
||||
try:
|
||||
hostname = socket.gethostname()
|
||||
ip = socket.gethostbyname(hostname)
|
||||
if not ip.startswith("127."):
|
||||
return ip
|
||||
except (OSError, ValueError):
|
||||
logger.error("Failed to get local IP via hostname resolution.")
|
||||
pass
|
||||
|
||||
return "127.0.0.1"
|
||||
0
tensorrt_llm/_torch/disaggregation/nixl/__init__.py
Normal file
0
tensorrt_llm/_torch/disaggregation/nixl/__init__.py
Normal file
147
tensorrt_llm/_torch/disaggregation/nixl/_agent_cpp.py
Normal file
147
tensorrt_llm/_torch/disaggregation/nixl/_agent_cpp.py
Normal file
@ -0,0 +1,147 @@
|
||||
from tensorrt_llm._utils import nvtx_range
|
||||
from tensorrt_llm.tensorrt_llm_transfer_agent_binding import ( # noqa: E402
|
||||
AgentDesc,
|
||||
BaseAgentConfig,
|
||||
MemoryDescs,
|
||||
MemoryType,
|
||||
TransferState,
|
||||
)
|
||||
from tensorrt_llm.tensorrt_llm_transfer_agent_binding import (
|
||||
NixlTransferAgent as CppNixlTransferAgent,
|
||||
)
|
||||
|
||||
from ..base.agent import BaseTransferAgent, RegMemoryDescs, TransferRequest, TransferStatus
|
||||
|
||||
|
||||
class BindingsNixlTransferStatus(TransferStatus):
|
||||
"""TransferStatus wrapper using C++ bindings with GIL release."""
|
||||
|
||||
def __init__(self, cpp_status):
|
||||
self._cpp_status = cpp_status
|
||||
|
||||
def is_completed(self) -> bool:
|
||||
"""Check if transfer is completed (releases GIL)."""
|
||||
return self._cpp_status.is_completed()
|
||||
|
||||
@nvtx_range("BindingsNixlTransferStatus.wait")
|
||||
def wait(self, timeout_ms=None) -> bool:
|
||||
"""Wait for transfer to complete (releases GIL)."""
|
||||
if timeout_ms is None:
|
||||
timeout_ms = -1
|
||||
return self._cpp_status.wait(timeout_ms) == TransferState.SUCCESS
|
||||
|
||||
|
||||
class BindingsNixlTransferAgent(BaseTransferAgent):
|
||||
"""NixlTransferAgent using C++ bindings with GIL release support.
|
||||
|
||||
This implementation uses the standalone nixl_bindings C++ module which releases
|
||||
the GIL during blocking operations like wait().
|
||||
|
||||
The nixl_bindings module is independent from the main trtllm bindings,
|
||||
so trtllm can function normally even without NIXL.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
use_prog_thread: bool = True,
|
||||
num_threads: int = 1,
|
||||
enable_telemetry: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
backend_params = kwargs
|
||||
for key, value in backend_params.items():
|
||||
backend_params[key] = str(value)
|
||||
backend_params["num_threads"] = str(num_threads)
|
||||
|
||||
config = BaseAgentConfig(
|
||||
name,
|
||||
use_prog_thread,
|
||||
multi_thread=False,
|
||||
use_listen_thread=False,
|
||||
enable_telemetry=enable_telemetry,
|
||||
backend_params=backend_params,
|
||||
)
|
||||
self._cpp_agent = CppNixlTransferAgent(config)
|
||||
self.name = name
|
||||
|
||||
def register_memory(self, descs: RegMemoryDescs):
|
||||
"""Register memory regions."""
|
||||
cpp_descs = self._convert_reg_memory_descs(descs)
|
||||
self._cpp_agent.register_memory(cpp_descs)
|
||||
|
||||
def deregister_memory(self, descs: RegMemoryDescs):
|
||||
"""Deregister memory regions."""
|
||||
cpp_descs = self._convert_reg_memory_descs(descs)
|
||||
self._cpp_agent.deregister_memory(cpp_descs)
|
||||
|
||||
def load_remote_agent(self, name: str, agent_desc: bytes):
|
||||
"""Load a remote agent by its descriptor (bytes)."""
|
||||
# AgentDesc expects std::string which can hold binary data
|
||||
desc_str = agent_desc if isinstance(agent_desc, bytes) else agent_desc.encode()
|
||||
cpp_desc = AgentDesc(desc_str)
|
||||
self._cpp_agent.load_remote_agent(name, cpp_desc)
|
||||
|
||||
def load_remote_agent_by_connection(self, name: str, connection_info: str):
|
||||
"""Load a remote agent by connection info."""
|
||||
self._cpp_agent.load_remote_agent_by_connection(name, connection_info)
|
||||
|
||||
def get_local_agent_desc(self) -> bytes:
|
||||
"""Get the local agent descriptor as bytes."""
|
||||
agent_desc = self._cpp_agent.get_local_agent_desc()
|
||||
return agent_desc.backend_agent_desc # Returns bytes
|
||||
|
||||
def get_local_connection_info(self) -> str:
|
||||
"""Get the local connection info."""
|
||||
return self._cpp_agent.get_local_connection_info()
|
||||
|
||||
def invalidate_remote_agent(self, name: str):
|
||||
"""Invalidate a remote agent."""
|
||||
self._cpp_agent.invalidate_remote_agent(name)
|
||||
|
||||
def check_remote_descs(self, name: str, memory_descs: MemoryDescs) -> bool:
|
||||
"""Check if remote descriptors are available.
|
||||
|
||||
memory_descs should be C++ MemoryDescs type.
|
||||
"""
|
||||
return self._cpp_agent.check_remote_descs(name, memory_descs)
|
||||
|
||||
def notify_sync_message(self, name: str, sync_message: str):
|
||||
"""Send a sync message to a remote agent."""
|
||||
self._cpp_agent.notify_sync_message(name, sync_message)
|
||||
|
||||
def get_notified_sync_messages(self):
|
||||
"""Get notified sync messages."""
|
||||
return self._cpp_agent.get_notified_sync_messages()
|
||||
|
||||
@nvtx_range("BindingsNixlTransferAgent.submit_transfer_requests")
|
||||
def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus:
|
||||
"""Submit transfer requests and return status.
|
||||
|
||||
request should be a C++ TransferRequest (from tensorrt_llm_transfer_agent_binding).
|
||||
"""
|
||||
cpp_status = self._cpp_agent.submit_transfer_requests(request)
|
||||
return BindingsNixlTransferStatus(cpp_status)
|
||||
|
||||
def _convert_reg_memory_descs(self, descs: RegMemoryDescs) -> "MemoryDescs":
|
||||
"""Convert Python RegMemoryDescs to C++ MemoryDescs.
|
||||
|
||||
RegMemoryDescs.descs is List[Tuple[int, int, int, str]] = (ptr, size, device_id, name)
|
||||
Extract first 3 elements for C++ batch constructor.
|
||||
"""
|
||||
mem_type = self._convert_memory_type(descs.type)
|
||||
# Extract (ptr, size, device_id) from 4-tuple, discard name
|
||||
tuples = [(d[0], d[1], d[2]) for d in descs.descs]
|
||||
return MemoryDescs(mem_type, tuples)
|
||||
|
||||
def _convert_memory_type(self, py_type: str) -> "MemoryType":
|
||||
"""Convert Python memory type string to C++ MemoryType."""
|
||||
type_map = {
|
||||
"DRAM": MemoryType.DRAM,
|
||||
"VRAM": MemoryType.VRAM,
|
||||
"GPU": MemoryType.VRAM,
|
||||
"BLK": MemoryType.BLK,
|
||||
"OBJ": MemoryType.OBJ,
|
||||
"FILE": MemoryType.FILE,
|
||||
}
|
||||
return type_map.get(py_type.upper(), MemoryType.VRAM)
|
||||
114
tensorrt_llm/_torch/disaggregation/nixl/_agent_py.py
Normal file
114
tensorrt_llm/_torch/disaggregation/nixl/_agent_py.py
Normal file
@ -0,0 +1,114 @@
|
||||
import time
|
||||
from enum import Enum
|
||||
|
||||
from nixl import nixl_agent, nixl_agent_config, nixl_xfer_handle
|
||||
|
||||
from tensorrt_llm._utils import nvtx_range
|
||||
|
||||
# Import base classes for type compatibility
|
||||
from ..base.agent import BaseTransferAgent, RegMemoryDescs, TransferRequest, TransferStatus
|
||||
|
||||
|
||||
class TransferState(Enum):
|
||||
PENDING = "PENDING"
|
||||
PROCESSING = "PROC"
|
||||
DONE = "DONE"
|
||||
ERROR = "ERROR"
|
||||
|
||||
|
||||
class NixlTransferStatus(TransferStatus):
|
||||
def __init__(self, agent: nixl_agent, handle: nixl_xfer_handle):
|
||||
self.agent = agent
|
||||
self.handle = handle
|
||||
|
||||
def is_completed(self):
|
||||
status = TransferState(self.agent.check_xfer_state(self.handle))
|
||||
return status == TransferState.DONE
|
||||
|
||||
def wait(self, timeout_ms=None):
|
||||
start_time = time.time()
|
||||
status = TransferState.PENDING
|
||||
sleep_time = 0.0001 # 0.1ms in seconds
|
||||
max_sleep_time = 0.01 # 10ms in seconds
|
||||
|
||||
timeout = timeout_ms / 1000 if timeout_ms is not None else None
|
||||
|
||||
while status in (TransferState.PENDING, TransferState.PROCESSING):
|
||||
status = TransferState(self.agent.check_xfer_state(self.handle))
|
||||
if status == TransferState.ERROR:
|
||||
return False # Transfer failed
|
||||
if timeout is not None and (time.time() - start_time > timeout):
|
||||
return False # Timeout
|
||||
time.sleep(sleep_time)
|
||||
sleep_time = min(sleep_time * 2, max_sleep_time)
|
||||
return status == TransferState.DONE
|
||||
|
||||
|
||||
class NixlTransferAgent(BaseTransferAgent):
|
||||
"""NixlTransferAgent using Python nixl library."""
|
||||
|
||||
def __init__(self, name: str, use_prog_thread: bool = True, num_threads: int = 1, **kwargs):
|
||||
"""
|
||||
Initialize NixlTransferAgent.
|
||||
:param name: Name of the agent.
|
||||
:param use_prog_thread: Whether to enable the progress thread, if available.
|
||||
:param num_workers: Specify number of threads for the supported multi-threaded backends.
|
||||
"""
|
||||
self.name = name
|
||||
self.backends = ["UCX"]
|
||||
agent_config = nixl_agent_config(
|
||||
enable_prog_thread=use_prog_thread, backends=self.backends, num_threads=num_threads
|
||||
)
|
||||
self.agent = nixl_agent(name, agent_config)
|
||||
|
||||
def register_memory(self, descs: RegMemoryDescs):
|
||||
if not descs.descs:
|
||||
raise ValueError("descs.descs must not be empty")
|
||||
if isinstance(descs.descs[0], tuple):
|
||||
assert len(descs.descs[0]) == 4, f"Expected 4 elements per desc, got {descs.descs[0]}"
|
||||
reg_descs = self.agent.get_reg_descs(descs.descs, descs.type)
|
||||
assert reg_descs is not None, "Failed to get reg_descs"
|
||||
self.agent.register_memory(reg_descs)
|
||||
|
||||
def deregister_memory(self, descs: RegMemoryDescs):
|
||||
if not descs.descs:
|
||||
raise ValueError("descs.descs must not be empty")
|
||||
if isinstance(descs.descs[0], tuple):
|
||||
assert len(descs.descs[0]) == 4, f"Expected 4 elements per desc, got {descs.descs[0]}"
|
||||
reg_descs = self.agent.get_reg_descs(descs.descs, descs.type)
|
||||
assert reg_descs is not None, "Failed to get reg_descs"
|
||||
self.agent.deregister_memory(reg_descs)
|
||||
|
||||
def load_remote_agent(self, name: str, agent_desc: bytes):
|
||||
self.agent.add_remote_agent(agent_desc)
|
||||
|
||||
def get_local_agent_desc(self):
|
||||
return self.agent.get_agent_metadata()
|
||||
|
||||
def invalidate_remote_agent(self, name: str):
|
||||
self.agent.remove_remote_agent(name)
|
||||
|
||||
def check_remote_descs(self, name: str, memory_descs: list[int]) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def notify_sync_message(self, name: str, sync_message: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@nvtx_range("NixlTransferAgent.submit_transfer_requests")
|
||||
def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus:
|
||||
src_xfer_descs = self.agent.get_xfer_descs(request.src_descs.descs, request.src_descs.type)
|
||||
dst_xfer_descs = self.agent.get_xfer_descs(request.dst_descs.descs, request.dst_descs.type)
|
||||
assert src_xfer_descs is not None, "Failed to get src_xfer_descs"
|
||||
assert dst_xfer_descs is not None, "Failed to get dst_xfer_descs"
|
||||
sync_message = "" if request.sync_message is None else request.sync_message
|
||||
handle = self.agent.initialize_xfer(
|
||||
request.op,
|
||||
src_xfer_descs,
|
||||
dst_xfer_descs,
|
||||
request.remote_name,
|
||||
sync_message,
|
||||
)
|
||||
status = self.agent.transfer(handle)
|
||||
if status == "ERROR":
|
||||
raise RuntimeError("NIXL transfer initialization failed.")
|
||||
return NixlTransferStatus(self.agent, handle)
|
||||
43
tensorrt_llm/_torch/disaggregation/nixl/agent.py
Normal file
43
tensorrt_llm/_torch/disaggregation/nixl/agent.py
Normal file
@ -0,0 +1,43 @@
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from ..base.agent import use_pure_python_transfer_agent
|
||||
|
||||
"""NIXL Transfer Agent implementations.
|
||||
|
||||
This module provides two implementations:
|
||||
1. BindingsNixlTransferAgent - Uses the standalone nixl_bindings C++ module with GIL release support
|
||||
2. NixlTransferAgent - Uses the Python nixl library directly (fallback)
|
||||
|
||||
The standalone nixl_bindings module is separate from the main trtllm bindings,
|
||||
so trtllm can still function normally even without NIXL dependencies.
|
||||
"""
|
||||
|
||||
|
||||
def _load_agent(module_name, required_attributes):
|
||||
try:
|
||||
module = __import__(module_name, fromlist=required_attributes, level=0)
|
||||
if all(hasattr(module, attr) for attr in required_attributes):
|
||||
return module
|
||||
except ImportError as e:
|
||||
logger.info("Failed to import module: %s. Error: %s", module_name, str(e))
|
||||
return None
|
||||
|
||||
|
||||
NixlTransferStatus, NixlTransferAgent = None, None
|
||||
|
||||
if use_pure_python_transfer_agent():
|
||||
_py_agent = _load_agent(
|
||||
module_name="tensorrt_llm._torch.disaggregation.nixl._agent_py",
|
||||
required_attributes=["NixlTransferAgent", "NixlTransferStatus"],
|
||||
)
|
||||
assert _py_agent is not None, "Failed to load pure Python NIXL Transfer Agent."
|
||||
NixlTransferStatus = _py_agent.NixlTransferStatus
|
||||
NixlTransferAgent = _py_agent.NixlTransferAgent
|
||||
else:
|
||||
_cpp_agent = _load_agent(
|
||||
module_name="tensorrt_llm._torch.disaggregation.nixl._agent_cpp",
|
||||
required_attributes=["BindingsNixlTransferAgent", "BindingsNixlTransferStatus"],
|
||||
)
|
||||
assert _cpp_agent is not None, "Failed to load C++ NIXL Transfer Agent bindings."
|
||||
NixlTransferStatus = _cpp_agent.BindingsNixlTransferStatus
|
||||
NixlTransferAgent = _cpp_agent.BindingsNixlTransferAgent
|
||||
@ -33,6 +33,8 @@ l0_a10:
|
||||
- unittest/disaggregated/test_disagg_utils.py
|
||||
- unittest/disaggregated/test_router.py
|
||||
- unittest/disaggregated/test_remoteDictionary.py
|
||||
- unittest/disaggregated/test_agent_multi_backends.py
|
||||
- unittest/disaggregated/test_messenger.py
|
||||
- unittest/disaggregated/test_disagg_cluster_manager_worker.py
|
||||
- unittest/disaggregated/test_cluster_storage.py
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0]
|
||||
|
||||
@ -37,6 +37,8 @@ l0_h100:
|
||||
- unittest/disaggregated/test_disagg_utils.py
|
||||
- unittest/disaggregated/test_router.py
|
||||
- unittest/disaggregated/test_remoteDictionary.py
|
||||
- unittest/disaggregated/test_agent_multi_backends.py
|
||||
- unittest/disaggregated/test_messenger.py
|
||||
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype
|
||||
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_without_reuse
|
||||
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_reuse
|
||||
|
||||
@ -142,21 +142,24 @@ def test_base_agent_config_custom():
|
||||
use_prog_thread = True
|
||||
multi_thread = False
|
||||
use_listen_thread = True
|
||||
num_workers = 4
|
||||
enable_telemetry = True
|
||||
backend_params = {"key1": "value1", "key2": "value2"}
|
||||
|
||||
config = tab.BaseAgentConfig(
|
||||
name=name,
|
||||
use_prog_thread=use_prog_thread,
|
||||
multi_thread=multi_thread,
|
||||
use_listen_thread=use_listen_thread,
|
||||
num_workers=num_workers,
|
||||
enable_telemetry=enable_telemetry,
|
||||
backend_params=backend_params,
|
||||
)
|
||||
|
||||
assert config.name == name
|
||||
assert config.use_prog_thread == use_prog_thread
|
||||
assert config.multi_thread == multi_thread
|
||||
assert config.use_listen_thread == use_listen_thread
|
||||
assert config.num_workers == num_workers
|
||||
assert config.enable_telemetry == enable_telemetry
|
||||
assert config.backend_params == backend_params
|
||||
|
||||
|
||||
def test_base_agent_config_readwrite():
|
||||
@ -175,8 +178,11 @@ def test_base_agent_config_readwrite():
|
||||
config.use_listen_thread = True
|
||||
assert config.use_listen_thread is True
|
||||
|
||||
config.num_workers = 8
|
||||
assert config.num_workers == 8
|
||||
config.enable_telemetry = True
|
||||
assert config.enable_telemetry is True
|
||||
|
||||
config.backend_params = {"test_key": "test_value"}
|
||||
assert config.backend_params == {"test_key": "test_value"}
|
||||
|
||||
|
||||
def test_transfer_request():
|
||||
|
||||
176
tests/unittest/disaggregated/test_agent.py
Normal file
176
tests/unittest/disaggregated/test_agent.py
Normal file
@ -0,0 +1,176 @@
|
||||
from dataclasses import dataclass, field
|
||||
from unittest import TestCase
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tensorrt_llm import logger
|
||||
from tensorrt_llm._torch.disaggregation.base.agent import (
|
||||
MemoryDescs,
|
||||
MemoryType,
|
||||
RegMemoryDescs,
|
||||
TransferOp,
|
||||
TransferRequest,
|
||||
TransferStatus,
|
||||
)
|
||||
from tensorrt_llm._torch.disaggregation.nixl.agent import NixlTransferAgent
|
||||
|
||||
|
||||
class TestTransferStatus(TestCase):
|
||||
def test_mock_transfer_status(self):
|
||||
mock_transfer_status = Mock(spec=TransferStatus)
|
||||
mock_transfer_status.is_completed.return_value = True
|
||||
self.assertTrue(mock_transfer_status.is_completed())
|
||||
mock_transfer_status.is_completed.assert_called_once()
|
||||
mock_transfer_status.wait.return_value = True
|
||||
timeout_values = [None, 1000, 5000]
|
||||
for timeout in timeout_values:
|
||||
with self.subTest(timeout=timeout):
|
||||
result = mock_transfer_status.wait(timeout_ms=timeout)
|
||||
self.assertTrue(result)
|
||||
mock_transfer_status.wait.assert_called_with(timeout_ms=timeout)
|
||||
|
||||
|
||||
def _convert_to_memory_descs(reg_descs: RegMemoryDescs) -> MemoryDescs:
|
||||
tuples = [(ptr, size, device_id) for (ptr, size, device_id, _) in reg_descs.descs]
|
||||
|
||||
def _convert_memory_type(py_type: str) -> MemoryType:
|
||||
"""Convert Python memory type string to C++ MemoryType."""
|
||||
type_map = {
|
||||
"DRAM": MemoryType.DRAM,
|
||||
"VRAM": MemoryType.VRAM,
|
||||
"GPU": MemoryType.VRAM,
|
||||
"BLK": MemoryType.BLK,
|
||||
"OBJ": MemoryType.OBJ,
|
||||
"FILE": MemoryType.FILE,
|
||||
}
|
||||
return type_map.get(py_type.upper(), MemoryType.VRAM)
|
||||
|
||||
return MemoryDescs(_convert_memory_type(reg_descs.type), tuples)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryManager:
|
||||
allocated_memory: list[torch.Tensor] = field(default_factory=list)
|
||||
|
||||
def allocate_memory(
|
||||
self, size: int, name: str, memory_type=MemoryType.VRAM, device_id: int = 0
|
||||
) -> RegMemoryDescs:
|
||||
device = torch.device(f"cuda:{device_id}" if memory_type == MemoryType.VRAM else "cpu")
|
||||
|
||||
# Allocate memory block using torch.Tensor and track it
|
||||
block = torch.zeros(size, dtype=torch.uint8, device=device)
|
||||
self.allocated_memory.append(block)
|
||||
|
||||
# Return RegMemoryDescs with position arguments
|
||||
memory_descs = RegMemoryDescs(
|
||||
type=memory_type, descs=[(block.data_ptr(), block.numel(), device_id, name)]
|
||||
)
|
||||
return memory_descs
|
||||
|
||||
def clear_memory(self):
|
||||
"""Clear all tracked memory blocks."""
|
||||
self.allocated_memory.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_manager():
|
||||
return MemoryManager()
|
||||
|
||||
|
||||
@pytest.fixture(params=[256, 512])
|
||||
def memory_size(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=["DRAM", "VRAM"])
|
||||
def memory_type(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def alloc(memory_manager, memory_size, memory_type):
|
||||
"""Allocate memory for source and destination, based on the memory_size and memory_type parameters."""
|
||||
assert memory_size > 0, "Memory size must be a positive integer."
|
||||
if memory_type == "VRAM" and not torch.cuda.is_available():
|
||||
pytest.skip("CUDA not available for VRAM transfer tests")
|
||||
src_descs = memory_manager.allocate_memory(
|
||||
size=memory_size, name="src_mem", memory_type=memory_type
|
||||
)
|
||||
dst_descs = memory_manager.allocate_memory(
|
||||
size=memory_size, name="dst_mem", memory_type=memory_type
|
||||
)
|
||||
return src_descs, dst_descs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def transfer_agent_src():
|
||||
return NixlTransferAgent(name="src_agent")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def transfer_agent_dst():
|
||||
return NixlTransferAgent(name="dst_agent")
|
||||
|
||||
|
||||
def test_transfer_between_agents(
|
||||
transfer_agent_src,
|
||||
transfer_agent_dst,
|
||||
memory_manager,
|
||||
alloc,
|
||||
memory_size,
|
||||
memory_type,
|
||||
):
|
||||
"""End-to-end test of data transfer between two agents with parameterized memory sizes and types."""
|
||||
# Debug log the parameters being tested
|
||||
logger.info(f"Testing with memory_size={memory_size}, memory_type={memory_type}")
|
||||
|
||||
# Unpack source and destination memory descriptions
|
||||
memory_descs_src, memory_descs_dst = alloc
|
||||
|
||||
# Fill source memory with sequential data for validation
|
||||
src_data = memory_manager.allocated_memory[0]
|
||||
assert memory_size > 0, "Memory size must be positive."
|
||||
tensor = torch.arange(memory_size, dtype=torch.uint8) % 10
|
||||
src_data.copy_(tensor)
|
||||
|
||||
# Register memory with source and destination agents
|
||||
transfer_agent_src.register_memory(memory_descs_src)
|
||||
transfer_agent_dst.register_memory(memory_descs_dst)
|
||||
|
||||
src_agent_desc = transfer_agent_src.get_local_agent_desc()
|
||||
transfer_agent_dst.load_remote_agent("src_agent", src_agent_desc)
|
||||
|
||||
dst_agent_desc = transfer_agent_dst.get_local_agent_desc()
|
||||
transfer_agent_src.load_remote_agent("dst_agent", dst_agent_desc)
|
||||
|
||||
# Create and submit the transfer request
|
||||
transfer_request = TransferRequest(
|
||||
op=TransferOp.WRITE,
|
||||
src_descs=_convert_to_memory_descs(memory_descs_src),
|
||||
dst_descs=_convert_to_memory_descs(memory_descs_dst),
|
||||
remote_name="dst_agent",
|
||||
sync_message=None,
|
||||
)
|
||||
transfer_status = transfer_agent_src.submit_transfer_requests(transfer_request)
|
||||
assert transfer_status.wait(timeout_ms=5000), "Transfer did not complete within timeout."
|
||||
|
||||
# Validate transfer completion
|
||||
assert transfer_status.is_completed(), "Transfer did not complete successfully."
|
||||
|
||||
# Validate that the destination data matches the source data
|
||||
dst_data = memory_manager.allocated_memory[1]
|
||||
assert torch.equal(dst_data, src_data), "Destination data does not match source data."
|
||||
|
||||
# Clean up by deregistering memory and clearing allocations
|
||||
transfer_agent_src.deregister_memory(memory_descs_src)
|
||||
transfer_agent_dst.deregister_memory(memory_descs_dst)
|
||||
memory_manager.clear_memory()
|
||||
|
||||
transfer_agent_src.invalidate_remote_agent("dst_agent")
|
||||
transfer_agent_dst.invalidate_remote_agent("src_agent")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
||||
32
tests/unittest/disaggregated/test_agent_multi_backends.py
Normal file
32
tests/unittest/disaggregated/test_agent_multi_backends.py
Normal file
@ -0,0 +1,32 @@
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_py_nixl", ["0", "1"])
|
||||
def test_run_with_different_env(use_py_nixl):
|
||||
os.environ["TRTLLM_USE_PY_NIXL_KVCACHE"] = use_py_nixl
|
||||
print(f"Running tests with TRTLLM_USE_PY_NIXL_KVCACHE={use_py_nixl}")
|
||||
|
||||
test_file_path = os.path.join(os.path.dirname(__file__), "test_agent.py")
|
||||
print(f"Running tests in: {test_file_path}")
|
||||
|
||||
result = subprocess.run(
|
||||
["pytest", "--capture=no", test_file_path],
|
||||
env=os.environ.copy(),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
print(result.stdout)
|
||||
|
||||
if result.returncode != 0:
|
||||
print("Test failed. stderr output:")
|
||||
print(result.stderr)
|
||||
|
||||
assert result.returncode == 0, f"Tests failed with TRTLLM_USE_PY_NIXL_KVCACHE={use_py_nixl}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
||||
127
tests/unittest/disaggregated/test_messenger.py
Normal file
127
tests/unittest/disaggregated/test_messenger.py
Normal file
@ -0,0 +1,127 @@
|
||||
import socket
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
|
||||
from tensorrt_llm._torch.disaggregation.native.messenger import ZMQMessenger, decode_message
|
||||
from tensorrt_llm._torch.disaggregation.native.utils import get_local_ip
|
||||
|
||||
TEST_CASES = [
|
||||
{
|
||||
"name": "valid_message",
|
||||
"message": [b"hello", b"world"],
|
||||
"encoding": "utf-8",
|
||||
"err_mode": "strict",
|
||||
"expected": ("hello", "world"),
|
||||
"raises": None,
|
||||
},
|
||||
{
|
||||
"name": "invalid_input",
|
||||
"message": ["hello", b"world"],
|
||||
"encoding": "utf-8",
|
||||
"err_mode": "strict",
|
||||
"expected": None,
|
||||
"raises": ValueError,
|
||||
},
|
||||
{
|
||||
"name": "decoding_error",
|
||||
"message": [b"\xff"],
|
||||
"encoding": "utf-8",
|
||||
"err_mode": "strict",
|
||||
"expected": None,
|
||||
"raises": UnicodeDecodeError,
|
||||
},
|
||||
{
|
||||
"name": "decoding_with_ignore",
|
||||
"message": [b"\xff"],
|
||||
"encoding": "utf-8",
|
||||
"err_mode": "ignore",
|
||||
"expected": ("",),
|
||||
"raises": None,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class TestDecodeMessage(unittest.TestCase):
|
||||
@parameterized.expand([(case["name"], case) for case in TEST_CASES])
|
||||
def test_decode_message(self, name, case):
|
||||
message = case["message"]
|
||||
encoding = case["encoding"]
|
||||
err_mode = case["err_mode"]
|
||||
expected = case["expected"]
|
||||
raises = case["raises"]
|
||||
|
||||
if raises:
|
||||
with self.assertRaises(raises):
|
||||
decode_message(message, encoding=encoding, err_mode=err_mode)
|
||||
else:
|
||||
decoded = decode_message(message, encoding=encoding, err_mode=err_mode)
|
||||
self.assertEqual(decoded, expected)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dynamic_endpoint():
|
||||
"""Fixture to dynamically generate an available endpoint with a free port."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0)) # Bind to an available port provided by the OS
|
||||
port = s.getsockname()[1]
|
||||
return f"tcp://{get_local_ip()}:{port}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_messenger_pair(dynamic_endpoint):
|
||||
def _create_messenger_pair(mode1, mode2):
|
||||
messenger1 = ZMQMessenger(
|
||||
mode1, endpoint=dynamic_endpoint if mode1 in ["ROUTER", "REP"] else None
|
||||
)
|
||||
messenger2 = ZMQMessenger(
|
||||
mode2, endpoint=dynamic_endpoint if mode2 in ["DEALER", "REQ"] else None
|
||||
)
|
||||
return messenger1, messenger2
|
||||
|
||||
yield _create_messenger_pair
|
||||
|
||||
|
||||
def test_router_dealer(create_messenger_pair):
|
||||
"""Test ROUTER and DEALER communication."""
|
||||
router, dealer = create_messenger_pair("ROUTER", "DEALER")
|
||||
|
||||
received_messages = []
|
||||
|
||||
def on_message(messages):
|
||||
received_messages.extend(messages)
|
||||
|
||||
router.start_listener(on_message)
|
||||
|
||||
dealer.send([b"Hello, ROUTER!"])
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
assert len(received_messages) > 0
|
||||
assert b"Hello, ROUTER!" in received_messages
|
||||
|
||||
router.stop()
|
||||
dealer.stop()
|
||||
|
||||
|
||||
def test_req_rep(create_messenger_pair):
|
||||
"""Test REQ and REP communication."""
|
||||
rep, req = create_messenger_pair("REP", "REQ")
|
||||
|
||||
def on_message(messages):
|
||||
rep.send(messages)
|
||||
|
||||
rep.start_listener(on_message)
|
||||
|
||||
req.send([b"Hello, REP!"])
|
||||
response = req.receive()
|
||||
assert response == [b"Hello, REP!"]
|
||||
|
||||
req.stop()
|
||||
rep.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue
Block a user