[TRTLLM-9527][feat] Add transferAgent binding (step 1) (#10113)

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
This commit is contained in:
Chuang Zhu 2026-01-06 08:40:38 +08:00 committed by GitHub
parent 846e54aa09
commit 536a8f6a9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 1345 additions and 43 deletions

2
.gitignore vendored
View File

@ -40,6 +40,8 @@ tensorrt_llm/libs
tensorrt_llm/bindings.*.so
tensorrt_llm/bindings.pyi
tensorrt_llm/bindings/**/*.pyi
tensorrt_llm/tensorrt_llm_transfer_agent_binding.*.so
tensorrt_llm/tensorrt_llm_transfer_agent_binding.pyi
tensorrt_llm/deep_ep/
tensorrt_llm/deep_ep_cpp_tllm.*.so
tensorrt_llm/deep_ep_cpp_tllm.pyi

View File

@ -274,13 +274,20 @@ private:
std::optional<SyncMessage> mSyncMessage;
};
enum class TransferState : uint8_t
{
kIN_PROGRESS,
kSUCCESS,
kFAILURE,
};
// Data structure for checking the status of active transfer operations.
class TransferStatus
{
public:
virtual ~TransferStatus() = default;
[[nodiscard]] virtual bool isCompleted() const = 0;
virtual void wait() const = 0;
virtual TransferState wait(int64_t timeout_ms = -1) const = 0;
};
struct BaseAgentConfig
@ -288,6 +295,8 @@ struct BaseAgentConfig
std::string mName;
bool useProgThread;
bool multiThread;
bool useListenThread;
unsigned int numWorkers;
};
class BaseTransferAgent

View File

@ -157,6 +157,7 @@ set(UCX_WRAPPER_TARGET tensorrt_llm_ucx_wrapper)
if(NIXL_ROOT)
set(NIXL_WRAPPER_TARGET tensorrt_llm_nixl_wrapper)
set(TRANSFER_AGENT_BINDING_TARGET tensorrt_llm_transfer_agent_binding)
endif()
if(MOONCAKE_ROOT)

View File

@ -90,5 +90,5 @@ target_compile_definitions(${EXECUTOR_STATIC_TARGET}
PUBLIC TOP_LEVEL_DIR="${TOP_LEVEL_DIR}")
add_subdirectory(cache_transmission/ucx_utils)
add_subdirectory(cache_transmission/nixl_utils)
add_subdirectory(cache_transmission/mooncake_utils)
add_subdirectory(cache_transmission/nixl_utils)

View File

@ -141,7 +141,8 @@ void AgentConnection::send(DataContext const& ctx, void const* data, size_t size
NotificationInfo notificationInfo{syncInfo};
std::stringstream ss;
NotificationInfo::serialize(notificationInfo, ss);
status->wait();
TransferState transferState = status->wait();
TLLM_CHECK_WITH_INFO(transferState == TransferState::kSUCCESS, "AgentConnection::send failed");
// TODO: there is a bug in request_with_notify https://github.com/ai-dynamo/nixl/pull/252
mAgentConnectionManager->getAgent()->notifySyncMessage(mRemoteAgentName, ss.str());
}
@ -246,7 +247,7 @@ AgentConnectionManager::AgentConnectionManager(
mAgentName = genUniqueAgentName();
// Create Agent
BaseAgentConfig config{mAgentName, true};
BaseAgentConfig config{mAgentName, true, false, true, 1};
m_Agent = makeTransferAgent(backendType, &config);
TLLM_CHECK(!mCacheTransBufferManagers.empty());
std::vector<MemoryDesc> memDescs;

View File

@ -36,5 +36,10 @@ if(MOONCAKE_ROOT AND NOT IS_ROCKY8)
target_link_libraries(${MOONCAKE_WRAPPER_TARGET}
PRIVATE ${TRANSFER_ENGINE_LIB} CUDA::cudart)
# Export variables to parent scope for transfer_agent_binding
set(TRANSFER_ENGINE_INCLUDE_DIR
${TRANSFER_ENGINE_INCLUDE_DIR}
PARENT_SCOPE)
endif()
endif()

View File

@ -47,11 +47,77 @@ MooncakeTransferStatus::MooncakeTransferStatus(transfer_engine_t engine, uint64_
TLLM_CHECK(mEngine);
}
void MooncakeTransferStatus::wait() const
TransferState MooncakeTransferStatus::wait(int64_t timeout_ms) const
{
while (!isCompleted())
auto startTime = std::chrono::steady_clock::now();
while (true)
{
std::this_thread::sleep_for(std::chrono::milliseconds(1));
if (mBatchFreed)
{
return TransferState::kSUCCESS;
}
bool has_failed = false;
bool all_completed = true;
for (size_t index = 0; index < mRequestCount; ++index)
{
transfer_status_t status;
int rc = getTransferStatus(mEngine, mBatchId, index, &status);
if (rc || status.status == STATUS_FAILED)
{
has_failed = true;
if (rc)
{
TLLM_LOG_ERROR(
"Failed to get transfer status for batch %lu, task %zu: error code %d", mBatchId, index, rc);
}
else
{
TLLM_LOG_ERROR(
"Transfer failed for batch %lu, task %zu: status %d", mBatchId, index, status.status);
}
}
else if (status.status != STATUS_COMPLETED)
{
all_completed = false;
}
}
// If any request failed, return failure
if (has_failed)
{
return TransferState::kFAILURE;
}
// If all requests completed successfully
if (all_completed)
{
freeBatchID(mEngine, mBatchId);
mBatchFreed = true;
TLLM_LOG_DEBUG("Batch ID %lu freed in wait()", mBatchId);
syncSegmentCache(mEngine);
return TransferState::kSUCCESS;
}
// If timeout_ms < 0, wait indefinitely
if (timeout_ms < 0)
{
std::this_thread::yield();
continue;
}
// Check if timeout has elapsed
auto elapsed
= std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - startTime)
.count();
if (elapsed >= timeout_ms)
{
return TransferState::kIN_PROGRESS;
}
std::this_thread::yield();
}
}
@ -100,7 +166,7 @@ void MooncakeTransferStatus::wait() const
return true;
}
const std::string MooncakeBase64Helper::STANDARD_CHARS
std::string const MooncakeBase64Helper::STANDARD_CHARS
= "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/";
@ -361,7 +427,7 @@ AgentDesc MooncakeTransferAgent::getLocalAgentDesc()
TLLM_LOG_DEBUG("MooncakeTransferAgent::getLocalAgentDesc");
// Using connection info as agent desc
const static size_t kBufLen = 64;
static size_t const kBufLen = 64;
char connectionInfo[kBufLen];
int ret = getLocalIpAndPort(mEngine, connectionInfo, kBufLen);
@ -375,7 +441,7 @@ ConnectionInfoType MooncakeTransferAgent::getLocalConnectionInfo()
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::getLocalConnectionInfo");
const static size_t kBufLen = 64;
static size_t const kBufLen = 64;
char connectionInfo[kBufLen];
int ret = getLocalIpAndPort(mEngine, connectionInfo, kBufLen);
@ -399,7 +465,7 @@ ConnectionInfoType MooncakeTransferAgent::getLocalConnectionInfo()
syncMessage = request.getSyncMessage().value();
}
const static size_t kMaxRequestCount = 1024;
static size_t const kMaxRequestCount = 1024;
uint64_t batchId = allocateBatchID(mEngine, kMaxRequestCount);
TLLM_CHECK_WITH_INFO(batchId != INVALID_BATCH, "allocateBatchID failed");

View File

@ -35,7 +35,7 @@ public:
[[nodiscard]] bool isCompleted() const override;
void wait() const override;
TransferState wait(int64_t timeout_ms = -1) const override;
private:
transfer_engine_t mEngine;

View File

