mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 14:07:21 +08:00
[TRTLLM-9527][feat] Add transferAgent binding (step 1) (#10113)
Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
This commit is contained in:
parent
846e54aa09
commit
536a8f6a9c
2
.gitignore
vendored
2
.gitignore
vendored
@ -40,6 +40,8 @@ tensorrt_llm/libs
|
||||
tensorrt_llm/bindings.*.so
|
||||
tensorrt_llm/bindings.pyi
|
||||
tensorrt_llm/bindings/**/*.pyi
|
||||
tensorrt_llm/tensorrt_llm_transfer_agent_binding.*.so
|
||||
tensorrt_llm/tensorrt_llm_transfer_agent_binding.pyi
|
||||
tensorrt_llm/deep_ep/
|
||||
tensorrt_llm/deep_ep_cpp_tllm.*.so
|
||||
tensorrt_llm/deep_ep_cpp_tllm.pyi
|
||||
|
||||
@ -274,13 +274,20 @@ private:
|
||||
std::optional<SyncMessage> mSyncMessage;
|
||||
};
|
||||
|
||||
enum class TransferState : uint8_t
|
||||
{
|
||||
kIN_PROGRESS,
|
||||
kSUCCESS,
|
||||
kFAILURE,
|
||||
};
|
||||
|
||||
// Data structure for checking the status of active transfer operations.
|
||||
class TransferStatus
|
||||
{
|
||||
public:
|
||||
virtual ~TransferStatus() = default;
|
||||
[[nodiscard]] virtual bool isCompleted() const = 0;
|
||||
virtual void wait() const = 0;
|
||||
virtual TransferState wait(int64_t timeout_ms = -1) const = 0;
|
||||
};
|
||||
|
||||
struct BaseAgentConfig
|
||||
@ -288,6 +295,8 @@ struct BaseAgentConfig
|
||||
std::string mName;
|
||||
bool useProgThread;
|
||||
bool multiThread;
|
||||
bool useListenThread;
|
||||
unsigned int numWorkers;
|
||||
};
|
||||
|
||||
class BaseTransferAgent
|
||||
|
||||
@ -157,6 +157,7 @@ set(UCX_WRAPPER_TARGET tensorrt_llm_ucx_wrapper)
|
||||
|
||||
if(NIXL_ROOT)
|
||||
set(NIXL_WRAPPER_TARGET tensorrt_llm_nixl_wrapper)
|
||||
set(TRANSFER_AGENT_BINDING_TARGET tensorrt_llm_transfer_agent_binding)
|
||||
endif()
|
||||
|
||||
if(MOONCAKE_ROOT)
|
||||
|
||||
@ -90,5 +90,5 @@ target_compile_definitions(${EXECUTOR_STATIC_TARGET}
|
||||
PUBLIC TOP_LEVEL_DIR="${TOP_LEVEL_DIR}")
|
||||
|
||||
add_subdirectory(cache_transmission/ucx_utils)
|
||||
add_subdirectory(cache_transmission/nixl_utils)
|
||||
add_subdirectory(cache_transmission/mooncake_utils)
|
||||
add_subdirectory(cache_transmission/nixl_utils)
|
||||
|
||||
@ -141,7 +141,8 @@ void AgentConnection::send(DataContext const& ctx, void const* data, size_t size
|
||||
NotificationInfo notificationInfo{syncInfo};
|
||||
std::stringstream ss;
|
||||
NotificationInfo::serialize(notificationInfo, ss);
|
||||
status->wait();
|
||||
TransferState transferState = status->wait();
|
||||
TLLM_CHECK_WITH_INFO(transferState == TransferState::kSUCCESS, "AgentConnection::send failed");
|
||||
// TODO: there is a bug in request_with_notify https://github.com/ai-dynamo/nixl/pull/252
|
||||
mAgentConnectionManager->getAgent()->notifySyncMessage(mRemoteAgentName, ss.str());
|
||||
}
|
||||
@ -246,7 +247,7 @@ AgentConnectionManager::AgentConnectionManager(
|
||||
|
||||
mAgentName = genUniqueAgentName();
|
||||
// Create Agent
|
||||
BaseAgentConfig config{mAgentName, true};
|
||||
BaseAgentConfig config{mAgentName, true, false, true, 1};
|
||||
m_Agent = makeTransferAgent(backendType, &config);
|
||||
TLLM_CHECK(!mCacheTransBufferManagers.empty());
|
||||
std::vector<MemoryDesc> memDescs;
|
||||
|
||||
@ -36,5 +36,10 @@ if(MOONCAKE_ROOT AND NOT IS_ROCKY8)
|
||||
|
||||
target_link_libraries(${MOONCAKE_WRAPPER_TARGET}
|
||||
PRIVATE ${TRANSFER_ENGINE_LIB} CUDA::cudart)
|
||||
|
||||
# Export variables to parent scope for transfer_agent_binding
|
||||
set(TRANSFER_ENGINE_INCLUDE_DIR
|
||||
${TRANSFER_ENGINE_INCLUDE_DIR}
|
||||
PARENT_SCOPE)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@ -47,11 +47,77 @@ MooncakeTransferStatus::MooncakeTransferStatus(transfer_engine_t engine, uint64_
|
||||
TLLM_CHECK(mEngine);
|
||||
}
|
||||
|
||||
void MooncakeTransferStatus::wait() const
|
||||
TransferState MooncakeTransferStatus::wait(int64_t timeout_ms) const
|
||||
{
|
||||
while (!isCompleted())
|
||||
auto startTime = std::chrono::steady_clock::now();
|
||||
|
||||
while (true)
|
||||
{
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1));
|
||||
if (mBatchFreed)
|
||||
{
|
||||
return TransferState::kSUCCESS;
|
||||
}
|
||||
|
||||
bool has_failed = false;
|
||||
bool all_completed = true;
|
||||
|
||||
for (size_t index = 0; index < mRequestCount; ++index)
|
||||
{
|
||||
transfer_status_t status;
|
||||
int rc = getTransferStatus(mEngine, mBatchId, index, &status);
|
||||
if (rc || status.status == STATUS_FAILED)
|
||||
{
|
||||
has_failed = true;
|
||||
if (rc)
|
||||
{
|
||||
TLLM_LOG_ERROR(
|
||||
"Failed to get transfer status for batch %lu, task %zu: error code %d", mBatchId, index, rc);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_ERROR(
|
||||
"Transfer failed for batch %lu, task %zu: status %d", mBatchId, index, status.status);
|
||||
}
|
||||
}
|
||||
else if (status.status != STATUS_COMPLETED)
|
||||
{
|
||||
all_completed = false;
|
||||
}
|
||||
}
|
||||
|
||||
// If any request failed, return failure
|
||||
if (has_failed)
|
||||
{
|
||||
return TransferState::kFAILURE;
|
||||
}
|
||||
|
||||
// If all requests completed successfully
|
||||
if (all_completed)
|
||||
{
|
||||
freeBatchID(mEngine, mBatchId);
|
||||
mBatchFreed = true;
|
||||
TLLM_LOG_DEBUG("Batch ID %lu freed in wait()", mBatchId);
|
||||
syncSegmentCache(mEngine);
|
||||
return TransferState::kSUCCESS;
|
||||
}
|
||||
|
||||
// If timeout_ms < 0, wait indefinitely
|
||||
if (timeout_ms < 0)
|
||||
{
|
||||
std::this_thread::yield();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if timeout has elapsed
|
||||
auto elapsed
|
||||
= std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - startTime)
|
||||
.count();
|
||||
if (elapsed >= timeout_ms)
|
||||
{
|
||||
return TransferState::kIN_PROGRESS;
|
||||
}
|
||||
|
||||
std::this_thread::yield();
|
||||
}
|
||||
}
|
||||
|
||||
@ -100,7 +166,7 @@ void MooncakeTransferStatus::wait() const
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::string MooncakeBase64Helper::STANDARD_CHARS
|
||||
std::string const MooncakeBase64Helper::STANDARD_CHARS
|
||||
= "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
"abcdefghijklmnopqrstuvwxyz"
|
||||
"0123456789+/";
|
||||
@ -361,7 +427,7 @@ AgentDesc MooncakeTransferAgent::getLocalAgentDesc()
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::getLocalAgentDesc");
|
||||
|
||||
// Using connection info as agent desc
|
||||
const static size_t kBufLen = 64;
|
||||
static size_t const kBufLen = 64;
|
||||
char connectionInfo[kBufLen];
|
||||
|
||||
int ret = getLocalIpAndPort(mEngine, connectionInfo, kBufLen);
|
||||
@ -375,7 +441,7 @@ ConnectionInfoType MooncakeTransferAgent::getLocalConnectionInfo()
|
||||
{
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::getLocalConnectionInfo");
|
||||
|
||||
const static size_t kBufLen = 64;
|
||||
static size_t const kBufLen = 64;
|
||||
char connectionInfo[kBufLen];
|
||||
|
||||
int ret = getLocalIpAndPort(mEngine, connectionInfo, kBufLen);
|
||||
@ -399,7 +465,7 @@ ConnectionInfoType MooncakeTransferAgent::getLocalConnectionInfo()
|
||||
syncMessage = request.getSyncMessage().value();
|
||||
}
|
||||
|
||||
const static size_t kMaxRequestCount = 1024;
|
||||
static size_t const kMaxRequestCount = 1024;
|
||||
uint64_t batchId = allocateBatchID(mEngine, kMaxRequestCount);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(batchId != INVALID_BATCH, "allocateBatchID failed");
|
||||
|
||||
@ -35,7 +35,7 @@ public:
|
||||
|
||||
[[nodiscard]] bool isCompleted() const override;
|
||||
|
||||
void wait() const override;
|
||||
TransferState wait(int64_t timeout_ms = -1) const override;
|
||||
|
||||
private:
|
||||
transfer_engine_t mEngine;
|
||||
|
||||
@ -13,6 +13,9 @@
|
||||
# License for the specific language governing permissions and limitations under
|
||||
# the License.
|
||||
|
||||
# ============================================================================
|
||||
# NIXL Wrapper Library
|
||||
# ============================================================================
|
||||
if(NIXL_ROOT)
|
||||
find_package(NIXL REQUIRED)
|
||||
# Check if all required packages were found
|
||||
@ -30,6 +33,8 @@ if(NIXL_ROOT)
|
||||
|
||||
# Add include directories
|
||||
target_include_directories(${NIXL_WRAPPER_TARGET} PRIVATE NIXL::nixl)
|
||||
target_include_directories(${NIXL_WRAPPER_TARGET}
|
||||
PRIVATE ${PROJECT_SOURCE_DIR}/include)
|
||||
|
||||
# Link against all NIXL libraries
|
||||
target_link_libraries(${NIXL_WRAPPER_TARGET} PRIVATE NIXL::nixl)
|
||||
@ -37,4 +42,85 @@ if(NIXL_ROOT)
|
||||
# Link against CUDA
|
||||
target_link_libraries(${NIXL_WRAPPER_TARGET} PRIVATE CUDA::cudart)
|
||||
|
||||
set(NIXL_ENABLED TRUE)
|
||||
else()
|
||||
set(NIXL_ENABLED FALSE)
|
||||
endif()
|
||||
|
||||
# ============================================================================
|
||||
# Check if Mooncake wrapper is available (built in mooncake_utils)
|
||||
# ============================================================================
|
||||
if(MOONCAKE_ROOT AND TARGET tensorrt_llm_mooncake_wrapper)
|
||||
set(MOONCAKE_ENABLED TRUE)
|
||||
else()
|
||||
set(MOONCAKE_ENABLED FALSE)
|
||||
endif()
|
||||
|
||||
# ============================================================================
|
||||
# TensorRT-LLM Transfer Agent Binding Python Module Build if either NIXL or
|
||||
# Mooncake is enabled
|
||||
# ============================================================================
|
||||
if(NIXL_ENABLED OR MOONCAKE_ENABLED)
|
||||
set(TRANSFER_AGENT_BINDING_TARGET "tensorrt_llm_transfer_agent_binding")
|
||||
|
||||
# Collect binding source files
|
||||
set(AGENT_BINDING_SOURCES "")
|
||||
if(BINDING_TYPE STREQUAL "pybind")
|
||||
list(APPEND AGENT_BINDING_SOURCES agentBindingsPybind.cpp)
|
||||
else()
|
||||
list(APPEND AGENT_BINDING_SOURCES agentBindingsNanobind.cpp)
|
||||
endif()
|
||||
|
||||
if(BINDING_TYPE STREQUAL "pybind")
|
||||
# Use pybind11 (already fetched via FetchContent)
|
||||
pybind11_add_module(${TRANSFER_AGENT_BINDING_TARGET}
|
||||
${AGENT_BINDING_SOURCES})
|
||||
message(STATUS "Building tensorrt_llm_transfer_agent_binding with pybind11")
|
||||
else()
|
||||
# Default to nanobind (already fetched via FetchContent)
|
||||
nanobind_add_module(${TRANSFER_AGENT_BINDING_TARGET}
|
||||
${AGENT_BINDING_SOURCES})
|
||||
message(STATUS "Building tensorrt_llm_transfer_agent_binding with nanobind")
|
||||
endif()
|
||||
|
||||
target_compile_options(${TRANSFER_AGENT_BINDING_TARGET} PRIVATE -Wno-error)
|
||||
|
||||
# Add common include directories
|
||||
target_include_directories(${TRANSFER_AGENT_BINDING_TARGET}
|
||||
PRIVATE ${PROJECT_SOURCE_DIR}/include)
|
||||
|
||||
# Conditionally add NIXL support
|
||||
if(NIXL_ENABLED)
|
||||
target_compile_definitions(${TRANSFER_AGENT_BINDING_TARGET}
|
||||
PRIVATE ENABLE_NIXL)
|
||||
target_include_directories(${TRANSFER_AGENT_BINDING_TARGET}
|
||||
PRIVATE NIXL::nixl)
|
||||
target_link_libraries(${TRANSFER_AGENT_BINDING_TARGET}
|
||||
PRIVATE ${NIXL_WRAPPER_TARGET})
|
||||
target_link_libraries(${TRANSFER_AGENT_BINDING_TARGET} PRIVATE NIXL::nixl)
|
||||
message(STATUS "Transfer agent binding: NIXL support enabled")
|
||||
endif()
|
||||
|
||||
# Conditionally add Mooncake support
|
||||
if(MOONCAKE_ENABLED)
|
||||
target_compile_definitions(${TRANSFER_AGENT_BINDING_TARGET}
|
||||
PRIVATE ENABLE_MOONCAKE)
|
||||
target_include_directories(${TRANSFER_AGENT_BINDING_TARGET}
|
||||
PRIVATE ${TRANSFER_ENGINE_INCLUDE_DIR})
|
||||
target_link_libraries(${TRANSFER_AGENT_BINDING_TARGET}
|
||||
PRIVATE tensorrt_llm_mooncake_wrapper)
|
||||
message(STATUS "Transfer agent binding: Mooncake support enabled")
|
||||
endif()
|
||||
|
||||
# Common dependencies
|
||||
target_link_libraries(${TRANSFER_AGENT_BINDING_TARGET} PRIVATE CUDA::cudart)
|
||||
target_link_libraries(${TRANSFER_AGENT_BINDING_TARGET}
|
||||
PRIVATE ${SHARED_TARGET})
|
||||
|
||||
# Set RPATH for the module to find wrapper libraries
|
||||
set_target_properties(
|
||||
${TRANSFER_AGENT_BINDING_TARGET}
|
||||
PROPERTIES BUILD_RPATH "$ORIGIN;$ORIGIN/libs;$ORIGIN/libs/nixl"
|
||||
INSTALL_RPATH "$ORIGIN;$ORIGIN/libs;$ORIGIN/libs/nixl")
|
||||
|
||||
endif()
|
||||
|
||||
@ -0,0 +1,239 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/executor/transferAgent.h"
|
||||
|
||||
#ifdef ENABLE_NIXL
|
||||
#include "transferAgent.h"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_MOONCAKE
|
||||
#include "../mooncake_utils/transferAgent.h"
|
||||
#endif
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/function.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/pair.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/unordered_map.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
namespace nb = nanobind;
|
||||
namespace kvc = tensorrt_llm::executor::kv_cache;
|
||||
|
||||
NB_MODULE(tensorrt_llm_transfer_agent_binding, m)
|
||||
{
|
||||
m.doc() = "TensorRT-LLM Transfer Agent Python bindings (nanobind)";
|
||||
|
||||
// MemoryType enum
|
||||
nb::enum_<kvc::MemoryType>(m, "MemoryType")
|
||||
.value("DRAM", kvc::MemoryType::kDRAM)
|
||||
.value("VRAM", kvc::MemoryType::kVRAM)
|
||||
.value("BLK", kvc::MemoryType::kBLK)
|
||||
.value("OBJ", kvc::MemoryType::kOBJ)
|
||||
.value("FILE", kvc::MemoryType::kFILE);
|
||||
|
||||
// TransferOp enum
|
||||
nb::enum_<kvc::TransferOp>(m, "TransferOp")
|
||||
.value("READ", kvc::TransferOp::kREAD)
|
||||
.value("WRITE", kvc::TransferOp::kWRITE);
|
||||
|
||||
// TransferState enum
|
||||
nb::enum_<kvc::TransferState>(m, "TransferState")
|
||||
.value("IN_PROGRESS", kvc::TransferState::kIN_PROGRESS)
|
||||
.value("SUCCESS", kvc::TransferState::kSUCCESS)
|
||||
.value("FAILURE", kvc::TransferState::kFAILURE);
|
||||
|
||||
// MemoryDesc class
|
||||
nb::class_<kvc::MemoryDesc>(m, "MemoryDesc")
|
||||
.def(nb::init<uintptr_t, size_t, uint32_t>(), nb::arg("addr"), nb::arg("len"), nb::arg("device_id"))
|
||||
.def_prop_ro("addr", &kvc::MemoryDesc::getAddr)
|
||||
.def_prop_ro("len", &kvc::MemoryDesc::getLen)
|
||||
.def_prop_ro("device_id", &kvc::MemoryDesc::getDeviceId);
|
||||
|
||||
// MemoryDescs class
|
||||
nb::class_<kvc::MemoryDescs>(m, "MemoryDescs")
|
||||
.def(nb::init<kvc::MemoryType, std::vector<kvc::MemoryDesc>>(), nb::arg("type"), nb::arg("descs"))
|
||||
.def_prop_ro("type", &kvc::MemoryDescs::getType)
|
||||
.def_prop_ro("descs", &kvc::MemoryDescs::getDescs);
|
||||
|
||||
// AgentDesc class
|
||||
nb::class_<kvc::AgentDesc>(m, "AgentDesc")
|
||||
.def(
|
||||
"__init__",
|
||||
[](kvc::AgentDesc* self, nb::bytes data)
|
||||
{
|
||||
std::string str(data.c_str(), data.size());
|
||||
new (self) kvc::AgentDesc{std::move(str)};
|
||||
},
|
||||
nb::arg("backend_agent_desc"))
|
||||
.def(nb::init<std::string>(), nb::arg("backend_agent_desc"))
|
||||
.def_prop_ro("backend_agent_desc",
|
||||
[](kvc::AgentDesc const& self)
|
||||
{
|
||||
auto const& desc = self.getBackendAgentDesc();
|
||||
return nb::bytes(desc.data(), desc.size());
|
||||
});
|
||||
|
||||
// TransferRequest class
|
||||
nb::class_<kvc::TransferRequest>(m, "TransferRequest")
|
||||
.def(nb::init<kvc::TransferOp, kvc::TransferDescs, kvc::TransferDescs, std::string const&,
|
||||
std::optional<kvc::SyncMessage>>(),
|
||||
nb::arg("op"), nb::arg("src_descs"), nb::arg("dst_descs"), nb::arg("remote_name"),
|
||||
nb::arg("sync_message") = std::nullopt)
|
||||
.def_prop_ro("op", &kvc::TransferRequest::getOp)
|
||||
.def_prop_ro("src_descs", &kvc::TransferRequest::getSrcDescs)
|
||||
.def_prop_ro("dst_descs", &kvc::TransferRequest::getDstDescs)
|
||||
.def_prop_ro("remote_name", &kvc::TransferRequest::getRemoteName)
|
||||
.def_prop_ro("sync_message", &kvc::TransferRequest::getSyncMessage);
|
||||
|
||||
// TransferStatus base class
|
||||
nb::class_<kvc::TransferStatus>(m, "TransferStatus")
|
||||
.def("is_completed", &kvc::TransferStatus::isCompleted)
|
||||
.def("wait", &kvc::TransferStatus::wait, nb::arg("timeout_ms") = -1);
|
||||
|
||||
// BaseAgentConfig struct
|
||||
nb::class_<kvc::BaseAgentConfig>(m, "BaseAgentConfig")
|
||||
.def(nb::init<>())
|
||||
.def(
|
||||
"__init__",
|
||||
[](kvc::BaseAgentConfig* self, std::string name, bool use_prog_thread, bool multi_thread,
|
||||
bool use_listen_thread, unsigned int num_workers) {
|
||||
new (self) kvc::BaseAgentConfig{
|
||||
std::move(name), use_prog_thread, multi_thread, use_listen_thread, num_workers};
|
||||
},
|
||||
nb::arg("name"), nb::arg("use_prog_thread") = true, nb::arg("multi_thread") = false,
|
||||
nb::arg("use_listen_thread") = false, nb::arg("num_workers") = 1)
|
||||
.def_rw("name", &kvc::BaseAgentConfig::mName)
|
||||
.def_rw("use_prog_thread", &kvc::BaseAgentConfig::useProgThread)
|
||||
.def_rw("multi_thread", &kvc::BaseAgentConfig::multiThread)
|
||||
.def_rw("use_listen_thread", &kvc::BaseAgentConfig::useListenThread)
|
||||
.def_rw("num_workers", &kvc::BaseAgentConfig::numWorkers);
|
||||
|
||||
// BaseTransferAgent class (abstract base)
|
||||
nb::class_<kvc::BaseTransferAgent>(m, "BaseTransferAgent")
|
||||
.def("register_memory", &kvc::BaseTransferAgent::registerMemory, nb::arg("descs"))
|
||||
.def("deregister_memory", &kvc::BaseTransferAgent::deregisterMemory, nb::arg("descs"))
|
||||
.def("load_remote_agent",
|
||||
nb::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::BaseTransferAgent::loadRemoteAgent),
|
||||
nb::arg("name"), nb::arg("agent_desc"))
|
||||
.def("load_remote_agent_by_connection",
|
||||
nb::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
|
||||
&kvc::BaseTransferAgent::loadRemoteAgent),
|
||||
nb::arg("name"), nb::arg("connection_info"))
|
||||
.def("get_local_agent_desc", &kvc::BaseTransferAgent::getLocalAgentDesc)
|
||||
.def("invalidate_remote_agent", &kvc::BaseTransferAgent::invalidateRemoteAgent, nb::arg("name"))
|
||||
.def(
|
||||
"submit_transfer_requests",
|
||||
[](kvc::BaseTransferAgent& self, kvc::TransferRequest const& request)
|
||||
{ return self.submitTransferRequests(request).release(); },
|
||||
nb::arg("request"), nb::rv_policy::take_ownership)
|
||||
.def(
|
||||
"notify_sync_message", &kvc::BaseTransferAgent::notifySyncMessage, nb::arg("name"), nb::arg("sync_message"))
|
||||
.def("get_notified_sync_messages", &kvc::BaseTransferAgent::getNotifiedSyncMessages)
|
||||
.def("get_local_connection_info", &kvc::BaseTransferAgent::getLocalConnectionInfo)
|
||||
.def("check_remote_descs", &kvc::BaseTransferAgent::checkRemoteDescs, nb::arg("name"), nb::arg("memory_descs"));
|
||||
|
||||
#ifdef ENABLE_NIXL
|
||||
// NixlTransferStatus class - release GIL for blocking operations
|
||||
nb::class_<kvc::NixlTransferStatus, kvc::TransferStatus>(m, "NixlTransferStatus")
|
||||
.def("is_completed", &kvc::NixlTransferStatus::isCompleted, nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("wait", &kvc::NixlTransferStatus::wait, nb::arg("timeout_ms") = -1,
|
||||
nb::call_guard<nb::gil_scoped_release>());
|
||||
|
||||
// NixlTransferAgent class
|
||||
nb::class_<kvc::NixlTransferAgent, kvc::BaseTransferAgent>(m, "NixlTransferAgent")
|
||||
.def(nb::init<kvc::BaseAgentConfig const&>(), nb::arg("config"))
|
||||
.def("register_memory", &kvc::NixlTransferAgent::registerMemory, nb::arg("descs"))
|
||||
.def("deregister_memory", &kvc::NixlTransferAgent::deregisterMemory, nb::arg("descs"))
|
||||
.def("load_remote_agent",
|
||||
nb::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::NixlTransferAgent::loadRemoteAgent),
|
||||
nb::arg("name"), nb::arg("agent_desc"))
|
||||
.def("load_remote_agent_by_connection",
|
||||
nb::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
|
||||
&kvc::NixlTransferAgent::loadRemoteAgent),
|
||||
nb::arg("name"), nb::arg("connection_info"))
|
||||
.def("get_local_agent_desc", &kvc::NixlTransferAgent::getLocalAgentDesc)
|
||||
.def("get_local_connection_info", &kvc::NixlTransferAgent::getLocalConnectionInfo)
|
||||
.def("invalidate_remote_agent", &kvc::NixlTransferAgent::invalidateRemoteAgent, nb::arg("name"))
|
||||
.def(
|
||||
"submit_transfer_requests",
|
||||
[](kvc::NixlTransferAgent& self, kvc::TransferRequest const& request)
|
||||
{ return self.submitTransferRequests(request).release(); },
|
||||
nb::arg("request"), nb::rv_policy::take_ownership, nb::call_guard<nb::gil_scoped_release>())
|
||||
.def(
|
||||
"notify_sync_message", &kvc::NixlTransferAgent::notifySyncMessage, nb::arg("name"), nb::arg("sync_message"))
|
||||
.def("get_notified_sync_messages", &kvc::NixlTransferAgent::getNotifiedSyncMessages)
|
||||
.def("check_remote_descs", &kvc::NixlTransferAgent::checkRemoteDescs, nb::arg("name"), nb::arg("memory_descs"));
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_MOONCAKE
|
||||
// MooncakeTransferStatus class - release GIL for blocking operations
|
||||
nb::class_<kvc::MooncakeTransferStatus, kvc::TransferStatus>(m, "MooncakeTransferStatus")
|
||||
.def("is_completed", &kvc::MooncakeTransferStatus::isCompleted, nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("wait", &kvc::MooncakeTransferStatus::wait, nb::arg("timeout_ms") = -1,
|
||||
nb::call_guard<nb::gil_scoped_release>());
|
||||
|
||||
// MooncakeTransferAgent class
|
||||
nb::class_<kvc::MooncakeTransferAgent, kvc::BaseTransferAgent>(m, "MooncakeTransferAgent")
|
||||
.def(nb::init<kvc::BaseAgentConfig const&>(), nb::arg("config"))
|
||||
.def("register_memory", &kvc::MooncakeTransferAgent::registerMemory, nb::arg("descs"))
|
||||
.def("deregister_memory", &kvc::MooncakeTransferAgent::deregisterMemory, nb::arg("descs"))
|
||||
.def("load_remote_agent",
|
||||
nb::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::MooncakeTransferAgent::loadRemoteAgent),
|
||||
nb::arg("name"), nb::arg("agent_desc"))
|
||||
.def("load_remote_agent_by_connection",
|
||||
nb::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
|
||||
&kvc::MooncakeTransferAgent::loadRemoteAgent),
|
||||
nb::arg("name"), nb::arg("connection_info"))
|
||||
.def("get_local_agent_desc", &kvc::MooncakeTransferAgent::getLocalAgentDesc)
|
||||
.def("get_local_connection_info", &kvc::MooncakeTransferAgent::getLocalConnectionInfo)
|
||||
.def("invalidate_remote_agent", &kvc::MooncakeTransferAgent::invalidateRemoteAgent, nb::arg("name"))
|
||||
.def(
|
||||
"submit_transfer_requests",
|
||||
[](kvc::MooncakeTransferAgent& self, kvc::TransferRequest const& request)
|
||||
{ return self.submitTransferRequests(request).release(); },
|
||||
nb::arg("request"), nb::rv_policy::take_ownership, nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("notify_sync_message", &kvc::MooncakeTransferAgent::notifySyncMessage, nb::arg("name"),
|
||||
nb::arg("sync_message"))
|
||||
.def("get_notified_sync_messages", &kvc::MooncakeTransferAgent::getNotifiedSyncMessages)
|
||||
.def("check_remote_descs", &kvc::MooncakeTransferAgent::checkRemoteDescs, nb::arg("name"),
|
||||
nb::arg("memory_descs"));
|
||||
#endif
|
||||
|
||||
// Factory function to create transfer agent by backend name (uses dynamic loading)
|
||||
m.def(
|
||||
"make_transfer_agent",
|
||||
[](std::string const& backend, kvc::BaseAgentConfig const& config) -> kvc::BaseTransferAgent*
|
||||
{ return kvc::makeTransferAgent(backend, &config).release(); },
|
||||
nb::arg("backend"), nb::arg("config"), nb::rv_policy::take_ownership,
|
||||
"Create a transfer agent by backend name ('nixl' or 'mooncake'). Uses dynamic loading.");
|
||||
|
||||
// Expose which backends are available
|
||||
#ifdef ENABLE_NIXL
|
||||
m.attr("NIXL_ENABLED") = true;
|
||||
#else
|
||||
m.attr("NIXL_ENABLED") = false;
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_MOONCAKE
|
||||
m.attr("MOONCAKE_ENABLED") = true;
|
||||
#else
|
||||
m.attr("MOONCAKE_ENABLED") = false;
|
||||
#endif
|
||||
}
|
||||
@ -0,0 +1,234 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/executor/transferAgent.h"
|
||||
|
||||
#ifdef ENABLE_NIXL
|
||||
#include "transferAgent.h"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_MOONCAKE
|
||||
#include "../mooncake_utils/transferAgent.h"
|
||||
#endif
|
||||
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace kvc = tensorrt_llm::executor::kv_cache;
|
||||
|
||||
PYBIND11_MODULE(tensorrt_llm_transfer_agent_binding, m)
|
||||
{
|
||||
m.doc() = "TensorRT-LLM Transfer Agent Python bindings (pybind11)";
|
||||
|
||||
// MemoryType enum
|
||||
py::enum_<kvc::MemoryType>(m, "MemoryType")
|
||||
.value("DRAM", kvc::MemoryType::kDRAM)
|
||||
.value("VRAM", kvc::MemoryType::kVRAM)
|
||||
.value("BLK", kvc::MemoryType::kBLK)
|
||||
.value("OBJ", kvc::MemoryType::kOBJ)
|
||||
.value("FILE", kvc::MemoryType::kFILE);
|
||||
|
||||
// TransferOp enum
|
||||
py::enum_<kvc::TransferOp>(m, "TransferOp")
|
||||
.value("READ", kvc::TransferOp::kREAD)
|
||||
.value("WRITE", kvc::TransferOp::kWRITE);
|
||||
|
||||
// TransferState enum
|
||||
py::enum_<kvc::TransferState>(m, "TransferState")
|
||||
.value("IN_PROGRESS", kvc::TransferState::kIN_PROGRESS)
|
||||
.value("SUCCESS", kvc::TransferState::kSUCCESS)
|
||||
.value("FAILURE", kvc::TransferState::kFAILURE);
|
||||
|
||||
// MemoryDesc class
|
||||
py::class_<kvc::MemoryDesc>(m, "MemoryDesc")
|
||||
.def(py::init<uintptr_t, size_t, uint32_t>(), py::arg("addr"), py::arg("len"), py::arg("device_id"))
|
||||
.def_property_readonly("addr", &kvc::MemoryDesc::getAddr)
|
||||
.def_property_readonly("len", &kvc::MemoryDesc::getLen)
|
||||
.def_property_readonly("device_id", &kvc::MemoryDesc::getDeviceId);
|
||||
|
||||
// MemoryDescs class
|
||||
py::class_<kvc::MemoryDescs>(m, "MemoryDescs")
|
||||
.def(py::init<kvc::MemoryType, std::vector<kvc::MemoryDesc>>(), py::arg("type"), py::arg("descs"))
|
||||
.def_property_readonly("type", &kvc::MemoryDescs::getType)
|
||||
.def_property_readonly("descs", &kvc::MemoryDescs::getDescs);
|
||||
|
||||
// AgentDesc class
|
||||
py::class_<kvc::AgentDesc>(m, "AgentDesc")
|
||||
.def(py::init(
|
||||
[](py::bytes data)
|
||||
{
|
||||
std::string str(PyBytes_AsString(data.ptr()), PyBytes_Size(data.ptr()));
|
||||
return kvc::AgentDesc{std::move(str)};
|
||||
}),
|
||||
py::arg("backend_agent_desc"))
|
||||
.def(py::init<std::string>(), py::arg("backend_agent_desc"))
|
||||
.def_property_readonly("backend_agent_desc",
|
||||
[](kvc::AgentDesc const& self)
|
||||
{
|
||||
auto const& desc = self.getBackendAgentDesc();
|
||||
return py::bytes(desc.data(), desc.size());
|
||||
});
|
||||
|
||||
// TransferRequest class
|
||||
py::class_<kvc::TransferRequest>(m, "TransferRequest")
|
||||
.def(py::init<kvc::TransferOp, kvc::TransferDescs, kvc::TransferDescs, std::string const&,
|
||||
std::optional<kvc::SyncMessage>>(),
|
||||
py::arg("op"), py::arg("src_descs"), py::arg("dst_descs"), py::arg("remote_name"),
|
||||
py::arg("sync_message") = std::nullopt)
|
||||
.def_property_readonly("op", &kvc::TransferRequest::getOp)
|
||||
.def_property_readonly("src_descs", &kvc::TransferRequest::getSrcDescs)
|
||||
.def_property_readonly("dst_descs", &kvc::TransferRequest::getDstDescs)
|
||||
.def_property_readonly("remote_name", &kvc::TransferRequest::getRemoteName)
|
||||
.def_property_readonly("sync_message", &kvc::TransferRequest::getSyncMessage);
|
||||
|
||||
// TransferStatus base class
|
||||
py::class_<kvc::TransferStatus>(m, "TransferStatus")
|
||||
.def("is_completed", &kvc::TransferStatus::isCompleted)
|
||||
.def("wait", &kvc::TransferStatus::wait, py::arg("timeout_ms") = -1);
|
||||
|
||||
// BaseAgentConfig struct
|
||||
py::class_<kvc::BaseAgentConfig>(m, "BaseAgentConfig")
|
||||
.def(py::init<>())
|
||||
.def(py::init(
|
||||
[](std::string name, bool use_prog_thread, bool multi_thread, bool use_listen_thread,
|
||||
unsigned int num_workers) {
|
||||
return kvc::BaseAgentConfig{
|
||||
std::move(name), use_prog_thread, multi_thread, use_listen_thread, num_workers};
|
||||
}),
|
||||
py::arg("name"), py::arg("use_prog_thread") = true, py::arg("multi_thread") = false,
|
||||
py::arg("use_listen_thread") = false, py::arg("num_workers") = 1)
|
||||
.def_readwrite("name", &kvc::BaseAgentConfig::mName)
|
||||
.def_readwrite("use_prog_thread", &kvc::BaseAgentConfig::useProgThread)
|
||||
.def_readwrite("multi_thread", &kvc::BaseAgentConfig::multiThread)
|
||||
.def_readwrite("use_listen_thread", &kvc::BaseAgentConfig::useListenThread)
|
||||
.def_readwrite("num_workers", &kvc::BaseAgentConfig::numWorkers);
|
||||
|
||||
// BaseTransferAgent class (abstract base)
|
||||
py::class_<kvc::BaseTransferAgent>(m, "BaseTransferAgent")
|
||||
.def("register_memory", &kvc::BaseTransferAgent::registerMemory, py::arg("descs"))
|
||||
.def("deregister_memory", &kvc::BaseTransferAgent::deregisterMemory, py::arg("descs"))
|
||||
.def("load_remote_agent",
|
||||
py::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::BaseTransferAgent::loadRemoteAgent),
|
||||
py::arg("name"), py::arg("agent_desc"))
|
||||
.def("load_remote_agent_by_connection",
|
||||
py::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
|
||||
&kvc::BaseTransferAgent::loadRemoteAgent),
|
||||
py::arg("name"), py::arg("connection_info"))
|
||||
.def("get_local_agent_desc", &kvc::BaseTransferAgent::getLocalAgentDesc)
|
||||
.def("invalidate_remote_agent", &kvc::BaseTransferAgent::invalidateRemoteAgent, py::arg("name"))
|
||||
.def(
|
||||
"submit_transfer_requests",
|
||||
[](kvc::BaseTransferAgent& self, kvc::TransferRequest const& request)
|
||||
{ return self.submitTransferRequests(request).release(); },
|
||||
py::arg("request"), py::return_value_policy::take_ownership)
|
||||
.def(
|
||||
"notify_sync_message", &kvc::BaseTransferAgent::notifySyncMessage, py::arg("name"), py::arg("sync_message"))
|
||||
.def("get_notified_sync_messages", &kvc::BaseTransferAgent::getNotifiedSyncMessages)
|
||||
.def("get_local_connection_info", &kvc::BaseTransferAgent::getLocalConnectionInfo)
|
||||
.def("check_remote_descs", &kvc::BaseTransferAgent::checkRemoteDescs, py::arg("name"), py::arg("memory_descs"));
|
||||
|
||||
#ifdef ENABLE_NIXL
|
||||
// NixlTransferStatus class - release GIL for blocking operations
|
||||
py::class_<kvc::NixlTransferStatus, kvc::TransferStatus>(m, "NixlTransferStatus")
|
||||
.def("is_completed", &kvc::NixlTransferStatus::isCompleted, py::call_guard<py::gil_scoped_release>())
|
||||
.def("wait", &kvc::NixlTransferStatus::wait, py::arg("timeout_ms") = -1,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
// NixlTransferAgent class
|
||||
py::class_<kvc::NixlTransferAgent, kvc::BaseTransferAgent>(m, "NixlTransferAgent")
|
||||
.def(py::init<kvc::BaseAgentConfig const&>(), py::arg("config"))
|
||||
.def("register_memory", &kvc::NixlTransferAgent::registerMemory, py::arg("descs"))
|
||||
.def("deregister_memory", &kvc::NixlTransferAgent::deregisterMemory, py::arg("descs"))
|
||||
.def("load_remote_agent",
|
||||
py::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::NixlTransferAgent::loadRemoteAgent),
|
||||
py::arg("name"), py::arg("agent_desc"))
|
||||
.def("load_remote_agent_by_connection",
|
||||
py::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
|
||||
&kvc::NixlTransferAgent::loadRemoteAgent),
|
||||
py::arg("name"), py::arg("connection_info"))
|
||||
.def("get_local_agent_desc", &kvc::NixlTransferAgent::getLocalAgentDesc)
|
||||
.def("get_local_connection_info", &kvc::NixlTransferAgent::getLocalConnectionInfo)
|
||||
.def("invalidate_remote_agent", &kvc::NixlTransferAgent::invalidateRemoteAgent, py::arg("name"))
|
||||
.def(
|
||||
"submit_transfer_requests",
|
||||
[](kvc::NixlTransferAgent& self, kvc::TransferRequest const& request)
|
||||
{ return self.submitTransferRequests(request).release(); },
|
||||
py::arg("request"), py::return_value_policy::take_ownership, py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"notify_sync_message", &kvc::NixlTransferAgent::notifySyncMessage, py::arg("name"), py::arg("sync_message"))
|
||||
.def("get_notified_sync_messages", &kvc::NixlTransferAgent::getNotifiedSyncMessages)
|
||||
.def("check_remote_descs", &kvc::NixlTransferAgent::checkRemoteDescs, py::arg("name"), py::arg("memory_descs"));
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_MOONCAKE
|
||||
// MooncakeTransferStatus class - release GIL for blocking operations
|
||||
py::class_<kvc::MooncakeTransferStatus, kvc::TransferStatus>(m, "MooncakeTransferStatus")
|
||||
.def("is_completed", &kvc::MooncakeTransferStatus::isCompleted, py::call_guard<py::gil_scoped_release>())
|
||||
.def("wait", &kvc::MooncakeTransferStatus::wait, py::arg("timeout_ms") = -1,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
// MooncakeTransferAgent class
|
||||
py::class_<kvc::MooncakeTransferAgent, kvc::BaseTransferAgent>(m, "MooncakeTransferAgent")
|
||||
.def(py::init<kvc::BaseAgentConfig const&>(), py::arg("config"))
|
||||
.def("register_memory", &kvc::MooncakeTransferAgent::registerMemory, py::arg("descs"))
|
||||
.def("deregister_memory", &kvc::MooncakeTransferAgent::deregisterMemory, py::arg("descs"))
|
||||
.def("load_remote_agent",
|
||||
py::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::MooncakeTransferAgent::loadRemoteAgent),
|
||||
py::arg("name"), py::arg("agent_desc"))
|
||||
.def("load_remote_agent_by_connection",
|
||||
py::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
|
||||
&kvc::MooncakeTransferAgent::loadRemoteAgent),
|
||||
py::arg("name"), py::arg("connection_info"))
|
||||
.def("get_local_agent_desc", &kvc::MooncakeTransferAgent::getLocalAgentDesc)
|
||||
.def("get_local_connection_info", &kvc::MooncakeTransferAgent::getLocalConnectionInfo)
|
||||
.def("invalidate_remote_agent", &kvc::MooncakeTransferAgent::invalidateRemoteAgent, py::arg("name"))
|
||||
.def(
|
||||
"submit_transfer_requests",
|
||||
[](kvc::MooncakeTransferAgent& self, kvc::TransferRequest const& request)
|
||||
{ return self.submitTransferRequests(request).release(); },
|
||||
py::arg("request"), py::return_value_policy::take_ownership, py::call_guard<py::gil_scoped_release>())
|
||||
.def("notify_sync_message", &kvc::MooncakeTransferAgent::notifySyncMessage, py::arg("name"),
|
||||
py::arg("sync_message"))
|
||||
.def("get_notified_sync_messages", &kvc::MooncakeTransferAgent::getNotifiedSyncMessages)
|
||||
.def("check_remote_descs", &kvc::MooncakeTransferAgent::checkRemoteDescs, py::arg("name"),
|
||||
py::arg("memory_descs"));
|
||||
#endif
|
||||
|
||||
// Factory function to create transfer agent by backend name (uses dynamic loading)
|
||||
m.def(
|
||||
"make_transfer_agent",
|
||||
[](std::string const& backend, kvc::BaseAgentConfig const& config) -> kvc::BaseTransferAgent*
|
||||
{ return kvc::makeTransferAgent(backend, &config).release(); },
|
||||
py::arg("backend"), py::arg("config"), py::return_value_policy::take_ownership,
|
||||
"Create a transfer agent by backend name ('nixl' or 'mooncake'). Uses dynamic loading.");
|
||||
|
||||
// Expose which backends are available
|
||||
#ifdef ENABLE_NIXL
|
||||
m.attr("NIXL_ENABLED") = true;
|
||||
#else
|
||||
m.attr("NIXL_ENABLED") = false;
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_MOONCAKE
|
||||
m.attr("MOONCAKE_ENABLED") = true;
|
||||
#else
|
||||
m.attr("MOONCAKE_ENABLED") = false;
|
||||
#endif
|
||||
}
|
||||
@ -22,6 +22,7 @@
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <chrono>
|
||||
#include <dirent.h>
|
||||
#include <fcntl.h>
|
||||
#include <ifaddrs.h>
|
||||
@ -31,6 +32,7 @@
|
||||
#include <set>
|
||||
#include <sys/file.h>
|
||||
#include <sys/stat.h>
|
||||
#include <thread>
|
||||
#include <unistd.h>
|
||||
#include <vector>
|
||||
|
||||
@ -318,10 +320,40 @@ NixlTransferStatus::NixlTransferStatus(nixlAgent* agent, nixlXferReqH* handle)
|
||||
TLLM_CHECK(mHandle);
|
||||
}
|
||||
|
||||
void NixlTransferStatus::wait() const
|
||||
TransferState NixlTransferStatus::wait(int64_t timeout_ms) const
|
||||
{
|
||||
while (!isCompleted())
|
||||
;
|
||||
auto startTime = std::chrono::steady_clock::now();
|
||||
|
||||
while (true)
|
||||
{
|
||||
auto status = mRawAgent->getXferStatus(mHandle);
|
||||
if (status == NIXL_SUCCESS)
|
||||
{
|
||||
return TransferState::kSUCCESS;
|
||||
}
|
||||
else if (status != NIXL_IN_PROG)
|
||||
{
|
||||
return TransferState::kFAILURE;
|
||||
}
|
||||
|
||||
// If timeout_ms < 0, wait indefinitely until status is not NIXL_IN_PROG
|
||||
if (timeout_ms < 0)
|
||||
{
|
||||
std::this_thread::yield();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if timeout has elapsed
|
||||
auto elapsed
|
||||
= std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - startTime)
|
||||
.count();
|
||||
if (elapsed >= timeout_ms)
|
||||
{
|
||||
return TransferState::kIN_PROGRESS;
|
||||
}
|
||||
|
||||
std::this_thread::yield();
|
||||
}
|
||||
}
|
||||
|
||||
[[nodiscard]] bool NixlTransferStatus::isCompleted() const
|
||||
@ -333,6 +365,7 @@ NixlTransferAgent::NixlTransferAgent(BaseAgentConfig const& config)
|
||||
: mName{config.mName}
|
||||
{
|
||||
nixl_status_t status;
|
||||
if (config.useListenThread)
|
||||
{
|
||||
FileLock lock("/tmp/trtllm_nixl_port.lock");
|
||||
if (!lock.lock())
|
||||
@ -341,10 +374,18 @@ NixlTransferAgent::NixlTransferAgent(BaseAgentConfig const& config)
|
||||
}
|
||||
auto envPort = common::getEnvNixlPort();
|
||||
uint16_t port = envPort > 0 ? getIncrmentPort(envPort) : getAvailablePort();
|
||||
nixlAgentConfig nixlConfig{config.useProgThread, true, port};
|
||||
nixlAgentConfig nixlConfig{
|
||||
config.useProgThread, true, port, nixl_thread_sync_t::NIXL_THREAD_SYNC_DEFAULT, config.numWorkers};
|
||||
mAddress = getAvailableIP() + ":" + std::to_string(port);
|
||||
mRawAgent = std::make_unique<nixlAgent>(config.mName, std::move(nixlConfig));
|
||||
}
|
||||
else
|
||||
{
|
||||
mAddress.clear();
|
||||
nixlAgentConfig nixlConfig{
|
||||
config.useProgThread, false, 0, nixl_thread_sync_t::NIXL_THREAD_SYNC_DEFAULT, config.numWorkers};
|
||||
mRawAgent = std::make_unique<nixlAgent>(config.mName, std::move(nixlConfig));
|
||||
}
|
||||
|
||||
std::string nixlBackend = common::getEnvNixlBackend();
|
||||
// List of supported backends - extend this list as new backends are added
|
||||
@ -645,7 +686,8 @@ void NixlLoopbackAgent::executeLoopbackRequest(
|
||||
|
||||
std::unique_ptr<TransferStatus> status = this->submitLoopbackRequests(memoryDescs, fileDescs, isOffload);
|
||||
TLLM_CHECK_WITH_INFO(status != nullptr, "submitLoopbackRequests failed");
|
||||
status->wait();
|
||||
TransferState transferState = status->wait();
|
||||
TLLM_CHECK_WITH_INFO(transferState == TransferState::kSUCCESS, "submitLoopbackRequests failed");
|
||||
|
||||
this->deregisterMemory(memoryDescs);
|
||||
this->deregisterFiles(fileDescs);
|
||||
|
||||
@ -45,7 +45,7 @@ public:
|
||||
|
||||
[[nodiscard]] bool isCompleted() const override;
|
||||
|
||||
void wait() const override;
|
||||
[[nodiscard]] TransferState wait(int64_t timeout_ms = -1) const override;
|
||||
|
||||
private:
|
||||
nixlAgent* mRawAgent{};
|
||||
|
||||
@ -66,3 +66,8 @@ if(NOT WIN32)
|
||||
${TRTLLM_NB_MODULE} PROPERTIES LINK_FLAGS
|
||||
"${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}")
|
||||
endif()
|
||||
|
||||
# Build transfer_agent_binding when building bindings (if NIXL is enabled)
|
||||
if(TARGET ${TRANSFER_AGENT_BINDING_TARGET})
|
||||
add_dependencies(${TRTLLM_NB_MODULE} ${TRANSFER_AGENT_BINDING_TARGET})
|
||||
endif()
|
||||
|
||||
@ -69,3 +69,8 @@ if(NOT WIN32)
|
||||
${TRTLLM_PYBIND_MODULE} PROPERTIES LINK_FLAGS
|
||||
"${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}")
|
||||
endif()
|
||||
|
||||
# Build transfer_agent_binding when building bindings (if NIXL is enabled)
|
||||
if(TARGET ${TRANSFER_AGENT_BINDING_TARGET})
|
||||
add_dependencies(${TRTLLM_PYBIND_MODULE} ${TRANSFER_AGENT_BINDING_TARGET})
|
||||
endif()
|
||||
|
||||
@ -92,7 +92,7 @@ TEST_P(TransferAgentTest, Basic)
|
||||
{
|
||||
|
||||
std::string const agent0{"agent0"}, agent1{"agent1"};
|
||||
BaseAgentConfig config0{agent0, true}, config1{agent1, true};
|
||||
BaseAgentConfig config0{agent0, true, false, true, 1}, config1{agent1, true, false, true, 1};
|
||||
auto xferAgent0 = makeTransferAgent(config0);
|
||||
auto xferAgent1 = makeTransferAgent(config1);
|
||||
|
||||
@ -117,10 +117,9 @@ TEST_P(TransferAgentTest, Basic)
|
||||
|
||||
TransferRequest writeReq{TransferOp::kWRITE, regMem0.getDescs(), regMem1.getDescs(), agent1};
|
||||
auto status = xferAgent0->submitTransferRequests(writeReq);
|
||||
status->wait();
|
||||
TLLM_CHECK(status->wait() == TransferState::kSUCCESS);
|
||||
|
||||
TLLM_CHECK(memory0 == memory1);
|
||||
|
||||
xferAgent0->invalidateRemoteAgent(agent1);
|
||||
}
|
||||
|
||||
@ -128,7 +127,7 @@ TEST_P(TransferAgentTest, Basic2)
|
||||
{
|
||||
|
||||
std::string const agent0{"agent0"}, agent1{"agent1"};
|
||||
BaseAgentConfig config0{agent0, true}, config1{agent1, true};
|
||||
BaseAgentConfig config0{agent0, true, false, true, 1}, config1{agent1, true, false, true, 1};
|
||||
auto xferAgent0 = makeTransferAgent(config0);
|
||||
auto xferAgent1 = makeTransferAgent(config1);
|
||||
|
||||
@ -152,7 +151,7 @@ TEST_P(TransferAgentTest, Basic2)
|
||||
|
||||
TransferRequest readReq{TransferOp::kREAD, regMem0.getDescs(), regMem1.getDescs(), agent1};
|
||||
auto status = xferAgent0->submitTransferRequests(readReq);
|
||||
status->wait();
|
||||
TLLM_CHECK(status->wait() == TransferState::kSUCCESS);
|
||||
|
||||
TLLM_CHECK(memory0 == memory1);
|
||||
|
||||
@ -163,7 +162,7 @@ TEST_P(TransferAgentTest, DeviceMemory)
|
||||
{
|
||||
|
||||
std::string const agent0{"agent0"}, agent1{"agent1"};
|
||||
BaseAgentConfig config0{agent0, true}, config1{agent1, true};
|
||||
BaseAgentConfig config0{agent0, true, false, true, 1}, config1{agent1, true, false, true, 1};
|
||||
auto xferAgent0 = makeTransferAgent(config0);
|
||||
auto xferAgent1 = makeTransferAgent(config1);
|
||||
|
||||
@ -177,8 +176,8 @@ TEST_P(TransferAgentTest, DeviceMemory)
|
||||
cudaMalloc(&dev_ptr1, size);
|
||||
std::vector<char> memory0(size, 10);
|
||||
std::vector<char> memory1(size, 1);
|
||||
cudaMemcpy(dev_ptr0, memory0.data(), size, cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(dev_ptr1, memory1.data(), size, cudaMemcpyHostToDevice);
|
||||
TLLM_CUDA_CHECK(cudaMemcpy(dev_ptr0, memory0.data(), size, cudaMemcpyHostToDevice));
|
||||
TLLM_CUDA_CHECK(cudaMemcpy(dev_ptr1, memory1.data(), size, cudaMemcpyHostToDevice));
|
||||
RegisteredHostMemory regMem0(
|
||||
MemoryDescs{MemoryType::kVRAM, {MemoryDesc{dev_ptr0, size, deviceId}}}, xferAgent0.get());
|
||||
RegisteredHostMemory regMem1(
|
||||
@ -194,12 +193,13 @@ TEST_P(TransferAgentTest, DeviceMemory)
|
||||
} while (!checked);
|
||||
TransferRequest writeReq{TransferOp::kWRITE, regMem0.getDescs(), regMem1.getDescs(), agent1};
|
||||
auto status = xferAgent0->submitTransferRequests(writeReq);
|
||||
status->wait();
|
||||
TLLM_CHECK(status->wait() == TransferState::kSUCCESS);
|
||||
|
||||
cudaMemcpy(memory0.data(), dev_ptr0, size, cudaMemcpyDeviceToHost);
|
||||
cudaMemcpy(memory1.data(), dev_ptr1, size, cudaMemcpyDeviceToHost);
|
||||
TLLM_CUDA_CHECK(cudaMemcpy(memory0.data(), dev_ptr0, size, cudaMemcpyDeviceToHost));
|
||||
TLLM_CUDA_CHECK(cudaMemcpy(memory1.data(), dev_ptr1, size, cudaMemcpyDeviceToHost));
|
||||
|
||||
TLLM_CHECK(memory0 == memory1);
|
||||
|
||||
TLLM_CUDA_CHECK(cudaFree(dev_ptr0));
|
||||
TLLM_CUDA_CHECK(cudaFree(dev_ptr1));
|
||||
xferAgent0->invalidateRemoteAgent(agent1);
|
||||
@ -209,7 +209,8 @@ TEST_P(TransferAgentTest, Connect)
|
||||
{
|
||||
|
||||
std::string const agent0{"agent0"}, agent1{"agent1"}, agent2{"agent2"};
|
||||
BaseAgentConfig config0{agent0, true}, config1{agent1, true}, config2{agent2, true};
|
||||
BaseAgentConfig config0{agent0, true, false, true, 1}, config1{agent1, true, false, true, 1},
|
||||
config2{agent2, true, false, true, 1};
|
||||
auto xferAgent0 = makeTransferAgent(config0);
|
||||
auto xferAgent1 = makeTransferAgent(config1);
|
||||
auto xferAgent2 = makeTransferAgent(config2);
|
||||
@ -236,7 +237,7 @@ TEST_P(TransferAgentTest, Connect)
|
||||
} while (!checked);
|
||||
TransferRequest writeReq{TransferOp::kWRITE, memDescs0, memDescs1, agent1};
|
||||
auto status = xferAgent0->submitTransferRequests(writeReq);
|
||||
status->wait();
|
||||
TLLM_CHECK(status->wait() == TransferState::kSUCCESS);
|
||||
|
||||
TLLM_CHECK(memory0 == memory1);
|
||||
xferAgent2->loadRemoteAgent(agent1, connectionInfo);
|
||||
@ -247,7 +248,7 @@ TEST_P(TransferAgentTest, Connect)
|
||||
} while (!checked);
|
||||
TransferRequest writeReq2{TransferOp::kWRITE, memDescs0, memDescs1, agent1};
|
||||
auto status2 = xferAgent2->submitTransferRequests(writeReq2);
|
||||
status2->wait();
|
||||
TLLM_CHECK(status2->wait() == TransferState::kSUCCESS);
|
||||
TLLM_CHECK(memory0 == memory1);
|
||||
xferAgent0->invalidateRemoteAgent(agent1);
|
||||
xferAgent2->invalidateRemoteAgent(agent1);
|
||||
@ -260,7 +261,7 @@ TEST_P(TransferAgentTest, SyncMessage)
|
||||
{
|
||||
constexpr std::size_t MAX_QUERY_TIMES = std::numeric_limits<size_t>::max();
|
||||
std::string const agent0{"agent0"}, agent1{"agent1"};
|
||||
BaseAgentConfig config0{agent0, true}, config1{agent1, true};
|
||||
BaseAgentConfig config0{agent0, true, false, true, 1}, config1{agent1, true, false, true, 1};
|
||||
auto xferAgent0 = makeTransferAgent(config0);
|
||||
auto xferAgent1 = makeTransferAgent(config1);
|
||||
|
||||
@ -294,7 +295,7 @@ TEST_P(TransferAgentTest, SyncMessage)
|
||||
{
|
||||
notif = xferAgent1->getNotifiedSyncMessages();
|
||||
}
|
||||
status->wait();
|
||||
TLLM_CHECK(status->wait() == TransferState::kSUCCESS);
|
||||
TLLM_CHECK(status->isCompleted());
|
||||
TLLM_CHECK(notif.size() == 1);
|
||||
TLLM_CHECK(notif[agent0].size() == 1);
|
||||
@ -343,7 +344,7 @@ TEST_P(TransferAgentTest, SyncMessage)
|
||||
{
|
||||
notif4 = xferAgent0->getNotifiedSyncMessages();
|
||||
}
|
||||
status1->wait();
|
||||
TLLM_CHECK(status1->wait() == TransferState::kSUCCESS);
|
||||
TLLM_CHECK(status1->isCompleted());
|
||||
TLLM_CHECK(notif4.size() == 1);
|
||||
TLLM_CHECK(notif4[agent1].size() == 1);
|
||||
|
||||
@ -370,6 +370,7 @@ def check_missing_libs(lib_name: str) -> list[str]:
|
||||
|
||||
def generate_python_stubs_linux(binding_type: str, venv_python: Path,
|
||||
deep_ep: bool, flash_mla: bool,
|
||||
transfer_agent_binding: bool,
|
||||
binding_lib_name: str):
|
||||
is_nanobind = binding_type == "nanobind"
|
||||
if is_nanobind:
|
||||
@ -411,6 +412,16 @@ def generate_python_stubs_linux(binding_type: str, venv_python: Path,
|
||||
build_run(
|
||||
f"\"{venv_python}\" -m pybind11_stubgen -o . deep_ep_cpp_tllm --exit-code",
|
||||
env=env_stub_gen)
|
||||
if transfer_agent_binding:
|
||||
# Generate stubs for tensorrt_llm_transfer_agent_binding
|
||||
if is_nanobind:
|
||||
build_run(
|
||||
f"\"{venv_python}\" -m nanobind.stubgen -m tensorrt_llm_transfer_agent_binding -O .",
|
||||
env=env_stub_gen)
|
||||
else:
|
||||
build_run(
|
||||
f"\"{venv_python}\" -m pybind11_stubgen -o . tensorrt_llm_transfer_agent_binding --exit-code",
|
||||
env=env_stub_gen)
|
||||
finally:
|
||||
if link_dir:
|
||||
rmtree(link_dir)
|
||||
@ -801,17 +812,15 @@ def main(*,
|
||||
build_run(
|
||||
f"find {ucx_dir} -type f -name '*.so*' -exec patchelf --set-rpath \'$ORIGIN:$ORIGIN/ucx:$ORIGIN/../\' {{}} \\;"
|
||||
)
|
||||
if os.path.exists(
|
||||
build_dir /
|
||||
"tensorrt_llm/executor/cache_transmission/nixl_utils/libtensorrt_llm_nixl_wrapper.so"
|
||||
):
|
||||
install_file(
|
||||
build_dir /
|
||||
"tensorrt_llm/executor/cache_transmission/nixl_utils/libtensorrt_llm_nixl_wrapper.so",
|
||||
lib_dir / "libtensorrt_llm_nixl_wrapper.so")
|
||||
# NIXL wrapper and libraries
|
||||
nixl_utils_dir = build_dir / "tensorrt_llm/executor/cache_transmission/nixl_utils"
|
||||
if os.path.exists(nixl_utils_dir / "libtensorrt_llm_nixl_wrapper.so"):
|
||||
install_file(nixl_utils_dir / "libtensorrt_llm_nixl_wrapper.so",
|
||||
lib_dir / "libtensorrt_llm_nixl_wrapper.so")
|
||||
build_run(
|
||||
f'patchelf --set-rpath \'$ORIGIN/nixl/\' {lib_dir / "libtensorrt_llm_nixl_wrapper.so"}'
|
||||
)
|
||||
# Copy NIXL libraries
|
||||
if os.path.exists("/opt/nvidia/nvda_nixl"):
|
||||
nixl_dir = lib_dir / "nixl"
|
||||
if nixl_dir.exists():
|
||||
@ -825,6 +834,15 @@ def main(*,
|
||||
build_run(
|
||||
f"find {nixl_dir} -type f -name '*.so*' -exec patchelf --set-rpath \'$ORIGIN:$ORIGIN/plugins:$ORIGIN/../:$ORIGIN/../ucx/:$ORIGIN/../../ucx/\' {{}} \\;"
|
||||
)
|
||||
# Install tensorrt_llm_transfer_agent_binding Python module (standalone agent bindings)
|
||||
# This is built when either NIXL or Mooncake is enabled
|
||||
# Install to tensorrt_llm/ (same level as bindings.so)
|
||||
agent_binding_so = list(
|
||||
nixl_utils_dir.glob("tensorrt_llm_transfer_agent_binding*.so"))
|
||||
if agent_binding_so:
|
||||
trtllm_dir = project_dir / "tensorrt_llm"
|
||||
install_file(agent_binding_so[0],
|
||||
trtllm_dir / agent_binding_so[0].name)
|
||||
if os.path.exists(
|
||||
build_dir /
|
||||
"tensorrt_llm/executor/cache_transmission/mooncake_utils/libtensorrt_llm_mooncake_wrapper.so"
|
||||
@ -946,6 +964,7 @@ def main(*,
|
||||
binding_type, venv_python,
|
||||
bool(deep_ep_cuda_architectures),
|
||||
bool(flash_mla_cuda_architectures),
|
||||
nixl_root is not None or mooncake_root is not None,
|
||||
binding_lib_file_name)
|
||||
|
||||
if not skip_building_wheel:
|
||||
|
||||
2
setup.py
2
setup.py
@ -114,6 +114,8 @@ else:
|
||||
'libs/libnvinfer_plugin_tensorrt_llm.so',
|
||||
'libs/libtensorrt_llm_ucx_wrapper.so', 'libs/libdecoder_attention_0.so',
|
||||
'libs/libtensorrt_llm_nixl_wrapper.so', 'libs/nixl/**/*',
|
||||
'tensorrt_llm_transfer_agent_binding*.so',
|
||||
'tensorrt_llm_transfer_agent_binding.pyi',
|
||||
'libs/libtensorrt_llm_mooncake_wrapper.so', 'libs/ucx/**/*',
|
||||
'libs/libpg_utils.so', 'libs/libdecoder_attention_1.so',
|
||||
'libs/nvshmem/License.txt', 'libs/nvshmem/nvshmem_bootstrap_uid.so.3',
|
||||
|
||||
585
tests/unittest/bindings/test_transfer_agent_bindings.py
Normal file
585
tests/unittest/bindings/test_transfer_agent_bindings.py
Normal file
@ -0,0 +1,585 @@
|
||||
import pytest
|
||||
|
||||
# Try to import the transfer agent binding module
|
||||
try:
|
||||
import tensorrt_llm.tensorrt_llm_transfer_agent_binding as tab
|
||||
|
||||
HAS_TRANSFER_AGENT = True
|
||||
# Check which backends are available
|
||||
HAS_NIXL = getattr(tab, "NIXL_ENABLED", False)
|
||||
HAS_MOONCAKE = getattr(tab, "MOONCAKE_ENABLED", False)
|
||||
except ImportError:
|
||||
HAS_TRANSFER_AGENT = False
|
||||
HAS_NIXL = False
|
||||
HAS_MOONCAKE = False
|
||||
|
||||
# Try to import torch for functional tests
|
||||
try:
|
||||
import torch
|
||||
|
||||
HAS_TORCH = True
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
except ImportError:
|
||||
HAS_TORCH = False
|
||||
HAS_CUDA = False
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not HAS_TRANSFER_AGENT,
|
||||
reason="Transfer agent bindings not available (tensorrt_llm_transfer_agent_binding)",
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Common Tests (independent of backend)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_memory_type_enum():
|
||||
"""Test MemoryType enum values."""
|
||||
assert tab.MemoryType.DRAM is not None
|
||||
assert tab.MemoryType.VRAM is not None
|
||||
assert tab.MemoryType.BLK is not None
|
||||
assert tab.MemoryType.OBJ is not None
|
||||
assert tab.MemoryType.FILE is not None
|
||||
|
||||
# Verify they are distinct
|
||||
assert tab.MemoryType.DRAM != tab.MemoryType.VRAM
|
||||
assert tab.MemoryType.VRAM != tab.MemoryType.BLK
|
||||
|
||||
|
||||
def test_transfer_op_enum():
|
||||
"""Test TransferOp enum values."""
|
||||
assert tab.TransferOp.READ is not None
|
||||
assert tab.TransferOp.WRITE is not None
|
||||
assert tab.TransferOp.READ != tab.TransferOp.WRITE
|
||||
|
||||
|
||||
def test_transfer_state_enum():
|
||||
"""Test TransferState enum values."""
|
||||
assert tab.TransferState.IN_PROGRESS is not None
|
||||
assert tab.TransferState.SUCCESS is not None
|
||||
assert tab.TransferState.FAILURE is not None
|
||||
|
||||
# Verify they are distinct
|
||||
assert tab.TransferState.IN_PROGRESS != tab.TransferState.SUCCESS
|
||||
assert tab.TransferState.SUCCESS != tab.TransferState.FAILURE
|
||||
assert tab.TransferState.IN_PROGRESS != tab.TransferState.FAILURE
|
||||
|
||||
|
||||
def test_memory_desc():
|
||||
"""Test MemoryDesc class."""
|
||||
addr = 0x1000
|
||||
length = 4096
|
||||
device_id = 0
|
||||
|
||||
desc = tab.MemoryDesc(addr, length, device_id)
|
||||
|
||||
assert desc.addr == addr
|
||||
assert desc.len == length
|
||||
assert desc.device_id == device_id
|
||||
|
||||
|
||||
def test_memory_desc_different_values():
|
||||
"""Test MemoryDesc with different values."""
|
||||
test_cases = [
|
||||
(0x0, 1, 0),
|
||||
(0xFFFFFFFF, 65536, 1),
|
||||
(0x12345678, 1024, 7),
|
||||
]
|
||||
|
||||
for addr, length, device_id in test_cases:
|
||||
desc = tab.MemoryDesc(addr, length, device_id)
|
||||
assert desc.addr == addr
|
||||
assert desc.len == length
|
||||
assert desc.device_id == device_id
|
||||
|
||||
|
||||
def test_memory_descs():
|
||||
"""Test MemoryDescs class."""
|
||||
desc1 = tab.MemoryDesc(0x1000, 4096, 0)
|
||||
desc2 = tab.MemoryDesc(0x2000, 8192, 0)
|
||||
|
||||
descs = tab.MemoryDescs(tab.MemoryType.VRAM, [desc1, desc2])
|
||||
|
||||
assert descs.type == tab.MemoryType.VRAM
|
||||
assert len(descs.descs) == 2
|
||||
assert descs.descs[0].addr == 0x1000
|
||||
assert descs.descs[1].addr == 0x2000
|
||||
|
||||
|
||||
def test_memory_descs_empty():
|
||||
"""Test MemoryDescs with empty list."""
|
||||
descs = tab.MemoryDescs(tab.MemoryType.DRAM, [])
|
||||
assert descs.type == tab.MemoryType.DRAM
|
||||
assert len(descs.descs) == 0
|
||||
|
||||
|
||||
def test_agent_desc_from_string():
|
||||
"""Test AgentDesc from string."""
|
||||
test_data = "test_agent_descriptor"
|
||||
desc = tab.AgentDesc(test_data)
|
||||
assert desc.backend_agent_desc == test_data.encode()
|
||||
|
||||
|
||||
def test_agent_desc_from_bytes():
|
||||
"""Test AgentDesc from bytes."""
|
||||
test_data = b"test_binary_data\x00\x01\x02"
|
||||
desc = tab.AgentDesc(test_data)
|
||||
assert desc.backend_agent_desc == test_data
|
||||
|
||||
|
||||
def test_base_agent_config_default():
|
||||
"""Test BaseAgentConfig with default values."""
|
||||
config = tab.BaseAgentConfig()
|
||||
# Default values should be set
|
||||
assert config is not None
|
||||
|
||||
|
||||
def test_base_agent_config_custom():
|
||||
"""Test BaseAgentConfig with custom values."""
|
||||
name = "test_agent"
|
||||
use_prog_thread = True
|
||||
multi_thread = False
|
||||
use_listen_thread = True
|
||||
num_workers = 4
|
||||
|
||||
config = tab.BaseAgentConfig(
|
||||
name=name,
|
||||
use_prog_thread=use_prog_thread,
|
||||
multi_thread=multi_thread,
|
||||
use_listen_thread=use_listen_thread,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
assert config.name == name
|
||||
assert config.use_prog_thread == use_prog_thread
|
||||
assert config.multi_thread == multi_thread
|
||||
assert config.use_listen_thread == use_listen_thread
|
||||
assert config.num_workers == num_workers
|
||||
|
||||
|
||||
def test_base_agent_config_readwrite():
|
||||
"""Test BaseAgentConfig read/write properties."""
|
||||
config = tab.BaseAgentConfig()
|
||||
|
||||
config.name = "modified_name"
|
||||
assert config.name == "modified_name"
|
||||
|
||||
config.use_prog_thread = False
|
||||
assert config.use_prog_thread is False
|
||||
|
||||
config.multi_thread = True
|
||||
assert config.multi_thread is True
|
||||
|
||||
config.use_listen_thread = True
|
||||
assert config.use_listen_thread is True
|
||||
|
||||
config.num_workers = 8
|
||||
assert config.num_workers == 8
|
||||
|
||||
|
||||
def test_transfer_request():
|
||||
"""Test TransferRequest class."""
|
||||
src_desc = tab.MemoryDesc(0x1000, 4096, 0)
|
||||
dst_desc = tab.MemoryDesc(0x2000, 4096, 1)
|
||||
|
||||
src_descs = tab.MemoryDescs(tab.MemoryType.VRAM, [src_desc])
|
||||
dst_descs = tab.MemoryDescs(tab.MemoryType.VRAM, [dst_desc])
|
||||
|
||||
remote_name = "remote_agent"
|
||||
|
||||
request = tab.TransferRequest(tab.TransferOp.WRITE, src_descs, dst_descs, remote_name)
|
||||
|
||||
assert request.op == tab.TransferOp.WRITE
|
||||
assert request.remote_name == remote_name
|
||||
assert request.src_descs.type == tab.MemoryType.VRAM
|
||||
assert request.dst_descs.type == tab.MemoryType.VRAM
|
||||
|
||||
|
||||
def test_transfer_request_read_op():
|
||||
"""Test TransferRequest with READ operation."""
|
||||
src_desc = tab.MemoryDesc(0x3000, 2048, 0)
|
||||
dst_desc = tab.MemoryDesc(0x4000, 2048, 0)
|
||||
|
||||
src_descs = tab.MemoryDescs(tab.MemoryType.DRAM, [src_desc])
|
||||
dst_descs = tab.MemoryDescs(tab.MemoryType.DRAM, [dst_desc])
|
||||
|
||||
request = tab.TransferRequest(tab.TransferOp.READ, src_descs, dst_descs, "another_remote")
|
||||
|
||||
assert request.op == tab.TransferOp.READ
|
||||
assert request.remote_name == "another_remote"
|
||||
|
||||
|
||||
def test_backend_availability_flags():
|
||||
"""Test that backend availability flags are exposed."""
|
||||
# These should always be defined (either True or False)
|
||||
assert hasattr(tab, "NIXL_ENABLED")
|
||||
assert hasattr(tab, "MOONCAKE_ENABLED")
|
||||
assert isinstance(tab.NIXL_ENABLED, bool)
|
||||
assert isinstance(tab.MOONCAKE_ENABLED, bool)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# NIXL-specific Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_NIXL, reason="NIXL backend not available")
|
||||
class TestNixlTransferAgent:
|
||||
"""Test cases for NixlTransferAgent."""
|
||||
|
||||
def test_nixl_transfer_agent_class_exists(self):
|
||||
"""Test that NixlTransferAgent class exists."""
|
||||
assert hasattr(tab, "NixlTransferAgent")
|
||||
|
||||
def test_nixl_transfer_status_class_exists(self):
|
||||
"""Test that NixlTransferStatus class exists."""
|
||||
assert hasattr(tab, "NixlTransferStatus")
|
||||
|
||||
def test_nixl_transfer_agent_is_base_subclass(self):
|
||||
"""Test that NixlTransferAgent is a subclass of BaseTransferAgent."""
|
||||
assert issubclass(tab.NixlTransferAgent, tab.BaseTransferAgent)
|
||||
|
||||
def test_nixl_transfer_status_is_base_subclass(self):
|
||||
"""Test that NixlTransferStatus is a subclass of TransferStatus."""
|
||||
assert issubclass(tab.NixlTransferStatus, tab.TransferStatus)
|
||||
|
||||
def test_nixl_transfer_agent_has_required_methods(self):
|
||||
"""Test that NixlTransferAgent has all required methods."""
|
||||
required_methods = [
|
||||
"register_memory",
|
||||
"deregister_memory",
|
||||
"load_remote_agent",
|
||||
"load_remote_agent_by_connection",
|
||||
"get_local_agent_desc",
|
||||
"get_local_connection_info",
|
||||
"invalidate_remote_agent",
|
||||
"submit_transfer_requests",
|
||||
"notify_sync_message",
|
||||
"get_notified_sync_messages",
|
||||
"check_remote_descs",
|
||||
]
|
||||
for method in required_methods:
|
||||
assert hasattr(tab.NixlTransferAgent, method), f"Missing method: {method}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Mooncake-specific Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_MOONCAKE, reason="Mooncake backend not available")
|
||||
class TestMooncakeTransferAgent:
|
||||
"""Test cases for MooncakeTransferAgent."""
|
||||
|
||||
def test_mooncake_transfer_agent_class_exists(self):
|
||||
"""Test that MooncakeTransferAgent class exists."""
|
||||
assert hasattr(tab, "MooncakeTransferAgent")
|
||||
|
||||
def test_mooncake_transfer_status_class_exists(self):
|
||||
"""Test that MooncakeTransferStatus class exists."""
|
||||
assert hasattr(tab, "MooncakeTransferStatus")
|
||||
|
||||
def test_mooncake_transfer_agent_is_base_subclass(self):
|
||||
"""Test that MooncakeTransferAgent is a subclass of BaseTransferAgent."""
|
||||
assert issubclass(tab.MooncakeTransferAgent, tab.BaseTransferAgent)
|
||||
|
||||
def test_mooncake_transfer_status_is_base_subclass(self):
|
||||
"""Test that MooncakeTransferStatus is a subclass of TransferStatus."""
|
||||
assert issubclass(tab.MooncakeTransferStatus, tab.TransferStatus)
|
||||
|
||||
def test_mooncake_transfer_agent_has_required_methods(self):
|
||||
"""Test that MooncakeTransferAgent has all required methods."""
|
||||
required_methods = [
|
||||
"register_memory",
|
||||
"deregister_memory",
|
||||
"load_remote_agent",
|
||||
"load_remote_agent_by_connection",
|
||||
"get_local_agent_desc",
|
||||
"get_local_connection_info",
|
||||
"invalidate_remote_agent",
|
||||
"submit_transfer_requests",
|
||||
"notify_sync_message",
|
||||
"get_notified_sync_messages",
|
||||
"check_remote_descs",
|
||||
]
|
||||
for method in required_methods:
|
||||
assert hasattr(tab.MooncakeTransferAgent, method), f"Missing method: {method}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Functional Tests - Data Transfer Validation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _create_memory_descs_from_tensor(tensor, memory_type):
|
||||
"""Helper to create MemoryDescs from a torch tensor."""
|
||||
addr = tensor.data_ptr()
|
||||
size = tensor.numel() * tensor.element_size()
|
||||
device_id = tensor.device.index if tensor.is_cuda else 0
|
||||
desc = tab.MemoryDesc(addr, size, device_id)
|
||||
return tab.MemoryDescs(memory_type, [desc])
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not (HAS_TORCH and HAS_CUDA),
|
||||
reason="Torch with CUDA support required for functional tests",
|
||||
)
|
||||
@pytest.mark.skipif(not HAS_NIXL, reason="NIXL backend not available")
|
||||
class TestNixlFunctionalTransfer:
|
||||
"""Functional tests for NIXL data transfer between two agents."""
|
||||
|
||||
def test_nixl_write_transfer_gpu_tensor(self):
|
||||
"""Test WRITE transfer of GPU tensor data between two NIXL agents."""
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# Create source tensor with known data pattern
|
||||
src_tensor = torch.arange(1024, dtype=torch.float32, device=device)
|
||||
|
||||
# Create destination tensor (zeros)
|
||||
dst_tensor = torch.zeros(1024, dtype=torch.float32, device=device)
|
||||
|
||||
# Verify initial state
|
||||
assert not torch.equal(src_tensor, dst_tensor)
|
||||
|
||||
# Create two agents
|
||||
config_a = tab.BaseAgentConfig(
|
||||
name="agent_a",
|
||||
use_prog_thread=True,
|
||||
use_listen_thread=False,
|
||||
)
|
||||
config_b = tab.BaseAgentConfig(
|
||||
name="agent_b",
|
||||
use_prog_thread=True,
|
||||
use_listen_thread=False,
|
||||
)
|
||||
|
||||
agent_a = tab.NixlTransferAgent(config_a)
|
||||
agent_b = tab.NixlTransferAgent(config_b)
|
||||
|
||||
# Register memory regions
|
||||
src_descs = _create_memory_descs_from_tensor(src_tensor, tab.MemoryType.VRAM)
|
||||
dst_descs = _create_memory_descs_from_tensor(dst_tensor, tab.MemoryType.VRAM)
|
||||
|
||||
agent_a.register_memory(src_descs)
|
||||
agent_b.register_memory(dst_descs)
|
||||
|
||||
# Exchange agent descriptors
|
||||
agent_a_desc = agent_a.get_local_agent_desc()
|
||||
agent_b_desc = agent_b.get_local_agent_desc()
|
||||
|
||||
agent_a.load_remote_agent("agent_b", agent_b_desc)
|
||||
agent_b.load_remote_agent("agent_a", agent_a_desc)
|
||||
|
||||
# Create transfer request: agent_a writes src_tensor to agent_b's dst_tensor
|
||||
request = tab.TransferRequest(
|
||||
tab.TransferOp.WRITE,
|
||||
src_descs, # local source
|
||||
dst_descs, # remote destination
|
||||
"agent_b", # remote agent name
|
||||
)
|
||||
|
||||
# Submit transfer and wait for completion
|
||||
status = agent_a.submit_transfer_requests(request)
|
||||
result = status.wait(timeout_ms=5000)
|
||||
|
||||
assert result == tab.TransferState.SUCCESS, f"Transfer failed with state: {result}"
|
||||
|
||||
# Synchronize CUDA to ensure transfer is complete
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Verify data was transferred correctly
|
||||
assert torch.equal(src_tensor, dst_tensor), "Data mismatch after transfer"
|
||||
|
||||
# Cleanup
|
||||
agent_a.deregister_memory(src_descs)
|
||||
agent_b.deregister_memory(dst_descs)
|
||||
|
||||
def test_nixl_write_transfer_multiple_chunks(self):
|
||||
"""Test WRITE transfer with multiple memory chunks."""
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# Create multiple source tensors
|
||||
src_tensors = [
|
||||
torch.arange(i * 256, (i + 1) * 256, dtype=torch.float32, device=device)
|
||||
for i in range(4)
|
||||
]
|
||||
|
||||
# Create corresponding destination tensors
|
||||
dst_tensors = [torch.zeros(256, dtype=torch.float32, device=device) for _ in range(4)]
|
||||
|
||||
# Create agents
|
||||
config_a = tab.BaseAgentConfig(
|
||||
name="agent_a", use_prog_thread=True, use_listen_thread=False
|
||||
)
|
||||
config_b = tab.BaseAgentConfig(
|
||||
name="agent_b", use_prog_thread=True, use_listen_thread=False
|
||||
)
|
||||
|
||||
agent_a = tab.NixlTransferAgent(config_a)
|
||||
agent_b = tab.NixlTransferAgent(config_b)
|
||||
|
||||
# Create memory descriptors for all chunks
|
||||
src_memory_descs = []
|
||||
dst_memory_descs = []
|
||||
for src, dst in zip(src_tensors, dst_tensors):
|
||||
src_memory_descs.append(
|
||||
tab.MemoryDesc(src.data_ptr(), src.numel() * src.element_size(), 0)
|
||||
)
|
||||
dst_memory_descs.append(
|
||||
tab.MemoryDesc(dst.data_ptr(), dst.numel() * dst.element_size(), 0)
|
||||
)
|
||||
|
||||
src_descs = tab.MemoryDescs(tab.MemoryType.VRAM, src_memory_descs)
|
||||
dst_descs = tab.MemoryDescs(tab.MemoryType.VRAM, dst_memory_descs)
|
||||
|
||||
# Register memory
|
||||
agent_a.register_memory(src_descs)
|
||||
agent_b.register_memory(dst_descs)
|
||||
|
||||
# Exchange agent info
|
||||
agent_a.load_remote_agent("agent_b", agent_b.get_local_agent_desc())
|
||||
agent_b.load_remote_agent("agent_a", agent_a.get_local_agent_desc())
|
||||
|
||||
# Transfer
|
||||
request = tab.TransferRequest(tab.TransferOp.WRITE, src_descs, dst_descs, "agent_b")
|
||||
status = agent_a.submit_transfer_requests(request)
|
||||
result = status.wait(timeout_ms=5000)
|
||||
|
||||
assert result == tab.TransferState.SUCCESS
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Verify all chunks
|
||||
for i, (src, dst) in enumerate(zip(src_tensors, dst_tensors)):
|
||||
assert torch.equal(src, dst), f"Data mismatch in chunk {i}"
|
||||
|
||||
# Cleanup
|
||||
agent_a.deregister_memory(src_descs)
|
||||
agent_b.deregister_memory(dst_descs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not (HAS_TORCH and HAS_CUDA),
|
||||
reason="Torch with CUDA support required for functional tests",
|
||||
)
|
||||
@pytest.mark.skipif(not HAS_MOONCAKE, reason="Mooncake backend not available")
|
||||
class TestMooncakeFunctionalTransfer:
|
||||
"""Functional tests for Mooncake data transfer between two agents."""
|
||||
|
||||
def test_mooncake_write_transfer_gpu_tensor(self):
|
||||
"""Test WRITE transfer of GPU tensor data between two Mooncake agents."""
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# Create source tensor with known data pattern
|
||||
src_tensor = torch.arange(1024, dtype=torch.float32, device=device)
|
||||
|
||||
# Create destination tensor (zeros)
|
||||
dst_tensor = torch.zeros(1024, dtype=torch.float32, device=device)
|
||||
|
||||
# Verify initial state
|
||||
assert not torch.equal(src_tensor, dst_tensor)
|
||||
|
||||
# Create two agents
|
||||
config_a = tab.BaseAgentConfig(name="mooncake_agent_a", use_prog_thread=True)
|
||||
config_b = tab.BaseAgentConfig(name="mooncake_agent_b", use_prog_thread=True)
|
||||
agent_a = tab.MooncakeTransferAgent(config_a)
|
||||
|
||||
agent_b = tab.MooncakeTransferAgent(config_b)
|
||||
# Register memory regions
|
||||
src_descs = _create_memory_descs_from_tensor(src_tensor, tab.MemoryType.VRAM)
|
||||
dst_descs = _create_memory_descs_from_tensor(dst_tensor, tab.MemoryType.VRAM)
|
||||
|
||||
agent_a.register_memory(src_descs)
|
||||
agent_b.register_memory(dst_descs)
|
||||
agent_a_desc = agent_a.get_local_agent_desc()
|
||||
|
||||
agent_b_desc = agent_b.get_local_agent_desc()
|
||||
|
||||
agent_a.load_remote_agent("mooncake_agent_b", agent_b_desc)
|
||||
agent_b.load_remote_agent("mooncake_agent_a", agent_a_desc)
|
||||
|
||||
request = tab.TransferRequest(
|
||||
tab.TransferOp.WRITE,
|
||||
src_descs, # local source
|
||||
dst_descs, # remote destination
|
||||
"mooncake_agent_b", # remote agent name
|
||||
)
|
||||
|
||||
# # Submit transfer and wait for completion
|
||||
status = agent_a.submit_transfer_requests(request)
|
||||
|
||||
result = status.wait()
|
||||
assert result == tab.TransferState.SUCCESS, f"Transfer failed with state: {result}"
|
||||
|
||||
# Synchronize CUDA to ensure transfer is complete
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Verify data was transferred correctly
|
||||
assert torch.equal(src_tensor, dst_tensor), "Data mismatch after transfer"
|
||||
|
||||
# Cleanup
|
||||
agent_a.deregister_memory(src_descs)
|
||||
agent_b.deregister_memory(dst_descs)
|
||||
|
||||
def test_mooncake_write_transfer_multiple_chunks(self):
|
||||
"""Test WRITE transfer with multiple memory chunks."""
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# Create multiple source tensors
|
||||
src_tensors = [
|
||||
torch.arange(i * 256, (i + 1) * 256, dtype=torch.float32, device=device)
|
||||
for i in range(4)
|
||||
]
|
||||
|
||||
# Create corresponding destination tensors
|
||||
dst_tensors = [torch.zeros(256, dtype=torch.float32, device=device) for _ in range(4)]
|
||||
|
||||
# Create agents
|
||||
config_a = tab.BaseAgentConfig(name="mooncake_agent_a", use_prog_thread=True)
|
||||
config_b = tab.BaseAgentConfig(name="mooncake_agent_b", use_prog_thread=True)
|
||||
|
||||
agent_a = tab.MooncakeTransferAgent(config_a)
|
||||
agent_b = tab.MooncakeTransferAgent(config_b)
|
||||
|
||||
# Create memory descriptors for all chunks
|
||||
src_memory_descs = []
|
||||
dst_memory_descs = []
|
||||
for src, dst in zip(src_tensors, dst_tensors):
|
||||
src_memory_descs.append(
|
||||
tab.MemoryDesc(src.data_ptr(), src.numel() * src.element_size(), 0)
|
||||
)
|
||||
dst_memory_descs.append(
|
||||
tab.MemoryDesc(dst.data_ptr(), dst.numel() * dst.element_size(), 0)
|
||||
)
|
||||
|
||||
src_descs = tab.MemoryDescs(tab.MemoryType.VRAM, src_memory_descs)
|
||||
dst_descs = tab.MemoryDescs(tab.MemoryType.VRAM, dst_memory_descs)
|
||||
|
||||
# Register memory
|
||||
agent_a.register_memory(src_descs)
|
||||
agent_b.register_memory(dst_descs)
|
||||
|
||||
# Exchange agent info
|
||||
agent_a.load_remote_agent("mooncake_agent_b", agent_b.get_local_agent_desc())
|
||||
agent_b.load_remote_agent("mooncake_agent_a", agent_a.get_local_agent_desc())
|
||||
|
||||
# Transfer
|
||||
request = tab.TransferRequest(
|
||||
tab.TransferOp.WRITE, src_descs, dst_descs, "mooncake_agent_b"
|
||||
)
|
||||
status = agent_a.submit_transfer_requests(request)
|
||||
result = status.wait(timeout_ms=5000)
|
||||
|
||||
assert result == tab.TransferState.SUCCESS
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Verify all chunks
|
||||
for i, (src, dst) in enumerate(zip(src_tensors, dst_tensors)):
|
||||
assert torch.equal(src, dst), f"Data mismatch in chunk {i}"
|
||||
|
||||
# Cleanup
|
||||
agent_a.deregister_memory(src_descs)
|
||||
agent_b.deregister_memory(dst_descs)
|
||||
Loading…
Reference in New Issue
Block a user