@ -13,6 +13,9 @@
# License for the specific language governing permissions and limitations under
# the License.
# ============================================================================
# NIXL Wrapper Library
# ============================================================================
if(NIXL_ROOT)
find_package(NIXL REQUIRED)
# Check if all required packages were found
@ -30,6 +33,8 @@ if(NIXL_ROOT)
# Add include directories
target_include_directories(${NIXL_WRAPPER_TARGET} PRIVATE NIXL::nixl)
target_include_directories(${NIXL_WRAPPER_TARGET}
PRIVATE ${PROJECT_SOURCE_DIR}/include)
# Link against all NIXL libraries
target_link_libraries(${NIXL_WRAPPER_TARGET} PRIVATE NIXL::nixl)
@ -37,4 +42,85 @@ if(NIXL_ROOT)
# Link against CUDA
target_link_libraries(${NIXL_WRAPPER_TARGET} PRIVATE CUDA::cudart)
set(NIXL_ENABLED TRUE)
else()
set(NIXL_ENABLED FALSE)
endif()
# ============================================================================
# Check if Mooncake wrapper is available (built in mooncake_utils)
# ============================================================================
if(MOONCAKE_ROOT AND TARGET tensorrt_llm_mooncake_wrapper)
set(MOONCAKE_ENABLED TRUE)
else()
set(MOONCAKE_ENABLED FALSE)
endif()
# ============================================================================
# TensorRT-LLM Transfer Agent Binding Python Module Build if either NIXL or
# Mooncake is enabled
# ============================================================================
if(NIXL_ENABLED OR MOONCAKE_ENABLED)
set(TRANSFER_AGENT_BINDING_TARGET "tensorrt_llm_transfer_agent_binding")
# Collect binding source files
set(AGENT_BINDING_SOURCES "")
if(BINDING_TYPE STREQUAL "pybind")
list(APPEND AGENT_BINDING_SOURCES agentBindingsPybind.cpp)
else()
list(APPEND AGENT_BINDING_SOURCES agentBindingsNanobind.cpp)
endif()
if(BINDING_TYPE STREQUAL "pybind")
# Use pybind11 (already fetched via FetchContent)
pybind11_add_module(${TRANSFER_AGENT_BINDING_TARGET}
${AGENT_BINDING_SOURCES})
message(STATUS "Building tensorrt_llm_transfer_agent_binding with pybind11")
else()
# Default to nanobind (already fetched via FetchContent)
nanobind_add_module(${TRANSFER_AGENT_BINDING_TARGET}
${AGENT_BINDING_SOURCES})
message(STATUS "Building tensorrt_llm_transfer_agent_binding with nanobind")
endif()
target_compile_options(${TRANSFER_AGENT_BINDING_TARGET} PRIVATE -Wno-error)
# Add common include directories
target_include_directories(${TRANSFER_AGENT_BINDING_TARGET}
PRIVATE ${PROJECT_SOURCE_DIR}/include)
# Conditionally add NIXL support
if(NIXL_ENABLED)
target_compile_definitions(${TRANSFER_AGENT_BINDING_TARGET}
PRIVATE ENABLE_NIXL)
target_include_directories(${TRANSFER_AGENT_BINDING_TARGET}
PRIVATE NIXL::nixl)
target_link_libraries(${TRANSFER_AGENT_BINDING_TARGET}
PRIVATE ${NIXL_WRAPPER_TARGET})
target_link_libraries(${TRANSFER_AGENT_BINDING_TARGET} PRIVATE NIXL::nixl)
message(STATUS "Transfer agent binding: NIXL support enabled")
endif()
# Conditionally add Mooncake support
if(MOONCAKE_ENABLED)
target_compile_definitions(${TRANSFER_AGENT_BINDING_TARGET}
PRIVATE ENABLE_MOONCAKE)
target_include_directories(${TRANSFER_AGENT_BINDING_TARGET}
PRIVATE ${TRANSFER_ENGINE_INCLUDE_DIR})
target_link_libraries(${TRANSFER_AGENT_BINDING_TARGET}
PRIVATE tensorrt_llm_mooncake_wrapper)
message(STATUS "Transfer agent binding: Mooncake support enabled")
endif()
# Common dependencies
target_link_libraries(${TRANSFER_AGENT_BINDING_TARGET} PRIVATE CUDA::cudart)
target_link_libraries(${TRANSFER_AGENT_BINDING_TARGET}
PRIVATE ${SHARED_TARGET})
# Set RPATH for the module to find wrapper libraries
set_target_properties(
${TRANSFER_AGENT_BINDING_TARGET}
PROPERTIES BUILD_RPATH "$ORIGIN;$ORIGIN/libs;$ORIGIN/libs/nixl"
INSTALL_RPATH "$ORIGIN;$ORIGIN/libs;$ORIGIN/libs/nixl")
endif()

View File

@ -0,0 +1,239 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/executor/transferAgent.h"
#ifdef ENABLE_NIXL
#include "transferAgent.h"
#endif
#ifdef ENABLE_MOONCAKE
#include "../mooncake_utils/transferAgent.h"
#endif
#include <nanobind/nanobind.h>
#include <nanobind/stl/function.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/pair.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/unordered_map.h>
#include <nanobind/stl/vector.h>
namespace nb = nanobind;
namespace kvc = tensorrt_llm::executor::kv_cache;
NB_MODULE(tensorrt_llm_transfer_agent_binding, m)
{
m.doc() = "TensorRT-LLM Transfer Agent Python bindings (nanobind)";
// MemoryType enum
nb::enum_<kvc::MemoryType>(m, "MemoryType")
.value("DRAM", kvc::MemoryType::kDRAM)
.value("VRAM", kvc::MemoryType::kVRAM)
.value("BLK", kvc::MemoryType::kBLK)
.value("OBJ", kvc::MemoryType::kOBJ)
.value("FILE", kvc::MemoryType::kFILE);
// TransferOp enum
nb::enum_<kvc::TransferOp>(m, "TransferOp")
.value("READ", kvc::TransferOp::kREAD)
.value("WRITE", kvc::TransferOp::kWRITE);
// TransferState enum
nb::enum_<kvc::TransferState>(m, "TransferState")
.value("IN_PROGRESS", kvc::TransferState::kIN_PROGRESS)
.value("SUCCESS", kvc::TransferState::kSUCCESS)
.value("FAILURE", kvc::TransferState::kFAILURE);
// MemoryDesc class
nb::class_<kvc::MemoryDesc>(m, "MemoryDesc")
.def(nb::init<uintptr_t, size_t, uint32_t>(), nb::arg("addr"), nb::arg("len"), nb::arg("device_id"))
.def_prop_ro("addr", &kvc::MemoryDesc::getAddr)
.def_prop_ro("len", &kvc::MemoryDesc::getLen)
.def_prop_ro("device_id", &kvc::MemoryDesc::getDeviceId);
// MemoryDescs class
nb::class_<kvc::MemoryDescs>(m, "MemoryDescs")
.def(nb::init<kvc::MemoryType, std::vector<kvc::MemoryDesc>>(), nb::arg("type"), nb::arg("descs"))
.def_prop_ro("type", &kvc::MemoryDescs::getType)
.def_prop_ro("descs", &kvc::MemoryDescs::getDescs);
// AgentDesc class
nb::class_<kvc::AgentDesc>(m, "AgentDesc")
.def(
"__init__",
[](kvc::AgentDesc* self, nb::bytes data)
{
std::string str(data.c_str(), data.size());
new (self) kvc::AgentDesc{std::move(str)};
},
nb::arg("backend_agent_desc"))
.def(nb::init<std::string>(), nb::arg("backend_agent_desc"))
.def_prop_ro("backend_agent_desc",
[](kvc::AgentDesc const& self)
{
auto const& desc = self.getBackendAgentDesc();
return nb::bytes(desc.data(), desc.size());
});
// TransferRequest class
nb::class_<kvc::TransferRequest>(m, "TransferRequest")
.def(nb::init<kvc::TransferOp, kvc::TransferDescs, kvc::TransferDescs, std::string const&,
std::optional<kvc::SyncMessage>>(),
nb::arg("op"), nb::arg("src_descs"), nb::arg("dst_descs"), nb::arg("remote_name"),
nb::arg("sync_message") = std::nullopt)
.def_prop_ro("op", &kvc::TransferRequest::getOp)
.def_prop_ro("src_descs", &kvc::TransferRequest::getSrcDescs)
.def_prop_ro("dst_descs", &kvc::TransferRequest::getDstDescs)
.def_prop_ro("remote_name", &kvc::TransferRequest::getRemoteName)
.def_prop_ro("sync_message", &kvc::TransferRequest::getSyncMessage);
// TransferStatus base class
nb::class_<kvc::TransferStatus>(m, "TransferStatus")
.def("is_completed", &kvc::TransferStatus::isCompleted)
.def("wait", &kvc::TransferStatus::wait, nb::arg("timeout_ms") = -1);
// BaseAgentConfig struct
nb::class_<kvc::BaseAgentConfig>(m, "BaseAgentConfig")
.def(nb::init<>())
.def(
"__init__",
[](kvc::BaseAgentConfig* self, std::string name, bool use_prog_thread, bool multi_thread,
bool use_listen_thread, unsigned int num_workers) {
new (self) kvc::BaseAgentConfig{
std::move(name), use_prog_thread, multi_thread, use_listen_thread, num_workers};
},
nb::arg("name"), nb::arg("use_prog_thread") = true, nb::arg("multi_thread") = false,
nb::arg("use_listen_thread") = false, nb::arg("num_workers") = 1)
.def_rw("name", &kvc::BaseAgentConfig::mName)
.def_rw("use_prog_thread", &kvc::BaseAgentConfig::useProgThread)
.def_rw("multi_thread", &kvc::BaseAgentConfig::multiThread)
.def_rw("use_listen_thread", &kvc::BaseAgentConfig::useListenThread)
.def_rw("num_workers", &kvc::BaseAgentConfig::numWorkers);
// BaseTransferAgent class (abstract base)
nb::class_<kvc::BaseTransferAgent>(m, "BaseTransferAgent")
.def("register_memory", &kvc::BaseTransferAgent::registerMemory, nb::arg("descs"))
.def("deregister_memory", &kvc::BaseTransferAgent::deregisterMemory, nb::arg("descs"))
.def("load_remote_agent",
nb::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::BaseTransferAgent::loadRemoteAgent),
nb::arg("name"), nb::arg("agent_desc"))
.def("load_remote_agent_by_connection",
nb::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
&kvc::BaseTransferAgent::loadRemoteAgent),
nb::arg("name"), nb::arg("connection_info"))
.def("get_local_agent_desc", &kvc::BaseTransferAgent::getLocalAgentDesc)
.def("invalidate_remote_agent", &kvc::BaseTransferAgent::invalidateRemoteAgent, nb::arg("name"))
.def(
"submit_transfer_requests",
[](kvc::BaseTransferAgent& self, kvc::TransferRequest const& request)
{ return self.submitTransferRequests(request).release(); },
nb::arg("request"), nb::rv_policy::take_ownership)
.def(
"notify_sync_message", &kvc::BaseTransferAgent::notifySyncMessage, nb::arg("name"), nb::arg("sync_message"))
.def("get_notified_sync_messages", &kvc::BaseTransferAgent::getNotifiedSyncMessages)
.def("get_local_connection_info", &kvc::BaseTransferAgent::getLocalConnectionInfo)
.def("check_remote_descs", &kvc::BaseTransferAgent::checkRemoteDescs, nb::arg("name"), nb::arg("memory_descs"));
#ifdef ENABLE_NIXL
// NixlTransferStatus class - release GIL for blocking operations
nb::class_<kvc::NixlTransferStatus, kvc::TransferStatus>(m, "NixlTransferStatus")
.def("is_completed", &kvc::NixlTransferStatus::isCompleted, nb::call_guard<nb::gil_scoped_release>())
.def("wait", &kvc::NixlTransferStatus::wait, nb::arg("timeout_ms") = -1,
nb::call_guard<nb::gil_scoped_release>());
// NixlTransferAgent class
nb::class_<kvc::NixlTransferAgent, kvc::BaseTransferAgent>(m, "NixlTransferAgent")
.def(nb::init<kvc::BaseAgentConfig const&>(), nb::arg("config"))
.def("register_memory", &kvc::NixlTransferAgent::registerMemory, nb::arg("descs"))
.def("deregister_memory", &kvc::NixlTransferAgent::deregisterMemory, nb::arg("descs"))
.def("load_remote_agent",
nb::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::NixlTransferAgent::loadRemoteAgent),
nb::arg("name"), nb::arg("agent_desc"))
.def("load_remote_agent_by_connection",
nb::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
&kvc::NixlTransferAgent::loadRemoteAgent),
nb::arg("name"), nb::arg("connection_info"))
.def("get_local_agent_desc", &kvc::NixlTransferAgent::getLocalAgentDesc)
.def("get_local_connection_info", &kvc::NixlTransferAgent::getLocalConnectionInfo)
.def("invalidate_remote_agent", &kvc::NixlTransferAgent::invalidateRemoteAgent, nb::arg("name"))
.def(
"submit_transfer_requests",
[](kvc::NixlTransferAgent& self, kvc::TransferRequest const& request)
{ return self.submitTransferRequests(request).release(); },
nb::arg("request"), nb::rv_policy::take_ownership, nb::call_guard<nb::gil_scoped_release>())
.def(
"notify_sync_message", &kvc::NixlTransferAgent::notifySyncMessage, nb::arg("name"), nb::arg("sync_message"))
.def("get_notified_sync_messages", &kvc::NixlTransferAgent::getNotifiedSyncMessages)
.def("check_remote_descs", &kvc::NixlTransferAgent::checkRemoteDescs, nb::arg("name"), nb::arg("memory_descs"));
#endif
#ifdef ENABLE_MOONCAKE
// MooncakeTransferStatus class - release GIL for blocking operations
nb::class_<kvc::MooncakeTransferStatus, kvc::TransferStatus>(m, "MooncakeTransferStatus")
.def("is_completed", &kvc::MooncakeTransferStatus::isCompleted, nb::call_guard<nb::gil_scoped_release>())
.def("wait", &kvc::MooncakeTransferStatus::wait, nb::arg("timeout_ms") = -1,
nb::call_guard<nb::gil_scoped_release>());
// MooncakeTransferAgent class
nb::class_<kvc::MooncakeTransferAgent, kvc::BaseTransferAgent>(m, "MooncakeTransferAgent")
.def(nb::init<kvc::BaseAgentConfig const&>(), nb::arg("config"))
.def("register_memory", &kvc::MooncakeTransferAgent::registerMemory, nb::arg("descs"))
.def("deregister_memory", &kvc::MooncakeTransferAgent::deregisterMemory, nb::arg("descs"))
.def("load_remote_agent",
nb::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::MooncakeTransferAgent::loadRemoteAgent),
nb::arg("name"), nb::arg("agent_desc"))
.def("load_remote_agent_by_connection",
nb::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
&kvc::MooncakeTransferAgent::loadRemoteAgent),
nb::arg("name"), nb::arg("connection_info"))
.def("get_local_agent_desc", &kvc::MooncakeTransferAgent::getLocalAgentDesc)
.def("get_local_connection_info", &kvc::MooncakeTransferAgent::getLocalConnectionInfo)
.def("invalidate_remote_agent", &kvc::MooncakeTransferAgent::invalidateRemoteAgent, nb::arg("name"))
.def(
"submit_transfer_requests",
[](kvc::MooncakeTransferAgent& self, kvc::TransferRequest const& request)
{ return self.submitTransferRequests(request).release(); },
nb::arg("request"), nb::rv_policy::take_ownership, nb::call_guard<nb::gil_scoped_release>())
.def("notify_sync_message", &kvc::MooncakeTransferAgent::notifySyncMessage, nb::arg("name"),
nb::arg("sync_message"))
.def("get_notified_sync_messages", &kvc::MooncakeTransferAgent::getNotifiedSyncMessages)
.def("check_remote_descs", &kvc::MooncakeTransferAgent::checkRemoteDescs, nb::arg("name"),
nb::arg("memory_descs"));
#endif
// Factory function to create transfer agent by backend name (uses dynamic loading)
m.def(
"make_transfer_agent",
[](std::string const& backend, kvc::BaseAgentConfig const& config) -> kvc::BaseTransferAgent*
{ return kvc::makeTransferAgent(backend, &config).release(); },
nb::arg("backend"), nb::arg("config"), nb::rv_policy::take_ownership,
"Create a transfer agent by backend name ('nixl' or 'mooncake'). Uses dynamic loading.");
// Expose which backends are available
#ifdef ENABLE_NIXL
m.attr("NIXL_ENABLED") = true;
#else
m.attr("NIXL_ENABLED") = false;
#endif
#ifdef ENABLE_MOONCAKE
m.attr("MOONCAKE_ENABLED") = true;
#else
m.attr("MOONCAKE_ENABLED") = false;
#endif
}

View File

@ -0,0 +1,234 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/executor/transferAgent.h"
#ifdef ENABLE_NIXL
#include "transferAgent.h"
#endif
#ifdef ENABLE_MOONCAKE
#include "../mooncake_utils/transferAgent.h"
#endif
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;
namespace kvc = tensorrt_llm::executor::kv_cache;
PYBIND11_MODULE(tensorrt_llm_transfer_agent_binding, m)
{
m.doc() = "TensorRT-LLM Transfer Agent Python bindings (pybind11)";
// MemoryType enum
py::enum_<kvc::MemoryType>(m, "MemoryType")
.value("DRAM", kvc::MemoryType::kDRAM)
.value("VRAM", kvc::MemoryType::kVRAM)
.value("BLK", kvc::MemoryType::kBLK)
.value("OBJ", kvc::MemoryType::kOBJ)
.value("FILE", kvc::MemoryType::kFILE);
// TransferOp enum
py::enum_<kvc::TransferOp>(m, "TransferOp")
.value("READ", kvc::TransferOp::kREAD)
.value("WRITE", kvc::TransferOp::kWRITE);
// TransferState enum
py::enum_<kvc::TransferState>(m, "TransferState")
.value("IN_PROGRESS", kvc::TransferState::kIN_PROGRESS)
.value("SUCCESS", kvc::TransferState::kSUCCESS)
.value("FAILURE", kvc::TransferState::kFAILURE);
// MemoryDesc class
py::class_<kvc::MemoryDesc>(m, "MemoryDesc")
.def(py::init<uintptr_t, size_t, uint32_t>(), py::arg("addr"), py::arg("len"), py::arg("device_id"))
.def_property_readonly("addr", &kvc::MemoryDesc::getAddr)
.def_property_readonly("len", &kvc::MemoryDesc::getLen)
.def_property_readonly("device_id", &kvc::MemoryDesc::getDeviceId);
// MemoryDescs class
py::class_<kvc::MemoryDescs>(m, "MemoryDescs")
.def(py::init<kvc::MemoryType, std::vector<kvc::MemoryDesc>>(), py::arg("type"), py::arg("descs"))
.def_property_readonly("type", &kvc::MemoryDescs::getType)
.def_property_readonly("descs", &kvc::MemoryDescs::getDescs);
// AgentDesc class
py::class_<kvc::AgentDesc>(m, "AgentDesc")
.def(py::init(
[](py::bytes data)
{
std::string str(PyBytes_AsString(data.ptr()), PyBytes_Size(data.ptr()));
return kvc::AgentDesc{std::move(str)};
}),
py::arg("backend_agent_desc"))
.def(py::init<std::string>(), py::arg("backend_agent_desc"))
.def_property_readonly("backend_agent_desc",
[](kvc::AgentDesc const& self)
{
auto const& desc = self.getBackendAgentDesc();
return py::bytes(desc.data(), desc.size());
});
// TransferRequest class
py::class_<kvc::TransferRequest>(m, "TransferRequest")
.def(py::init<kvc::TransferOp, kvc::TransferDescs, kvc::TransferDescs, std::string const&,
std::optional<kvc::SyncMessage>>(),
py::arg("op"), py::arg("src_descs"), py::arg("dst_descs"), py::arg("remote_name"),
py::arg("sync_message") = std::nullopt)
.def_property_readonly("op", &kvc::TransferRequest::getOp)
.def_property_readonly("src_descs", &kvc::TransferRequest::getSrcDescs)
.def_property_readonly("dst_descs", &kvc::TransferRequest::getDstDescs)
.def_property_readonly("remote_name", &kvc::TransferRequest::getRemoteName)
.def_property_readonly("sync_message", &kvc::TransferRequest::getSyncMessage);
// TransferStatus base class
py::class_<kvc::TransferStatus>(m, "TransferStatus")
.def("is_completed", &kvc::TransferStatus::isCompleted)
.def("wait", &kvc::TransferStatus::wait, py::arg("timeout_ms") = -1);
// BaseAgentConfig struct
py::class_<kvc::BaseAgentConfig>(m, "BaseAgentConfig")
.def(py::init<>())
.def(py::init(
[](std::string name, bool use_prog_thread, bool multi_thread, bool use_listen_thread,
unsigned int num_workers) {
return kvc::BaseAgentConfig{
std::move(name), use_prog_thread, multi_thread, use_listen_thread, num_workers};
}),
py::arg("name"), py::arg("use_prog_thread") = true, py::arg("multi_thread") = false,
py::arg("use_listen_thread") = false, py::arg("num_workers") = 1)
.def_readwrite("name", &kvc::BaseAgentConfig::mName)
.def_readwrite("use_prog_thread", &kvc::BaseAgentConfig::useProgThread)
.def_readwrite("multi_thread", &kvc::BaseAgentConfig::multiThread)
.def_readwrite("use_listen_thread", &kvc::BaseAgentConfig::useListenThread)
.def_readwrite("num_workers", &kvc::BaseAgentConfig::numWorkers);
// BaseTransferAgent class (abstract base)
py::class_<kvc::BaseTransferAgent>(m, "BaseTransferAgent")
.def("register_memory", &kvc::BaseTransferAgent::registerMemory, py::arg("descs"))
.def("deregister_memory", &kvc::BaseTransferAgent::deregisterMemory, py::arg("descs"))
.def("load_remote_agent",
py::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::BaseTransferAgent::loadRemoteAgent),
py::arg("name"), py::arg("agent_desc"))
.def("load_remote_agent_by_connection",
py::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
&kvc::BaseTransferAgent::loadRemoteAgent),
py::arg("name"), py::arg("connection_info"))
.def("get_local_agent_desc", &kvc::BaseTransferAgent::getLocalAgentDesc)
.def("invalidate_remote_agent", &kvc::BaseTransferAgent::invalidateRemoteAgent, py::arg("name"))
.def(
"submit_transfer_requests",
[](kvc::BaseTransferAgent& self, kvc::TransferRequest const& request)
{ return self.submitTransferRequests(request).release(); },
py::arg("request"), py::return_value_policy::take_ownership)
.def(
"notify_sync_message", &kvc::BaseTransferAgent::notifySyncMessage, py::arg("name"), py::arg("sync_message"))
.def("get_notified_sync_messages", &kvc::BaseTransferAgent::getNotifiedSyncMessages)
.def("get_local_connection_info", &kvc::BaseTransferAgent::getLocalConnectionInfo)
.def("check_remote_descs", &kvc::BaseTransferAgent::checkRemoteDescs, py::arg("name"), py::arg("memory_descs"));
#ifdef ENABLE_NIXL
// NixlTransferStatus class - release GIL for blocking operations
py::class_<kvc::NixlTransferStatus, kvc::TransferStatus>(m, "NixlTransferStatus")
.def("is_completed", &kvc::NixlTransferStatus::isCompleted, py::call_guard<py::gil_scoped_release>())
.def("wait", &kvc::NixlTransferStatus::wait, py::arg("timeout_ms") = -1,
py::call_guard<py::gil_scoped_release>());
// NixlTransferAgent class
py::class_<kvc::NixlTransferAgent, kvc::BaseTransferAgent>(m, "NixlTransferAgent")
.def(py::init<kvc::BaseAgentConfig const&>(), py::arg("config"))
.def("register_memory", &kvc::NixlTransferAgent::registerMemory, py::arg("descs"))
.def("deregister_memory", &kvc::NixlTransferAgent::deregisterMemory, py::arg("descs"))
.def("load_remote_agent",
py::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::NixlTransferAgent::loadRemoteAgent),
py::arg("name"), py::arg("agent_desc"))
.def("load_remote_agent_by_connection",
py::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
&kvc::NixlTransferAgent::loadRemoteAgent),
py::arg("name"), py::arg("connection_info"))
.def("get_local_agent_desc", &kvc::NixlTransferAgent::getLocalAgentDesc)
.def("get_local_connection_info", &kvc::NixlTransferAgent::getLocalConnectionInfo)
.def("invalidate_remote_agent", &kvc::NixlTransferAgent::invalidateRemoteAgent, py::arg("name"))
.def(
"submit_transfer_requests",
[](kvc::NixlTransferAgent& self, kvc::TransferRequest const& request)
{ return self.submitTransferRequests(request).release(); },
py::arg("request"), py::return_value_policy::take_ownership, py::call_guard<py::gil_scoped_release>())
.def(
"notify_sync_message", &kvc::NixlTransferAgent::notifySyncMessage, py::arg("name"), py::arg("sync_message"))
.def("get_notified_sync_messages", &kvc::NixlTransferAgent::getNotifiedSyncMessages)
.def("check_remote_descs", &kvc::NixlTransferAgent::checkRemoteDescs, py::arg("name"), py::arg("memory_descs"));
#endif
#ifdef ENABLE_MOONCAKE
// MooncakeTransferStatus class - release GIL for blocking operations
py::class_<kvc::MooncakeTransferStatus, kvc::TransferStatus>(m, "MooncakeTransferStatus")
.def("is_completed", &kvc::MooncakeTransferStatus::isCompleted, py::call_guard<py::gil_scoped_release>())
.def("wait", &kvc::MooncakeTransferStatus::wait, py::arg("timeout_ms") = -1,
py::call_guard<py::gil_scoped_release>());
// MooncakeTransferAgent class
py::class_<kvc::MooncakeTransferAgent, kvc::BaseTransferAgent>(m, "MooncakeTransferAgent")
.def(py::init<kvc::BaseAgentConfig const&>(), py::arg("config"))
.def("register_memory", &kvc::MooncakeTransferAgent::registerMemory, py::arg("descs"))
.def("deregister_memory", &kvc::MooncakeTransferAgent::deregisterMemory, py::arg("descs"))
.def("load_remote_agent",
py::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::MooncakeTransferAgent::loadRemoteAgent),
py::arg("name"), py::arg("agent_desc"))
.def("load_remote_agent_by_connection",
py::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
&kvc::MooncakeTransferAgent::loadRemoteAgent),
py::arg("name"), py::arg("connection_info"))
.def("get_local_agent_desc", &kvc::MooncakeTransferAgent::getLocalAgentDesc)
.def("get_local_connection_info", &kvc::MooncakeTransferAgent::getLocalConnectionInfo)
.def("invalidate_remote_agent", &kvc::MooncakeTransferAgent::invalidateRemoteAgent, py::arg("name"))
.def(
"submit_transfer_requests",
[](kvc::MooncakeTransferAgent& self, kvc::TransferRequest const& request)
{ return self.submitTransferRequests(request).release(); },
py::arg("request"), py::return_value_policy::take_ownership, py::call_guard<py::gil_scoped_release>())
.def("notify_sync_message", &kvc::MooncakeTransferAgent::notifySyncMessage, py::arg("name"),
py::arg("sync_message"))
.def("get_notified_sync_messages", &kvc::MooncakeTransferAgent::getNotifiedSyncMessages)
.def("check_remote_descs", &kvc::MooncakeTransferAgent::checkRemoteDescs, py::arg("name"),
py::arg("memory_descs"));
#endif
// Factory function to create transfer agent by backend name (uses dynamic loading)
m.def(
"make_transfer_agent",
[](std::string const& backend, kvc::BaseAgentConfig const& config) -> kvc::BaseTransferAgent*
{ return kvc::makeTransferAgent(backend, &config).release(); },
py::arg("backend"), py::arg("config"), py::return_value_policy::take_ownership,
"Create a transfer agent by backend name ('nixl' or 'mooncake'). Uses dynamic loading.");
// Expose which backends are available
#ifdef ENABLE_NIXL
m.attr("NIXL_ENABLED") = true;
#else
m.attr("NIXL_ENABLED") = false;
#endif
#ifdef ENABLE_MOONCAKE
m.attr("MOONCAKE_ENABLED") = true;
#else
m.attr("MOONCAKE_ENABLED") = false;
#endif
}

View File

@ -22,6 +22,7 @@
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include <arpa/inet.h>
#include <chrono>
#include <dirent.h>
#include <fcntl.h>
#include <ifaddrs.h>
@ -31,6 +32,7 @@
#include <set>
#include <sys/file.h>
#include <sys/stat.h>
#include <thread>
#include <unistd.h>
#include <vector>
@ -318,10 +320,40 @@ NixlTransferStatus::NixlTransferStatus(nixlAgent* agent, nixlXferReqH* handle)
TLLM_CHECK(mHandle);
}
void NixlTransferStatus::wait() const
TransferState NixlTransferStatus::wait(int64_t timeout_ms) const
{
while (!isCompleted())
;
auto startTime = std::chrono::steady_clock::now();
while (true)
{
auto status = mRawAgent->getXferStatus(mHandle);
if (status == NIXL_SUCCESS)
{
return TransferState::kSUCCESS;
}
else if (status != NIXL_IN_PROG)
{
return TransferState::kFAILURE;
}
// If timeout_ms < 0, wait indefinitely until status is not NIXL_IN_PROG
if (timeout_ms < 0)
{
std::this_thread::yield();
continue;
}
// Check if timeout has elapsed
auto elapsed
= std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - startTime)
.count();
if (elapsed >= timeout_ms)
{
return TransferState::kIN_PROGRESS;
}
std::this_thread::yield();
}
}
[[nodiscard]] bool NixlTransferStatus::isCompleted() const
@ -333,6 +365,7 @@ NixlTransferAgent::NixlTransferAgent(BaseAgentConfig const& config)
: mName{config.mName}
{
nixl_status_t status;
if (config.useListenThread)
{
FileLock lock("/tmp/trtllm_nixl_port.lock");
if (!lock.lock())
@ -341,10 +374,18 @@ NixlTransferAgent::NixlTransferAgent(BaseAgentConfig const& config)
}
auto envPort = common::getEnvNixlPort();
uint16_t port = envPort > 0 ? getIncrmentPort(envPort) : getAvailablePort();
nixlAgentConfig nixlConfig{config.useProgThread, true, port};
nixlAgentConfig nixlConfig{
config.useProgThread, true, port, nixl_thread_sync_t::NIXL_THREAD_SYNC_DEFAULT, config.numWorkers};
mAddress = getAvailableIP() + ":" + std::to_string(port);
mRawAgent = std::make_unique<nixlAgent>(config.mName, std::move(nixlConfig));
}
else
{
mAddress.clear();
nixlAgentConfig nixlConfig{
config.useProgThread, false, 0, nixl_thread_sync_t::NIXL_THREAD_SYNC_DEFAULT, config.numWorkers};
mRawAgent = std::make_unique<nixlAgent>(config.mName, std::move(nixlConfig));
}
std::string nixlBackend = common::getEnvNixlBackend();
// List of supported backends - extend this list as new backends are added
@ -645,7 +686,8 @@ void NixlLoopbackAgent::executeLoopbackRequest(
std::unique_ptr<TransferStatus> status = this->submitLoopbackRequests(memoryDescs, fileDescs, isOffload);
TLLM_CHECK_WITH_INFO(status != nullptr, "submitLoopbackRequests failed");
status->wait();
TransferState transferState = status->wait();
TLLM_CHECK_WITH_INFO(transferState == TransferState::kSUCCESS, "submitLoopbackRequests failed");
this->deregisterMemory(memoryDescs);
this->deregisterFiles(fileDescs);

View File

@ -45,7 +45,7 @@ public:
[[nodiscard]] bool isCompleted() const override;
void wait() const override;
[[nodiscard]] TransferState wait(int64_t timeout_ms = -1) const override;
private:
nixlAgent* mRawAgent{};

View File

@ -66,3 +66,8 @@ if(NOT WIN32)
${TRTLLM_NB_MODULE} PROPERTIES LINK_FLAGS
"${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}")
endif()
# Build transfer_agent_binding when building bindings (if NIXL is enabled)
if(TARGET ${TRANSFER_AGENT_BINDING_TARGET})
add_dependencies(${TRTLLM_NB_MODULE} ${TRANSFER_AGENT_BINDING_TARGET})
endif()

View File

@ -69,3 +69,8 @@ if(NOT WIN32)
${TRTLLM_PYBIND_MODULE} PROPERTIES LINK_FLAGS
"${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}")
endif()
# Build transfer_agent_binding when building bindings (if NIXL is enabled)
if(TARGET ${TRANSFER_AGENT_BINDING_TARGET})
add_dependencies(${TRTLLM_PYBIND_MODULE} ${TRANSFER_AGENT_BINDING_TARGET})
endif()

View File

@ -92,7 +92,7 @@ TEST_P(TransferAgentTest, Basic)
{
std::string const agent0{"agent0"}, agent1{"agent1"};
BaseAgentConfig config0{agent0, true}, config1{agent1, true};
BaseAgentConfig config0{agent0, true, false, true, 1}, config1{agent1, true, false, true, 1};
auto xferAgent0 = makeTransferAgent(config0);
auto xferAgent1 = makeTransferAgent(config1);
@ -117,10 +117,9 @@ TEST_P(TransferAgentTest, Basic)
TransferRequest writeReq{TransferOp::kWRITE, regMem0.getDescs(), regMem1.getDescs(), agent1};
auto status = xferAgent0->submitTransferRequests(writeReq);
status->wait();
TLLM_CHECK(status->wait() == TransferState::kSUCCESS);
TLLM_CHECK(memory0 == memory1);
xferAgent0->invalidateRemoteAgent(agent1);
}
@ -128,7 +127,7 @@ TEST_P(TransferAgentTest, Basic2)
{
std::string const agent0{"agent0"}, agent1{"agent1"};
BaseAgentConfig config0{agent0, true}, config1{agent1, true};
BaseAgentConfig config0{agent0, true, false, true, 1}, config1{agent1, true, false, true, 1};
auto xferAgent0 = makeTransferAgent(config0);
auto xferAgent1 = makeTransferAgent(config1);
@ -152,7 +151,7 @@ TEST_P(TransferAgentTest, Basic2)
TransferRequest readReq{TransferOp::kREAD, regMem0.getDescs(), regMem1.getDescs(), agent1};
auto status = xferAgent0->submitTransferRequests(readReq);
status->wait();
TLLM_CHECK(status->wait() == TransferState::kSUCCESS);
TLLM_CHECK(memory0 == memory1);
@ -163,7 +162,7 @@ TEST_P(TransferAgentTest, DeviceMemory)
{
std::string const agent0{"agent0"}, agent1{"agent1"};
BaseAgentConfig config0{agent0, true}, config1{agent1, true};
BaseAgentConfig config0{agent0, true, false, true, 1}, config1{agent1, true, false, true, 1};
auto xferAgent0 = makeTransferAgent(config0);
auto xferAgent1 = makeTransferAgent(config1);
@ -177,8 +176,8 @@ TEST_P(TransferAgentTest, DeviceMemory)
cudaMalloc(&dev_ptr1, size);
std::vector<char> memory0(size, 10);
std::vector<char> memory1(size, 1);
cudaMemcpy(dev_ptr0, memory0.data(), size, cudaMemcpyHostToDevice);
cudaMemcpy(dev_ptr1, memory1.data(), size, cudaMemcpyHostToDevice);
TLLM_CUDA_CHECK(cudaMemcpy(dev_ptr0, memory0.data(), size, cudaMemcpyHostToDevice));
TLLM_CUDA_CHECK(cudaMemcpy(dev_ptr1, memory1.data(), size, cudaMemcpyHostToDevice));
RegisteredHostMemory regMem0(
MemoryDescs{MemoryType::kVRAM, {MemoryDesc{dev_ptr0, size, deviceId}}}, xferAgent0.get());
RegisteredHostMemory regMem1(
@ -194,12 +193,13 @@ TEST_P(TransferAgentTest, DeviceMemory)
} while (!checked);
TransferRequest writeReq{TransferOp::kWRITE, regMem0.getDescs(), regMem1.getDescs(), agent1};
auto status = xferAgent0->submitTransferRequests(writeReq);
status->wait();
TLLM_CHECK(status->wait() == TransferState::kSUCCESS);
cudaMemcpy(memory0.data(), dev_ptr0, size, cudaMemcpyDeviceToHost);
cudaMemcpy(memory1.data(), dev_ptr1, size, cudaMemcpyDeviceToHost);
TLLM_CUDA_CHECK(cudaMemcpy(memory0.data(), dev_ptr0, size, cudaMemcpyDeviceToHost));
TLLM_CUDA_CHECK(cudaMemcpy(memory1.data(), dev_ptr1, size, cudaMemcpyDeviceToHost));
TLLM_CHECK(memory0 == memory1);
TLLM_CUDA_CHECK(cudaFree(dev_ptr0));
TLLM_CUDA_CHECK(cudaFree(dev_ptr1));
xferAgent0->invalidateRemoteAgent(agent1);
@ -209,7 +209,8 @@ TEST_P(TransferAgentTest, Connect)
{
std::string const agent0{"agent0"}, agent1{"agent1"}, agent2{"agent2"};
BaseAgentConfig config0{agent0, true}, config1{agent1, true}, config2{agent2, true};
BaseAgentConfig config0{agent0, true, false, true, 1}, config1{agent1, true, false, true, 1},
config2{agent2, true, false, true, 1};
auto xferAgent0 = makeTransferAgent(config0);
auto xferAgent1 = makeTransferAgent(config1);
auto xferAgent2 = makeTransferAgent(config2);
@ -236,7 +237,7 @@ TEST_P(TransferAgentTest, Connect)
} while (!checked);
TransferRequest writeReq{TransferOp::kWRITE, memDescs0, memDescs1, agent1};
auto status = xferAgent0->submitTransferRequests(writeReq);
status->wait();
TLLM_CHECK(status->wait() == TransferState::kSUCCESS);
TLLM_CHECK(memory0 == memory1);
xferAgent2->loadRemoteAgent(agent1, connectionInfo);
@ -247,7 +248,7 @@ TEST_P(TransferAgentTest, Connect)
} while (!checked);
TransferRequest writeReq2{TransferOp::kWRITE, memDescs0, memDescs1, agent1};
auto status2 = xferAgent2->submitTransferRequests(writeReq2);
status2->wait();
TLLM_CHECK(status2->wait() == TransferState::kSUCCESS);
TLLM_CHECK(memory0 == memory1);
xferAgent0->invalidateRemoteAgent(agent1);
xferAgent2->invalidateRemoteAgent(agent1);
@ -260,7 +261,7 @@ TEST_P(TransferAgentTest, SyncMessage)
{
constexpr std::size_t MAX_QUERY_TIMES = std::numeric_limits<size_t>::max();
std::string const agent0{"agent0"}, agent1{"agent1"};
BaseAgentConfig config0{agent0, true}, config1{agent1, true};
BaseAgentConfig config0{agent0, true, false, true, 1}, config1{agent1, true, false, true, 1};
auto xferAgent0 = makeTransferAgent(config0);
auto xferAgent1 = makeTransferAgent(config1);
@ -294,7 +295,7 @@ TEST_P(TransferAgentTest, SyncMessage)
{
notif = xferAgent1->getNotifiedSyncMessages();
}
status->wait();
TLLM_CHECK(status->wait() == TransferState::kSUCCESS);
TLLM_CHECK(status->isCompleted());
TLLM_CHECK(notif.size() == 1);
TLLM_CHECK(notif[agent0].size() == 1);
@ -343,7 +344,7 @@ TEST_P(TransferAgentTest, SyncMessage)
{
notif4 = xferAgent0->getNotifiedSyncMessages();
}
status1->wait();
TLLM_CHECK(status1->wait() == TransferState::kSUCCESS);
TLLM_CHECK(status1->isCompleted());
TLLM_CHECK(notif4.size() == 1);
TLLM_CHECK(notif4[agent1].size() == 1);

View File

@ -370,6 +370,7 @@ def check_missing_libs(lib_name: str) -> list[str]:
def generate_python_stubs_linux(binding_type: str, venv_python: Path,
deep_ep: bool, flash_mla: bool,
transfer_agent_binding: bool,
binding_lib_name: str):
is_nanobind = binding_type == "nanobind"
if is_nanobind:
@ -411,6 +412,16 @@ def generate_python_stubs_linux(binding_type: str, venv_python: Path,
build_run(
f"\"{venv_python}\" -m pybind11_stubgen -o . deep_ep_cpp_tllm --exit-code",
env=env_stub_gen)
if transfer_agent_binding:
# Generate stubs for tensorrt_llm_transfer_agent_binding
if is_nanobind:
build_run(
f"\"{venv_python}\" -m nanobind.stubgen -m tensorrt_llm_transfer_agent_binding -O .",
env=env_stub_gen)
else:
build_run(
f"\"{venv_python}\" -m pybind11_stubgen -o . tensorrt_llm_transfer_agent_binding --exit-code",
env=env_stub_gen)
finally:
if link_dir:
rmtree(link_dir)
@ -801,17 +812,15 @@ def main(*,
build_run(
f"find {ucx_dir} -type f -name '*.so*' -exec patchelf --set-rpath \'$ORIGIN:$ORIGIN/ucx:$ORIGIN/../\' {{}} \\;"
)
if os.path.exists(
build_dir /
"tensorrt_llm/executor/cache_transmission/nixl_utils/libtensorrt_llm_nixl_wrapper.so"
):
install_file(
build_dir /
"tensorrt_llm/executor/cache_transmission/nixl_utils/libtensorrt_llm_nixl_wrapper.so",
lib_dir / "libtensorrt_llm_nixl_wrapper.so")
# NIXL wrapper and libraries
nixl_utils_dir = build_dir / "tensorrt_llm/executor/cache_transmission/nixl_utils"
if os.path.exists(nixl_utils_dir / "libtensorrt_llm_nixl_wrapper.so"):
install_file(nixl_utils_dir / "libtensorrt_llm_nixl_wrapper.so",
lib_dir / "libtensorrt_llm_nixl_wrapper.so")
build_run(
f'patchelf --set-rpath \'$ORIGIN/nixl/\' {lib_dir / "libtensorrt_llm_nixl_wrapper.so"}'
)
# Copy NIXL libraries
if os.path.exists("/opt/nvidia/nvda_nixl"):
nixl_dir = lib_dir / "nixl"
if nixl_dir.exists():
@ -825,6 +834,15 @@ def main(*,
build_run(
f"find {nixl_dir} -type f -name '*.so*' -exec patchelf --set-rpath \'$ORIGIN:$ORIGIN/plugins:$ORIGIN/../:$ORIGIN/../ucx/:$ORIGIN/../../ucx/\' {{}} \\;"
)
# Install tensorrt_llm_transfer_agent_binding Python module (standalone agent bindings)
# This is built when either NIXL or Mooncake is enabled
# Install to tensorrt_llm/ (same level as bindings.so)
agent_binding_so = list(
nixl_utils_dir.glob("tensorrt_llm_transfer_agent_binding*.so"))
if agent_binding_so:
trtllm_dir = project_dir / "tensorrt_llm"
install_file(agent_binding_so[0],
trtllm_dir / agent_binding_so[0].name)
if os.path.exists(
build_dir /
"tensorrt_llm/executor/cache_transmission/mooncake_utils/libtensorrt_llm_mooncake_wrapper.so"
@ -946,6 +964,7 @@ def main(*,
binding_type, venv_python,
bool(deep_ep_cuda_architectures),
bool(flash_mla_cuda_architectures),
nixl_root is not None or mooncake_root is not None,
binding_lib_file_name)
if not skip_building_wheel:

View File

@ -114,6 +114,8 @@ else:
'libs/libnvinfer_plugin_tensorrt_llm.so',
'libs/libtensorrt_llm_ucx_wrapper.so', 'libs/libdecoder_attention_0.so',
'libs/libtensorrt_llm_nixl_wrapper.so', 'libs/nixl/**/*',
'tensorrt_llm_transfer_agent_binding*.so',
'tensorrt_llm_transfer_agent_binding.pyi',
'libs/libtensorrt_llm_mooncake_wrapper.so', 'libs/ucx/**/*',
'libs/libpg_utils.so', 'libs/libdecoder_attention_1.so',
'libs/nvshmem/License.txt', 'libs/nvshmem/nvshmem_bootstrap_uid.so.3',

View File

@ -0,0 +1,585 @@
import pytest
# Try to import the transfer agent binding module
try:
import tensorrt_llm.tensorrt_llm_transfer_agent_binding as tab
HAS_TRANSFER_AGENT = True
# Check which backends are available
HAS_NIXL = getattr(tab, "NIXL_ENABLED", False)
HAS_MOONCAKE = getattr(tab, "MOONCAKE_ENABLED", False)
except ImportError:
HAS_TRANSFER_AGENT = False
HAS_NIXL = False
HAS_MOONCAKE = False
# Try to import torch for functional tests
try:
import torch
HAS_TORCH = True
HAS_CUDA = torch.cuda.is_available()
except ImportError:
HAS_TORCH = False
HAS_CUDA = False
pytestmark = pytest.mark.skipif(
not HAS_TRANSFER_AGENT,
reason="Transfer agent bindings not available (tensorrt_llm_transfer_agent_binding)",
)
# =============================================================================
# Common Tests (independent of backend)
# =============================================================================
def test_memory_type_enum():
"""Test MemoryType enum values."""
assert tab.MemoryType.DRAM is not None
assert tab.MemoryType.VRAM is not None
assert tab.MemoryType.BLK is not None
assert tab.MemoryType.OBJ is not None
assert tab.MemoryType.FILE is not None
# Verify they are distinct
assert tab.MemoryType.DRAM != tab.MemoryType.VRAM
assert tab.MemoryType.VRAM != tab.MemoryType.BLK
def test_transfer_op_enum():
"""Test TransferOp enum values."""
assert tab.TransferOp.READ is not None
assert tab.TransferOp.WRITE is not None
assert tab.TransferOp.READ != tab.TransferOp.WRITE
def test_transfer_state_enum():
"""Test TransferState enum values."""
assert tab.TransferState.IN_PROGRESS is not None
assert tab.TransferState.SUCCESS is not None
assert tab.TransferState.FAILURE is not None
# Verify they are distinct
assert tab.TransferState.IN_PROGRESS != tab.TransferState.SUCCESS
assert tab.TransferState.SUCCESS != tab.TransferState.FAILURE
assert tab.TransferState.IN_PROGRESS != tab.TransferState.FAILURE
def test_memory_desc():
"""Test MemoryDesc class."""
addr = 0x1000
length = 4096
device_id = 0
desc = tab.MemoryDesc(addr, length, device_id)
assert desc.addr == addr
assert desc.len == length
assert desc.device_id == device_id
def test_memory_desc_different_values():
"""Test MemoryDesc with different values."""
test_cases = [
(0x0, 1, 0),
(0xFFFFFFFF, 65536, 1),
(0x12345678, 1024, 7),
]
for addr, length, device_id in test_cases:
desc = tab.MemoryDesc(addr, length, device_id)
assert desc.addr == addr
assert desc.len == length
assert desc.device_id == device_id
def test_memory_descs():
"""Test MemoryDescs class."""
desc1 = tab.MemoryDesc(0x1000, 4096, 0)
desc2 = tab.MemoryDesc(0x2000, 8192, 0)
descs = tab.MemoryDescs(tab.MemoryType.VRAM, [desc1, desc2])
assert descs.type == tab.MemoryType.VRAM
assert len(descs.descs) == 2
assert descs.descs[0].addr == 0x1000
assert descs.descs[1].addr == 0x2000
def test_memory_descs_empty():
"""Test MemoryDescs with empty list."""
descs = tab.MemoryDescs(tab.MemoryType.DRAM, [])
assert descs.type == tab.MemoryType.DRAM
assert len(descs.descs) == 0
def test_agent_desc_from_string():
"""Test AgentDesc from string."""
test_data = "test_agent_descriptor"
desc = tab.AgentDesc(test_data)
assert desc.backend_agent_desc == test_data.encode()
def test_agent_desc_from_bytes():
"""Test AgentDesc from bytes."""
test_data = b"test_binary_data\x00\x01\x02"
desc = tab.AgentDesc(test_data)
assert desc.backend_agent_desc == test_data
def test_base_agent_config_default():
"""Test BaseAgentConfig with default values."""
config = tab.BaseAgentConfig()
# Default values should be set
assert config is not None
def test_base_agent_config_custom():
"""Test BaseAgentConfig with custom values."""
name = "test_agent"
use_prog_thread = True
multi_thread = False
use_listen_thread = True
num_workers = 4
config = tab.BaseAgentConfig(
name=name,
use_prog_thread=use_prog_thread,
multi_thread=multi_thread,
use_listen_thread=use_listen_thread,
num_workers=num_workers,
)
assert config.name == name
assert config.use_prog_thread == use_prog_thread
assert config.multi_thread == multi_thread
assert config.use_listen_thread == use_listen_thread
assert config.num_workers == num_workers
def test_base_agent_config_readwrite():
"""Test BaseAgentConfig read/write properties."""
config = tab.BaseAgentConfig()
config.name = "modified_name"
assert config.name == "modified_name"
config.use_prog_thread = False
assert config.use_prog_thread is False
config.multi_thread = True
assert config.multi_thread is True
config.use_listen_thread = True
assert config.use_listen_thread is True
config.num_workers = 8
assert config.num_workers == 8
def test_transfer_request():
"""Test TransferRequest class."""
src_desc = tab.MemoryDesc(0x1000, 4096, 0)
dst_desc = tab.MemoryDesc(0x2000, 4096, 1)
src_descs = tab.MemoryDescs(tab.MemoryType.VRAM, [src_desc])
dst_descs = tab.MemoryDescs(tab.MemoryType.VRAM, [dst_desc])
remote_name = "remote_agent"
request = tab.TransferRequest(tab.TransferOp.WRITE, src_descs, dst_descs, remote_name)
assert request.op == tab.TransferOp.WRITE
assert request.remote_name == remote_name
assert request.src_descs.type == tab.MemoryType.VRAM
assert request.dst_descs.type == tab.MemoryType.VRAM
def test_transfer_request_read_op():
"""Test TransferRequest with READ operation."""
src_desc = tab.MemoryDesc(0x3000, 2048, 0)
dst_desc = tab.MemoryDesc(0x4000, 2048, 0)
src_descs = tab.MemoryDescs(tab.MemoryType.DRAM, [src_desc])
dst_descs = tab.MemoryDescs(tab.MemoryType.DRAM, [dst_desc])
request = tab.TransferRequest(tab.TransferOp.READ, src_descs, dst_descs, "another_remote")
assert request.op == tab.TransferOp.READ
assert request.remote_name == "another_remote"
def test_backend_availability_flags():
"""Test that backend availability flags are exposed."""
# These should always be defined (either True or False)
assert hasattr(tab, "NIXL_ENABLED")
assert hasattr(tab, "MOONCAKE_ENABLED")
assert isinstance(tab.NIXL_ENABLED, bool)
assert isinstance(tab.MOONCAKE_ENABLED, bool)
# =============================================================================
# NIXL-specific Tests
# =============================================================================
@pytest.mark.skipif(not HAS_NIXL, reason="NIXL backend not available")
class TestNixlTransferAgent:
"""Test cases for NixlTransferAgent."""
def test_nixl_transfer_agent_class_exists(self):
"""Test that NixlTransferAgent class exists."""
assert hasattr(tab, "NixlTransferAgent")
def test_nixl_transfer_status_class_exists(self):
"""Test that NixlTransferStatus class exists."""
assert hasattr(tab, "NixlTransferStatus")
def test_nixl_transfer_agent_is_base_subclass(self):
"""Test that NixlTransferAgent is a subclass of BaseTransferAgent."""
assert issubclass(tab.NixlTransferAgent, tab.BaseTransferAgent)
def test_nixl_transfer_status_is_base_subclass(self):
"""Test that NixlTransferStatus is a subclass of TransferStatus."""
assert issubclass(tab.NixlTransferStatus, tab.TransferStatus)
def test_nixl_transfer_agent_has_required_methods(self):
"""Test that NixlTransferAgent has all required methods."""
required_methods = [
"register_memory",
"deregister_memory",
"load_remote_agent",
"load_remote_agent_by_connection",
"get_local_agent_desc",
"get_local_connection_info",
"invalidate_remote_agent",
"submit_transfer_requests",
"notify_sync_message",
"get_notified_sync_messages",
"check_remote_descs",
]
for method in required_methods:
assert hasattr(tab.NixlTransferAgent, method), f"Missing method: {method}"
# =============================================================================
# Mooncake-specific Tests
# =============================================================================
@pytest.mark.skipif(not HAS_MOONCAKE, reason="Mooncake backend not available")
class TestMooncakeTransferAgent:
"""Test cases for MooncakeTransferAgent."""
def test_mooncake_transfer_agent_class_exists(self):
"""Test that MooncakeTransferAgent class exists."""
assert hasattr(tab, "MooncakeTransferAgent")
def test_mooncake_transfer_status_class_exists(self):
"""Test that MooncakeTransferStatus class exists."""
assert hasattr(tab, "MooncakeTransferStatus")
def test_mooncake_transfer_agent_is_base_subclass(self):
"""Test that MooncakeTransferAgent is a subclass of BaseTransferAgent."""
assert issubclass(tab.MooncakeTransferAgent, tab.BaseTransferAgent)
def test_mooncake_transfer_status_is_base_subclass(self):
"""Test that MooncakeTransferStatus is a subclass of TransferStatus."""
assert issubclass(tab.MooncakeTransferStatus, tab.TransferStatus)
def test_mooncake_transfer_agent_has_required_methods(self):
"""Test that MooncakeTransferAgent has all required methods."""
required_methods = [
"register_memory",
"deregister_memory",
"load_remote_agent",
"load_remote_agent_by_connection",
"get_local_agent_desc",
"get_local_connection_info",
"invalidate_remote_agent",
"submit_transfer_requests",
"notify_sync_message",
"get_notified_sync_messages",
"check_remote_descs",
]
for method in required_methods:
assert hasattr(tab.MooncakeTransferAgent, method), f"Missing method: {method}"
# =============================================================================
# Functional Tests - Data Transfer Validation
# =============================================================================
def _create_memory_descs_from_tensor(tensor, memory_type):
"""Helper to create MemoryDescs from a torch tensor."""
addr = tensor.data_ptr()
size = tensor.numel() * tensor.element_size()
device_id = tensor.device.index if tensor.is_cuda else 0
desc = tab.MemoryDesc(addr, size, device_id)
return tab.MemoryDescs(memory_type, [desc])
@pytest.mark.skipif(
not (HAS_TORCH and HAS_CUDA),
reason="Torch with CUDA support required for functional tests",
)
@pytest.mark.skipif(not HAS_NIXL, reason="NIXL backend not available")
class TestNixlFunctionalTransfer:
"""Functional tests for NIXL data transfer between two agents."""
def test_nixl_write_transfer_gpu_tensor(self):
"""Test WRITE transfer of GPU tensor data between two NIXL agents."""
device = torch.device("cuda:0")
# Create source tensor with known data pattern
src_tensor = torch.arange(1024, dtype=torch.float32, device=device)
# Create destination tensor (zeros)
dst_tensor = torch.zeros(1024, dtype=torch.float32, device=device)
# Verify initial state
assert not torch.equal(src_tensor, dst_tensor)
# Create two agents
config_a = tab.BaseAgentConfig(
name="agent_a",
use_prog_thread=True,
use_listen_thread=False,
)
config_b = tab.BaseAgentConfig(
name="agent_b",
use_prog_thread=True,
use_listen_thread=False,
)
agent_a = tab.NixlTransferAgent(config_a)
agent_b = tab.NixlTransferAgent(config_b)
# Register memory regions
src_descs = _create_memory_descs_from_tensor(src_tensor, tab.MemoryType.VRAM)
dst_descs = _create_memory_descs_from_tensor(dst_tensor, tab.MemoryType.VRAM)
agent_a.register_memory(src_descs)
agent_b.register_memory(dst_descs)
# Exchange agent descriptors
agent_a_desc = agent_a.get_local_agent_desc()
agent_b_desc = agent_b.get_local_agent_desc()
agent_a.load_remote_agent("agent_b", agent_b_desc)
agent_b.load_remote_agent("agent_a", agent_a_desc)
# Create transfer request: agent_a writes src_tensor to agent_b's dst_tensor
request = tab.TransferRequest(
tab.TransferOp.WRITE,
src_descs, # local source
dst_descs, # remote destination
"agent_b", # remote agent name
)
# Submit transfer and wait for completion
status = agent_a.submit_transfer_requests(request)
result = status.wait(timeout_ms=5000)
assert result == tab.TransferState.SUCCESS, f"Transfer failed with state: {result}"
# Synchronize CUDA to ensure transfer is complete
torch.cuda.synchronize()
# Verify data was transferred correctly
assert torch.equal(src_tensor, dst_tensor), "Data mismatch after transfer"
# Cleanup
agent_a.deregister_memory(src_descs)
agent_b.deregister_memory(dst_descs)
def test_nixl_write_transfer_multiple_chunks(self):
"""Test WRITE transfer with multiple memory chunks."""
device = torch.device("cuda:0")
# Create multiple source tensors
src_tensors = [
torch.arange(i * 256, (i + 1) * 256, dtype=torch.float32, device=device)
for i in range(4)
]
# Create corresponding destination tensors
dst_tensors = [torch.zeros(256, dtype=torch.float32, device=device) for _ in range(4)]
# Create agents
config_a = tab.BaseAgentConfig(
name="agent_a", use_prog_thread=True, use_listen_thread=False
)
config_b = tab.BaseAgentConfig(
name="agent_b", use_prog_thread=True, use_listen_thread=False
)
agent_a = tab.NixlTransferAgent(config_a)
agent_b = tab.NixlTransferAgent(config_b)
# Create memory descriptors for all chunks
src_memory_descs = []
dst_memory_descs = []
for src, dst in zip(src_tensors, dst_tensors):
src_memory_descs.append(
tab.MemoryDesc(src.data_ptr(), src.numel() * src.element_size(), 0)
)
dst_memory_descs.append(
tab.MemoryDesc(dst.data_ptr(), dst.numel() * dst.element_size(), 0)
)
src_descs = tab.MemoryDescs(tab.MemoryType.VRAM, src_memory_descs)
dst_descs = tab.MemoryDescs(tab.MemoryType.VRAM, dst_memory_descs)
# Register memory
agent_a.register_memory(src_descs)
agent_b.register_memory(dst_descs)
# Exchange agent info
agent_a.load_remote_agent("agent_b", agent_b.get_local_agent_desc())
agent_b.load_remote_agent("agent_a", agent_a.get_local_agent_desc())
# Transfer
request = tab.TransferRequest(tab.TransferOp.WRITE, src_descs, dst_descs, "agent_b")
status = agent_a.submit_transfer_requests(request)
result = status.wait(timeout_ms=5000)
assert result == tab.TransferState.SUCCESS
torch.cuda.synchronize()
# Verify all chunks
for i, (src, dst) in enumerate(zip(src_tensors, dst_tensors)):
assert torch.equal(src, dst), f"Data mismatch in chunk {i}"
# Cleanup
agent_a.deregister_memory(src_descs)
agent_b.deregister_memory(dst_descs)
@pytest.mark.skipif(
not (HAS_TORCH and HAS_CUDA),
reason="Torch with CUDA support required for functional tests",
)
@pytest.mark.skipif(not HAS_MOONCAKE, reason="Mooncake backend not available")
class TestMooncakeFunctionalTransfer:
"""Functional tests for Mooncake data transfer between two agents."""
def test_mooncake_write_transfer_gpu_tensor(self):
"""Test WRITE transfer of GPU tensor data between two Mooncake agents."""
device = torch.device("cuda:0")
# Create source tensor with known data pattern
src_tensor = torch.arange(1024, dtype=torch.float32, device=device)
# Create destination tensor (zeros)
dst_tensor = torch.zeros(1024, dtype=torch.float32, device=device)
# Verify initial state
assert not torch.equal(src_tensor, dst_tensor)
# Create two agents
config_a = tab.BaseAgentConfig(name="mooncake_agent_a", use_prog_thread=True)
config_b = tab.BaseAgentConfig(name="mooncake_agent_b", use_prog_thread=True)
agent_a = tab.MooncakeTransferAgent(config_a)
agent_b = tab.MooncakeTransferAgent(config_b)
# Register memory regions
src_descs = _create_memory_descs_from_tensor(src_tensor, tab.MemoryType.VRAM)
dst_descs = _create_memory_descs_from_tensor(dst_tensor, tab.MemoryType.VRAM)
agent_a.register_memory(src_descs)
agent_b.register_memory(dst_descs)
agent_a_desc = agent_a.get_local_agent_desc()
agent_b_desc = agent_b.get_local_agent_desc()
agent_a.load_remote_agent("mooncake_agent_b", agent_b_desc)
agent_b.load_remote_agent("mooncake_agent_a", agent_a_desc)
request = tab.TransferRequest(
tab.TransferOp.WRITE,
src_descs, # local source
dst_descs, # remote destination
"mooncake_agent_b", # remote agent name
)
# # Submit transfer and wait for completion
status = agent_a.submit_transfer_requests(request)
result = status.wait()
assert result == tab.TransferState.SUCCESS, f"Transfer failed with state: {result}"
# Synchronize CUDA to ensure transfer is complete
torch.cuda.synchronize()
# Verify data was transferred correctly
assert torch.equal(src_tensor, dst_tensor), "Data mismatch after transfer"
# Cleanup
agent_a.deregister_memory(src_descs)
agent_b.deregister_memory(dst_descs)
def test_mooncake_write_transfer_multiple_chunks(self):
"""Test WRITE transfer with multiple memory chunks."""
device = torch.device("cuda:0")
# Create multiple source tensors
src_tensors = [
torch.arange(i * 256, (i + 1) * 256, dtype=torch.float32, device=device)
for i in range(4)
]
# Create corresponding destination tensors
dst_tensors = [torch.zeros(256, dtype=torch.float32, device=device) for _ in range(4)]
# Create agents
config_a = tab.BaseAgentConfig(name="mooncake_agent_a", use_prog_thread=True)
config_b = tab.BaseAgentConfig(name="mooncake_agent_b", use_prog_thread=True)
agent_a = tab.MooncakeTransferAgent(config_a)
agent_b = tab.MooncakeTransferAgent(config_b)
# Create memory descriptors for all chunks
src_memory_descs = []
dst_memory_descs = []
for src, dst in zip(src_tensors, dst_tensors):
src_memory_descs.append(
tab.MemoryDesc(src.data_ptr(), src.numel() * src.element_size(), 0)
)
dst_memory_descs.append(
tab.MemoryDesc(dst.data_ptr(), dst.numel() * dst.element_size(), 0)
)
src_descs = tab.MemoryDescs(tab.MemoryType.VRAM, src_memory_descs)
dst_descs = tab.MemoryDescs(tab.MemoryType.VRAM, dst_memory_descs)
# Register memory
agent_a.register_memory(src_descs)
agent_b.register_memory(dst_descs)
# Exchange agent info
agent_a.load_remote_agent("mooncake_agent_b", agent_b.get_local_agent_desc())
agent_b.load_remote_agent("mooncake_agent_a", agent_a.get_local_agent_desc())
# Transfer
request = tab.TransferRequest(
tab.TransferOp.WRITE, src_descs, dst_descs, "mooncake_agent_b"
)
status = agent_a.submit_transfer_requests(request)
result = status.wait(timeout_ms=5000)
assert result == tab.TransferState.SUCCESS
torch.cuda.synchronize()
# Verify all chunks
for i, (src, dst) in enumerate(zip(src_tensors, dst_tensors)):
assert torch.equal(src, dst), f"Data mismatch in chunk {i}"
# Cleanup
agent_a.deregister_memory(src_descs)
agent_b.deregister_memory(dst_descs)