[None][feat] Support Mooncake transfer engine as a cache transceiver backend (#8309)

Signed-off-by: wjueyao <wyao123@terpmail.umd.edu>
Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
This commit is contained in:
Wangjue Yao 2025-12-19 10:09:51 +08:00 committed by GitHub
parent e0b2a94309
commit 9f283f330b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 1353 additions and 147 deletions

View File

@ -1468,7 +1468,8 @@ public:
DEFAULT = 0,
MPI = 1,
UCX = 2,
NIXL = 3
NIXL = 3,
MOONCAKE = 4
};
explicit CacheTransceiverConfig(std::optional<BackendType> backendType = std::nullopt,
std::optional<size_t> maxNumTokens = std::nullopt, std::optional<int> kvTransferTimeoutMs = std::nullopt,

View File

@ -391,6 +391,14 @@ template <typename... Args>
"libtensorrt_llm_nixl_wrapper.so", "createNixlTransferAgent");
return func(std::forward<Args>(args)...);
}
if (backend == "mooncake")
{
auto& loader = DynLibLoader::getInstance();
using CreateMooncakeFuncType = std::unique_ptr<BaseTransferAgent> (*)(BaseAgentConfig const*);
auto* func = loader.getFunctionPointer<CreateMooncakeFuncType>(
"libtensorrt_llm_mooncake_wrapper.so", "createMooncakeTransferAgent");
return func(std::forward<Args>(args)...);
}
TLLM_THROW("Unknown backend name.");
}

View File

@ -159,6 +159,10 @@ if(NIXL_ROOT)
set(NIXL_WRAPPER_TARGET tensorrt_llm_nixl_wrapper)
endif()
if(MOONCAKE_ROOT)
set(MOONCAKE_WRAPPER_TARGET tensorrt_llm_mooncake_wrapper)
endif()
add_subdirectory(executor)
find_package(Threads REQUIRED)
@ -272,6 +276,11 @@ if(TARGET ${NIXL_WRAPPER_TARGET})
add_dependencies(${SHARED_TARGET} ${NIXL_WRAPPER_TARGET})
endif()
if(TARGET ${MOONCAKE_WRAPPER_TARGET})
target_link_libraries(${MOONCAKE_WRAPPER_TARGET} INTERFACE ${SHARED_TARGET})
add_dependencies(${SHARED_TARGET} ${MOONCAKE_WRAPPER_TARGET})
endif()
if(NOT WIN32)
# Load libraries at $PREFIX/lib from
# $PREFIX/lib/python3.12/site-packages/tensorrt_llm/libs

View File

@ -81,6 +81,11 @@ std::unique_ptr<BaseCacheTransceiver> CacheTransceiverFactory::createCacheTransc
backendType = executor::CacheTransceiverConfig::BackendType::NIXL;
TLLM_LOG_INFO("Enable NIXL KV cache transport.");
}
else if (common::getEnvUseMooncakeKvCache())
{
backendType = executor::CacheTransceiverConfig::BackendType::MOONCAKE;
TLLM_LOG_INFO("Enable MOONCAKE KV cache transport.");
}
else if (common::getEnvUseMPIKvCache())
{
backendType = executor::CacheTransceiverConfig::BackendType::MPI;
@ -203,9 +208,15 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::NIXL)
{
mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
mCacheTransBufferManagerPtrs, *mCacheState);
mCacheTransBufferManagerPtrs, *mCacheState, "nixl");
TLLM_LOG_INFO("NIXL Connection Manager created");
}
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MOONCAKE)
{
mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
mCacheTransBufferManagerPtrs, *mCacheState, "mooncake");
TLLM_LOG_INFO("MOONCAKE Connection Manager created");
}
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MPI)
{
mMpiWorldComm = std::addressof(tensorrt_llm::mpi::MpiComm::world());

View File

@ -281,6 +281,12 @@ bool getEnvUseNixlKvCache()
return useNixlKvCache;
}
bool getEnvUseMooncakeKvCache()
{
static bool const useMooncakeKvCache = getBoolEnv("TRTLLM_USE_MOONCAKE_KVCACHE");
return useMooncakeKvCache;
}
bool getEnvUseRoundRobinBlockDistForCP()
{
static bool const useRoundRobinBlockDistForCP = getBoolEnv("TRTLLM_USE_ROUND_ROBIN_BLOCK_DIST_FOR_CP");
@ -343,6 +349,23 @@ std::string getEnvNixlBackend()
return nixlBackend;
}
std::string getEnvMooncakeInterface()
{
static std::once_flag flag;
static std::string mooncakeInterface;
std::call_once(flag,
[&]()
{
char const* mooncake_interface = std::getenv("TRTLLM_MOONCAKE_INTERFACE");
if (mooncake_interface)
{
mooncakeInterface = mooncake_interface;
}
});
return mooncakeInterface;
}
bool getEnvDisaggLayerwise()
{
static bool const disaggLayerwise = getBoolEnv("TRTLLM_DISAGG_LAYERWISE");

View File

@ -83,8 +83,11 @@ inline void launchWithPdlWhenEnabled(char const* name, KernelFn kernelFn, dim3 g
bool getEnvUseUCXKvCache();
bool getEnvUseMPIKvCache();
bool getEnvUseNixlKvCache();
bool getEnvUseMooncakeKvCache();
bool getEnvUseRoundRobinBlockDistForCP();
std::string getEnvUCXInterface();
@ -93,6 +96,8 @@ std::string getEnvNixlInterface();
std::string getEnvNixlBackend();
std::string getEnvMooncakeInterface();
bool getEnvDisaggLayerwise();
bool getEnvParallelCacheSend();

View File

@ -0,0 +1,226 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ipUtils.h"
#include "tensorrt_llm/common/logger.h"
#include <arpa/inet.h>
#include <dirent.h>
#include <fcntl.h>
#include <ifaddrs.h>
#include <net/if.h>
#include <netdb.h>
#include <netinet/in.h>
#include <string>
#include <sys/socket.h>
#include <unistd.h>
TRTLLM_NAMESPACE_BEGIN
namespace common
{
std::string getLocalIpByNic(std::string const& interface, int rank)
{
struct ifaddrs* ifaddr = nullptr;
if (getifaddrs(&ifaddr) == -1)
{
TLLM_LOG_ERROR(rank,
"getLocalIpByNic: Can't get local ip from NIC Interface. Please check whether corresponding INTERFACE is "
"set "
"correctly.");
return std::string{};
}
for (struct ifaddrs* ifa = ifaddr; ifa != nullptr; ifa = ifa->ifa_next)
{
if (ifa->ifa_addr == nullptr)
{
continue;
}
if (ifa->ifa_name == interface)
{
if (ifa->ifa_addr->sa_family == AF_INET)
{
char ip[INET_ADDRSTRLEN]{};
void* addr = &((reinterpret_cast<struct sockaddr_in*>(ifa->ifa_addr))->sin_addr);
if ((inet_ntop(AF_INET, addr, ip, sizeof(ip)) != nullptr) && std::strcmp(ip, "0.0.0.0") != 0)
{
freeifaddrs(ifaddr);
return std::string(ip);
}
}
else if (ifa->ifa_addr->sa_family == AF_INET6)
{
char ip[INET6_ADDRSTRLEN]{};
void* addr = &((reinterpret_cast<struct sockaddr_in6*>(ifa->ifa_addr))->sin6_addr);
if ((inet_ntop(AF_INET6, addr, ip, sizeof(ip)) != nullptr) && std::strncmp(ip, "fe80::", 6) != 0
&& std::strcmp(ip, "::1") != 0)
{
freeifaddrs(ifaddr);
return std::string(ip);
}
}
}
}
freeifaddrs(ifaddr);
TLLM_LOG_ERROR(
rank, "Can't get local ip from NIC Interface. Please check whether corresponding INTERFACE is set correctly.");
return std::string{};
}
std::string getLocalIpByHostname(int rank)
{
char hostname[256]{};
if (gethostname(hostname, sizeof(hostname)) == -1)
{
TLLM_LOG_ERROR(rank, "getLocalIpByHostname: Can't get hostname");
return std::string{};
}
struct addrinfo hints = {};
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = AI_CANONNAME;
struct addrinfo* res = nullptr;
if (getaddrinfo(hostname, nullptr, &hints, &res) != 0)
{
TLLM_LOG_WARNING(rank, "getLocalIpByHostname: Can't get address info for hostname");
return std::string{};
}
for (struct addrinfo* p = res; p != nullptr; p = p->ai_next)
{
if (p->ai_family == AF_INET)
{ // IPv4
char ip[INET_ADDRSTRLEN]{};
struct sockaddr_in* ipv4 = reinterpret_cast<struct sockaddr_in*>(p->ai_addr);
void* addr = &(ipv4->sin_addr);
if ((inet_ntop(AF_INET, addr, ip, sizeof(ip)) != nullptr) && std::strcmp(ip, "127.0.0.1") != 0
&& std::strcmp(ip, "0.0.0.0") != 0)
{
freeaddrinfo(res);
return std::string(ip);
}
}
else if (p->ai_family == AF_INET6)
{ // IPv6
char ip[INET6_ADDRSTRLEN]{};
struct sockaddr_in6* ipv6 = reinterpret_cast<struct sockaddr_in6*>(p->ai_addr);
void* addr = &(ipv6->sin6_addr);
if ((inet_ntop(AF_INET6, addr, ip, sizeof(ip)) != nullptr) && std::strncmp(ip, "fe80::", 6) != 0
&& std::strcmp(ip, "::1") != 0)
{
freeaddrinfo(res);
return std::string(ip);
}
}
}
freeaddrinfo(res);
TLLM_LOG_WARNING(rank, "getLocalIpByHostname: Can't get local ip from hostname");
return std::string{};
}
std::string getLocalIpByRemoteOrHostName(int rank)
{
// Try IPv4
struct sockaddr_in addr
{
};
addr.sin_family = AF_INET;
addr.sin_port = htons(80);
// using google's public dns server to get the local ip which can be accessed from remote
char const* dns_ip_v4 = "8.8.8.8";
inet_pton(AF_INET, dns_ip_v4, &addr.sin_addr);
int sock = socket(AF_INET, SOCK_DGRAM, 0);
if (sock != -1)
{
if (connect(sock, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) != -1)
{
socklen_t addr_len = sizeof(addr);
if (getsockname(sock, reinterpret_cast<struct sockaddr*>(&addr), &addr_len) != -1)
{
char ip[INET_ADDRSTRLEN]{};
inet_ntop(AF_INET, &addr.sin_addr, ip, sizeof(ip));
close(sock);
return std::string(ip);
}
}
close(sock);
}
// Try IPv6
struct sockaddr_in6 addr6
{
};
addr6.sin6_family = AF_INET6;
addr6.sin6_port = htons(80);
// using google's public dns server
char const* dns_ipv6 = "2001:4860:4860::8888";
inet_pton(AF_INET6, dns_ipv6, &addr6.sin6_addr);
sock = socket(AF_INET6, SOCK_DGRAM, 0);
if (sock != -1)
{
if (connect(sock, reinterpret_cast<struct sockaddr*>(&addr6), sizeof(addr6)) != -1)
{
socklen_t addr_len = sizeof(addr6);
if (getsockname(sock, reinterpret_cast<struct sockaddr*>(&addr6), &addr_len) != -1)
{
char ip[INET6_ADDRSTRLEN]{};
inet_ntop(AF_INET6, &addr6.sin6_addr, ip, sizeof(ip));
close(sock);
return std::string(ip);
}
}
close(sock);
}
// Try hostname
return getLocalIpByHostname(rank);
}
std::string getLocalIp(std::string interface, int rank)
{
std::string localIP = {};
if (!interface.empty())
{
localIP = getLocalIpByNic(interface, rank);
}
if (localIP.empty())
{
localIP = getLocalIpByRemoteOrHostName(rank);
}
// check whether the localIP is valid
if (localIP.empty())
{
TLLM_THROW("getLocalIp: Can't get local ip");
}
return localIP;
}
} // namespace common
TRTLLM_NAMESPACE_END

View File

@ -0,0 +1,28 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/config.h"
#include <string>
TRTLLM_NAMESPACE_BEGIN
namespace common
{
std::string getLocalIp(std::string interface, int rank);
} // namespace common
TRTLLM_NAMESPACE_END

View File

@ -91,3 +91,4 @@ target_compile_definitions(${EXECUTOR_STATIC_TARGET}
add_subdirectory(cache_transmission/ucx_utils)
add_subdirectory(cache_transmission/nixl_utils)
add_subdirectory(cache_transmission/mooncake_utils)

View File

@ -236,7 +236,7 @@ bool AgentConnection::recvReadySignal(DataContext const& ctx) const
AgentConnectionManager::AgentConnectionManager(
std::vector<batch_manager::kv_cache_manager::CacheTransBufferManager*> cacheTransBufferManagers,
CacheState cacheState)
CacheState cacheState, std::string const& backendType)
: mCacheState(std::move(cacheState))
, mCacheTransBufferManagers(std::move(cacheTransBufferManagers))
, mRegMemDescs(MemoryType::kVRAM, {})
@ -247,7 +247,7 @@ AgentConnectionManager::AgentConnectionManager(
mAgentName = genUniqueAgentName();
// Create Agent
BaseAgentConfig config{mAgentName, true};
m_Agent = makeTransferAgent("nixl", &config);
m_Agent = makeTransferAgent(backendType, &config);
TLLM_CHECK(!mCacheTransBufferManagers.empty());
std::vector<MemoryDesc> memDescs;
for (auto* cacheTransBufferManager : mCacheTransBufferManagers)

View File

@ -277,7 +277,7 @@ class AgentConnectionManager : public ConnectionManager
public:
AgentConnectionManager(
std::vector<batch_manager::kv_cache_manager::CacheTransBufferManager*> cacheTransBufferManagers,
CacheState cacheState);
CacheState cacheState, std::string const& backendType);
~AgentConnectionManager();
AgentConnection* recvConnect(DataContext const& ctx, void* data, size_t size) override;
[[nodiscard]] std::vector<Connection const*> getConnections(CommState const& state) override;

View File

@ -0,0 +1,40 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved. SPDX-License-Identifier: NVIDIA TensorRT
# Source Code License Agreement
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this material and related documentation without an express
# license agreement from NVIDIA CORPORATION or its affiliates is strictly
# prohibited.
# MOONCAKE is not supported on Rocky8 for now
set(IS_ROCKY8 FALSE)
if(EXISTS "/etc/redhat-release")
set(IS_ROCKY8 TRUE)
endif()
if(MOONCAKE_ROOT AND NOT IS_ROCKY8)
find_library(TRANSFER_ENGINE_LIB transfer_engine ${MOONCAKE_ROOT}/lib)
find_path(TRANSFER_ENGINE_INCLUDE_DIR transfer_engine_c.h
${MOONCAKE_ROOT}/include)
message(STATUS "Find transfer engine results:")
message(STATUS " TRANSFER_ENGINE_LIB = ${TRANSFER_ENGINE_LIB}")
message(
STATUS " TRANSFER_ENGINE_INCLUDE_DIR = ${TRANSFER_ENGINE_INCLUDE_DIR}")
if(TRANSFER_ENGINE_LIB AND TRANSFER_ENGINE_INCLUDE_DIR)
set(MOONCAKE_WRAPPER_TARGET "tensorrt_llm_mooncake_wrapper")
add_library(${MOONCAKE_WRAPPER_TARGET} SHARED transferAgent.cpp)
target_compile_options(${MOONCAKE_WRAPPER_TARGET} PRIVATE -Wno-error)
target_include_directories(${MOONCAKE_WRAPPER_TARGET}
PRIVATE ${TRANSFER_ENGINE_INCLUDE_DIR})
target_link_libraries(${MOONCAKE_WRAPPER_TARGET}
PRIVATE ${TRANSFER_ENGINE_LIB} CUDA::cudart)
endif()
endif()

View File

@ -0,0 +1,546 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/executor/cache_transmission/mooncake_utils/transferAgent.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/common/ipUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/executor/transferAgent.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include <algorithm>
#include <arpa/inet.h>
#include <chrono>
#include <dirent.h>
#include <fcntl.h>
#include <ifaddrs.h>
#include <net/if.h>
#include <netdb.h>
#include <netinet/in.h>
#include <sys/file.h>
#include <sys/stat.h>
#include <thread>
#include <unistd.h>
namespace tensorrt_llm::executor::kv_cache
{
MooncakeTransferStatus::MooncakeTransferStatus(transfer_engine_t engine, uint64_t batchId, size_t requestCount)
: mEngine{engine}
, mBatchId{batchId}
, mRequestCount{requestCount}
{
TLLM_CHECK(mEngine);
}
void MooncakeTransferStatus::wait() const
{
while (!isCompleted())
{
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
}
[[nodiscard]] bool MooncakeTransferStatus::isCompleted() const
{
if (mBatchFreed)
{
return true;
}
bool has_failed = false;
for (size_t index = 0; index < mRequestCount; ++index)
{
transfer_status_t status;
int rc = getTransferStatus(mEngine, mBatchId, index, &status);
if (rc || status.status == STATUS_FAILED)
{
has_failed = true;
if (rc)
{
TLLM_LOG_ERROR(
"Failed to get transfer status for batch %lu, task %zu: error code %d", mBatchId, index, rc);
}
else
{
TLLM_LOG_ERROR("Transfer failed for batch %lu, task %zu: status %d", mBatchId, index, status.status);
}
}
else if (status.status == STATUS_PENDING || status.status == STATUS_WAITING)
{
TLLM_LOG_DEBUG("Transfer is pending for batch %lu, task %zu", mBatchId, index);
return false;
}
}
if (!has_failed)
{
// Each batchId has the batch size, and cannot process more requests
// than the batch size. So, free the batch id here to workaround the issue
// where the same batchId could be used to post multiple transfer.
freeBatchID(mEngine, mBatchId);
mBatchFreed = true;
TLLM_LOG_DEBUG("Batch ID %lu freed, future calls will return true directly", mBatchId);
}
// Currently, we cannot distinguish between failed and completed from return value.
TLLM_LOG_DEBUG("Transfer is completed for batch %lu", mBatchId);
return true;
}
const std::string MooncakeBase64Helper::STANDARD_CHARS
= "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/";
std::string MooncakeBase64Helper::encode(std::vector<uint8_t> const& data)
{
return encodeInternal(data, STANDARD_CHARS);
}
std::string MooncakeBase64Helper::encode(std::string const& data)
{
std::vector<uint8_t> vec(data.begin(), data.end());
return encode(vec);
}
std::vector<uint8_t> MooncakeBase64Helper::decode(std::string const& encoded)
{
return decodeInternal(encoded, STANDARD_CHARS);
}
std::string MooncakeBase64Helper::decodeToString(std::string const& encoded)
{
auto vec = decode(encoded);
return std::string(vec.begin(), vec.end());
}
std::string MooncakeBase64Helper::encodeInternal(std::vector<uint8_t> const& data, std::string const& chars)
{
std::string encoded;
size_t i = 0;
size_t j = 0;
std::array<uint8_t, 3> charArray3{};
std::array<uint8_t, 4> charArray4{};
size_t dataLen = data.size();
uint8_t const* bytes = data.data();
while (dataLen--)
{
charArray3[i++] = *(bytes++);
if (i == 3)
{
charArray4[0] = (charArray3[0] & 0xfc) >> 2;
charArray4[1] = ((charArray3[0] & 0x03) << 4) + ((charArray3[1] & 0xf0) >> 4);
charArray4[2] = ((charArray3[1] & 0x0f) << 2) + ((charArray3[2] & 0xc0) >> 6);
charArray4[3] = charArray3[2] & 0x3f;
for (i = 0; i < 4; i++)
{
encoded += chars[charArray4[i]];
}
i = 0;
}
}
if (i > 0)
{
for (j = i; j < 3; j++)
{
charArray3[j] = '\0';
}
charArray4[0] = (charArray3[0] & 0xfc) >> 2;
charArray4[1] = ((charArray3[0] & 0x03) << 4) + ((charArray3[1] & 0xf0) >> 4);
charArray4[2] = ((charArray3[1] & 0x0f) << 2) + ((charArray3[2] & 0xc0) >> 6);
charArray4[3] = charArray3[2] & 0x3f;
for (j = 0; j < i + 1; j++)
{
encoded += chars[charArray4[j]];
}
while (i++ < 3)
{
encoded += '=';
}
}
return encoded;
}
std::vector<uint8_t> MooncakeBase64Helper::decodeInternal(std::string const& encoded, std::string const& chars)
{
size_t encodedLen = encoded.size();
size_t i = 0;
size_t j = 0;
size_t in_ = 0;
std::array<uint8_t, 3> charArray3{};
std::array<uint8_t, 4> charArray4{};
std::vector<uint8_t> decoded;
std::string cleanEncoded;
for (char c : encoded)
{
if (!isWhitespace(c))
{
cleanEncoded += c;
}
}
encodedLen = cleanEncoded.size();
while (encodedLen-- && cleanEncoded[in_] != '=' && isBase64(cleanEncoded[in_], chars))
{
charArray4[i++] = cleanEncoded[in_];
in_++;
if (i == 4)
{
for (i = 0; i < 4; i++)
{
charArray4[i] = chars.find(charArray4[i]);
}
charArray3[0] = (charArray4[0] << 2) + ((charArray4[1] & 0x30) >> 4);
charArray3[1] = ((charArray4[1] & 0xf) << 4) + ((charArray4[2] & 0x3c) >> 2);
charArray3[2] = ((charArray4[2] & 0x3) << 6) + charArray4[3];
for (i = 0; i < 3; i++)
{
decoded.push_back(charArray3[i]);
}
i = 0;
}
}
if (i > 0)
{
for (j = i; j < 4; j++)
{
charArray4[j] = 0;
}
for (j = 0; j < 4; j++)
{
charArray4[j] = chars.find(charArray4[j]);
}
charArray3[0] = (charArray4[0] << 2) + ((charArray4[1] & 0x30) >> 4);
charArray3[1] = ((charArray4[1] & 0xf) << 4) + ((charArray4[2] & 0x3c) >> 2);
charArray3[2] = ((charArray4[2] & 0x3) << 6) + charArray4[3];
for (j = 0; j < i - 1; j++)
{
decoded.push_back(charArray3[j]);
}
}
return decoded;
}
bool MooncakeBase64Helper::isBase64(uint8_t c, std::string const& chars)
{
return (isalnum(c) || (c == chars[62]) || (c == chars[63]));
}
bool MooncakeBase64Helper::isWhitespace(uint8_t c)
{
return (c == ' ' || c == '\n' || c == '\r' || c == '\t');
}
MooncakeTransferAgent::MooncakeTransferAgent(BaseAgentConfig const& config)
{
mLocalAgentName = config.mName;
std::string segmentName = "127.0.0.1";
if (getenv("TLLM_MOONCAKE_IP_ADDR"))
{
segmentName = std::string(getenv("TLLM_MOONCAKE_IP_ADDR"));
}
else
{
auto ip = common::getLocalIp(common::getEnvMooncakeInterface(), mpi::MpiComm::session().getRank());
if (!ip.empty())
segmentName = ip;
}
mEngine = createTransferEngine("P2PHANDSHAKE", segmentName.c_str(), "", 0, true);
}
void MooncakeTransferAgent::registerMemory(RegisterDescs const& descs)
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::registerMemory");
std::lock_guard<std::mutex> lock(mMutex);
for (auto const& desc : descs.getDescs())
{
auto it = mMemRegInfo.find(desc.getAddr());
if (it != mMemRegInfo.end())
{
it->second->addRef();
continue;
}
int err = registerLocalMemory(mEngine, reinterpret_cast<void*>(desc.getAddr()), desc.getLen(), "*", 1);
TLLM_CHECK_WITH_INFO(err == 0, "registerLocalMemory failed, addr: %p, len: %lu",
reinterpret_cast<void*>(desc.getAddr()), desc.getLen());
auto mooncakeDesc = std::make_shared<MooncakeMemoryDesc>(desc);
mMemRegInfo[desc.getAddr()] = std::move(mooncakeDesc);
}
}
void MooncakeTransferAgent::deregisterMemory(RegisterDescs const& descs)
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::deregisterMemory");
std::lock_guard<std::mutex> lock(mMutex);
for (auto const& desc : descs.getDescs())
{
auto it = mMemRegInfo.find(desc.getAddr());
if (it != mMemRegInfo.end())
{
auto const& mooncakeDesc = it->second;
mooncakeDesc->releaseRef();
if (mooncakeDesc->getRefCount())
continue;
int err = unregisterLocalMemory(mEngine, reinterpret_cast<void*>(desc.getAddr()));
TLLM_CHECK_WITH_INFO(
err == 0, "unregisterLocalMemory failed, addr: %p", reinterpret_cast<void*>(desc.getAddr()));
mMemRegInfo.erase(desc.getAddr());
}
}
}
void MooncakeTransferAgent::loadRemoteAgent(std::string const& name, AgentDesc const& agentDesc)
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::loadRemoteAgent");
// Do the same thing as loadRemoteAgent(std::string const& name, ConnectionInfoType const& connectionInfo)
loadRemoteAgent(name, std::move(agentDesc.getBackendAgentDesc()));
}
void MooncakeTransferAgent::loadRemoteAgent(std::string const& name, ConnectionInfoType const& connectionInfo)
{
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
"MooncakeTransferAgent::loadRemoteAgent loadRemoteAgent to %s remoteagent name: %s", connectionInfo.c_str(),
name.c_str());
std::lock_guard<std::mutex> lock(mMutex);
auto segmentId = openSegment(mEngine, connectionInfo.c_str());
TLLM_CHECK_WITH_INFO(
segmentId >= 0, "loadRemoteAgent openSegment failed, connectionInfo: %s", connectionInfo.c_str());
mConnectedAgents[name].segmentId = segmentId;
}
void MooncakeTransferAgent::invalidateRemoteAgent(std::string const& name)
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::invalidateRemoteAgent");
}
AgentDesc MooncakeTransferAgent::getLocalAgentDesc()
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::getLocalAgentDesc");
// Using connection info as agent desc
const static size_t kBufLen = 64;
char connectionInfo[kBufLen];
int ret = getLocalIpAndPort(mEngine, connectionInfo, kBufLen);
TLLM_CHECK_WITH_INFO(ret == 0, "MooncakeTransferAgent::getLocalAgentDesc::getLocalIpAndPort failed");
return AgentDesc{std::string(connectionInfo)};
}
ConnectionInfoType MooncakeTransferAgent::getLocalConnectionInfo()
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::getLocalConnectionInfo");
const static size_t kBufLen = 64;
char connectionInfo[kBufLen];
int ret = getLocalIpAndPort(mEngine, connectionInfo, kBufLen);
TLLM_CHECK_WITH_INFO(ret == 0, "MooncakeTransferAgent::getLocalAgentDesc::getLocalConnectionInfo failed");
return std::string(connectionInfo);
}
[[nodiscard]] std::unique_ptr<TransferStatus> MooncakeTransferAgent::submitTransferRequests(
TransferRequest const& request)
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::submitTransferRequests");
bool hasNotif = false;
std::string syncMessage;
if (request.getSyncMessage().has_value())
{
hasNotif = true;
syncMessage = request.getSyncMessage().value();
}
const static size_t kMaxRequestCount = 1024;
uint64_t batchId = allocateBatchID(mEngine, kMaxRequestCount);
TLLM_CHECK_WITH_INFO(batchId != INVALID_BATCH, "allocateBatchID failed");
int segmentId;
{
std::lock_guard<std::mutex> lock(mMutex);
std::string remoteName = request.getRemoteName();
auto it = mConnectedAgents.find(remoteName);
if (it == mConnectedAgents.end())
{
std::string error = "Remote agent " + remoteName + "not found";
TLLM_THROW(error);
}
auto const& agentInfo = it->second;
segmentId = agentInfo.segmentId;
}
auto localDescs = request.getSrcDescs().getDescs();
auto remoteDescs = request.getDstDescs().getDescs();
TLLM_CHECK_WITH_INFO(localDescs.size() == remoteDescs.size(), "Number of local and remote memory must match");
size_t requestCount = localDescs.size();
std::vector<transfer_request_t> transferRequests(requestCount);
for (size_t index = 0; index < requestCount; ++index)
{
TLLM_CHECK_WITH_INFO(
localDescs[index].getLen() == remoteDescs[index].getLen(), "Length of local and remote memory must match");
transferRequests[index].opcode = (request.getOp() == TransferOp::kREAD) ? OPCODE_READ : OPCODE_WRITE;
transferRequests[index].source = reinterpret_cast<void*>(localDescs[index].getAddr());
transferRequests[index].target_offset = remoteDescs[index].getAddr();
transferRequests[index].length = localDescs[index].getLen();
transferRequests[index].target_id = segmentId;
}
int rc = 0;
if (hasNotif)
{
notify_msg_t notifyMsg;
notifyMsg.name = const_cast<char*>(mLocalAgentName.c_str());
notifyMsg.msg = const_cast<char*>(syncMessage.c_str());
rc = submitTransferWithNotify(mEngine, batchId, transferRequests.data(), requestCount, notifyMsg);
}
else
{
rc = submitTransfer(mEngine, batchId, transferRequests.data(), requestCount);
}
TLLM_CHECK_WITH_INFO(rc == 0, "submitTransfer failed with status: %d", rc);
return std::make_unique<MooncakeTransferStatus>(mEngine, batchId, requestCount);
}
void MooncakeTransferAgent::notifySyncMessage(std::string const& name, SyncMessage const& syncMessage)
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::notifySyncMessage");
int segmentId;
{
std::lock_guard<std::mutex> lock(mMutex);
auto it = mConnectedAgents.find(name);
if (it == mConnectedAgents.end())
{
TLLM_LOG_WARNING("Remote agent %s not found", name.c_str());
return;
}
auto const& agentInfo = it->second;
segmentId = agentInfo.segmentId;
}
notify_msg_t notifyMsg;
notifyMsg.name = const_cast<char*>(mLocalAgentName.c_str());
std::string encoded = MooncakeBase64Helper::encode(syncMessage);
notifyMsg.msg = const_cast<char*>(encoded.c_str());
TLLM_LOG_DEBUG("MooncakeTransferAgent::notifySyncMessage notifyMsg.name: %s, notifyMsg.msg: %s", notifyMsg.name,
notifyMsg.msg);
int ret = genNotifyInEngine(mEngine, segmentId, notifyMsg);
TLLM_CHECK_WITH_INFO(ret == 0, "genNotifyInEngine failed with status: %d", ret);
}
[[nodiscard]] std::unordered_map<std::string, std::vector<SyncMessage>> MooncakeTransferAgent::getNotifiedSyncMessages()
{
std::unordered_map<std::string, std::vector<SyncMessage>> notifs;
int size = 0;
notify_msg_t* notifyMsgs = getNotifsFromEngine(mEngine, &size);
TLLM_CHECK_WITH_INFO(size >= 0, "getNotifsFromEngine returned negative size: %d", size);
for (int i = 0; i < size; i++)
{
if (notifyMsgs[i].msg == nullptr)
{
TLLM_LOG_WARNING("Message pointer is null for: %s", notifyMsgs[i].name);
continue;
}
std::string decoded = MooncakeBase64Helper::decodeToString(notifyMsgs[i].msg);
notifs[notifyMsgs[i].name].emplace_back(std::move(decoded));
TLLM_LOG_DEBUG("MooncakeTransferAgent::getNotifiedSyncMessages getNotifsFromEngine: %s, %s", notifyMsgs[i].name,
notifyMsgs[i].msg);
}
freeNotifsMsgBuf(notifyMsgs, size);
return notifs;
}
bool MooncakeTransferAgent::checkRemoteDescs(std::string const& name, MemoryDescs const& memoryDescs)
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::checkRemoteDescs");
return true;
}
MooncakeTransferAgent::~MooncakeTransferAgent()
{
destroyTransferEngine(mEngine);
TLLM_LOG_DEBUG("MooncakeTransferAgent::~MooncakeTransferAgent");
}
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreturn-type-c-linkage"
#endif
extern "C"
{
std::unique_ptr<BaseTransferAgent> createMooncakeTransferAgent(BaseAgentConfig const* config)
{
TLLM_CHECK(config);
return std::make_unique<MooncakeTransferAgent>(*config);
}
}
} // namespace tensorrt_llm::executor::kv_cache

View File

@ -0,0 +1,165 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <atomic>
#include <mutex>
#include <thread>
#include <vector>
#include "tensorrt_llm/executor/transferAgent.h"
#include "transfer_engine_c.h"
namespace tensorrt_llm::executor::kv_cache
{
class MooncakeTransferStatus final : public TransferStatus
{
public:
MooncakeTransferStatus(transfer_engine_t engine, uint64_t batchId, size_t requestCount);
[[nodiscard]] bool isCompleted() const override;
void wait() const override;
private:
transfer_engine_t mEngine;
uint64_t mBatchId;
size_t mRequestCount;
mutable bool mBatchFreed = false;
};
class MooncakeMemoryDesc
{
public:
MooncakeMemoryDesc(MemoryDesc desc)
: mDesc{std::move(desc)}
, mRefCnt{0}
{
}
MooncakeMemoryDesc(MooncakeMemoryDesc const& other)
: mDesc{other.mDesc}
, mRefCnt{0}
{
}
MooncakeMemoryDesc& operator=(MooncakeMemoryDesc const&) = delete;
~MooncakeMemoryDesc() = default;
void addRef() noexcept
{
++mRefCnt;
}
int releaseRef() noexcept
{
return --mRefCnt;
}
int getRefCount() const noexcept
{
return mRefCnt;
}
MemoryDesc const& getDesc() const noexcept
{
return mDesc;
}
private:
MemoryDesc mDesc;
int mRefCnt;
};
class MooncakeBase64Helper
{
public:
static std::string encode(std::vector<uint8_t> const& data);
static std::string encode(std::string const& data);
static std::vector<uint8_t> decode(std::string const& encoded);
static std::string decodeToString(std::string const& encoded);
private:
static const std::string STANDARD_CHARS;
static std::string encodeInternal(std::vector<uint8_t> const& data, std::string const& chars);
static std::vector<uint8_t> decodeInternal(std::string const& encoded, std::string const& chars);
static inline bool isBase64(uint8_t c, std::string const& chars);
static inline bool isWhitespace(uint8_t c);
};
class MooncakeTransferAgent final : public BaseTransferAgent
{
public:
MooncakeTransferAgent(BaseAgentConfig const& config);
~MooncakeTransferAgent();
void registerMemory(RegisterDescs const& descs) override;
void deregisterMemory(RegisterDescs const& descs) override;
void loadRemoteAgent(std::string const& name, AgentDesc const& agentDesc) override;
void loadRemoteAgent(std::string const& name, ConnectionInfoType const& connectionInfo) override;
void invalidateRemoteAgent(std::string const& name) override;
AgentDesc getLocalAgentDesc() override;
ConnectionInfoType getLocalConnectionInfo() override;
[[nodiscard]] std::unique_ptr<TransferStatus> submitTransferRequests(TransferRequest const& request) override;
void notifySyncMessage(std::string const& name, SyncMessage const& syncMessage) override;
[[nodiscard]] std::unordered_map<std::string, std::vector<SyncMessage>> getNotifiedSyncMessages() override;
bool checkRemoteDescs(std::string const& name, MemoryDescs const& memoryDescs) override;
private:
struct AgentInfo
{
int segmentId;
};
mutable std::mutex mMutex;
transfer_engine_t mEngine;
std::unordered_map<uintptr_t, std::shared_ptr<MooncakeMemoryDesc>> mMemRegInfo;
std::unordered_map<std::string, AgentInfo> mConnectedAgents;
std::string mLocalAgentName;
};
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreturn-type-c-linkage"
#endif
extern "C"
{
[[nodiscard]] std::unique_ptr<BaseTransferAgent> createMooncakeTransferAgent(BaseAgentConfig const* config);
}
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
} // namespace tensorrt_llm::executor::kv_cache

View File

@ -449,6 +449,7 @@ void initConfigBindings(nb::module_& m)
.value("MPI", tle::CacheTransceiverConfig::BackendType::MPI)
.value("UCX", tle::CacheTransceiverConfig::BackendType::UCX)
.value("NIXL", tle::CacheTransceiverConfig::BackendType::NIXL)
.value("MOONCAKE", tle::CacheTransceiverConfig::BackendType::MOONCAKE)
.def("from_string",
[](std::string const& str)
{
@ -460,6 +461,8 @@ void initConfigBindings(nb::module_& m)
return tle::CacheTransceiverConfig::BackendType::UCX;
if (str == "NIXL" || str == "nixl")
return tle::CacheTransceiverConfig::BackendType::NIXL;
if (str == "MOONCAKE" || str == "mooncake")
return tle::CacheTransceiverConfig::BackendType::MOONCAKE;
throw std::runtime_error("Invalid backend type: " + str);
});

View File

@ -431,6 +431,7 @@ void initConfigBindings(pybind11::module_& m)
.value("MPI", tle::CacheTransceiverConfig::BackendType::MPI)
.value("UCX", tle::CacheTransceiverConfig::BackendType::UCX)
.value("NIXL", tle::CacheTransceiverConfig::BackendType::NIXL)
.value("MOONCAKE", tle::CacheTransceiverConfig::BackendType::MOONCAKE)
.def("from_string",
[](std::string const& str)
{
@ -442,6 +443,8 @@ void initConfigBindings(pybind11::module_& m)
return tle::CacheTransceiverConfig::BackendType::UCX;
if (str == "NIXL" || str == "nixl")
return tle::CacheTransceiverConfig::BackendType::NIXL;
if (str == "MOONCAKE" || str == "mooncake")
return tle::CacheTransceiverConfig::BackendType::MOONCAKE;
throw std::runtime_error("Invalid backend type: " + str);
});

View File

@ -38,10 +38,31 @@ add_gtest(ucxCommTest ucxCommTest.cpp)
target_link_libraries(ucxCommTest PRIVATE ${Python3_LIBRARIES})
target_link_libraries(serializeUtilsTest PRIVATE ${Python3_LIBRARIES})
if(NIXL_ROOT)
add_gtest(transferAgentTest transferAgentTest.cpp)
add_gtest(agentCommTest agentCommTest.cpp)
target_link_libraries(transferAgentTest PRIVATE tensorrt_llm_nixl_wrapper)
target_link_libraries(agentCommTest PRIVATE tensorrt_llm_nixl_wrapper
${Python3_LIBRARIES})
# Skip MOONCAKE related tests on Rocky8
set(IS_ROCKY8 FALSE)
if(EXISTS "/etc/redhat-release")
set(IS_ROCKY8 TRUE)
endif()
if(NIXL_ROOT OR (MOONCAKE_ROOT AND NOT IS_ROCKY8))
add_gtest(agentCommTest agentCommTest.cpp)
add_gtest(transferAgentTest transferAgentTest.cpp)
if(NIXL_ROOT)
target_link_libraries(transferAgentTest PRIVATE tensorrt_llm_nixl_wrapper)
target_link_libraries(agentCommTest PRIVATE tensorrt_llm_nixl_wrapper
${Python3_LIBRARIES})
target_compile_definitions(transferAgentTest PRIVATE TEST_NIXL_BACKEND=1)
target_compile_definitions(agentCommTest PRIVATE TEST_NIXL_BACKEND=1)
endif()
if(MOONCAKE_ROOT)
target_link_libraries(transferAgentTest
PRIVATE tensorrt_llm_mooncake_wrapper)
target_link_libraries(agentCommTest PRIVATE tensorrt_llm_mooncake_wrapper
${Python3_LIBRARIES})
target_compile_definitions(transferAgentTest
PRIVATE TEST_MOONCAKE_BACKEND=1)
target_compile_definitions(agentCommTest PRIVATE TEST_MOONCAKE_BACKEND=1)
endif()
endif()

View File

@ -22,22 +22,54 @@ using namespace tensorrt_llm::batch_manager::kv_cache_manager;
using namespace tensorrt_llm::runtime;
using namespace tensorrt_llm::executor::kv_cache;
bool needSkipTest(std::string& skipReason)
std::vector<std::string> getAvailableBackends()
{
std::vector<std::string> backends;
#ifdef TEST_NIXL_BACKEND
backends.push_back("nixl");
#endif
#ifdef TEST_MOONCAKE_BACKEND
backends.push_back("mooncake");
#endif
return backends;
}
bool needSkipTest(std::string const& backend, std::string& skipReason)
{
bool skip = false;
try
{
auto& loader = tensorrt_llm::executor::kv_cache::DynLibLoader::getInstance();
using CreateNixlFuncType = std::unique_ptr<tensorrt_llm::executor::kv_cache::BaseTransferAgent> (*)(
tensorrt_llm::executor::kv_cache::BaseAgentConfig const*);
auto* func = loader.getFunctionPointer<CreateNixlFuncType>(
"libtensorrt_llm_nixl_wrapper.so", "createNixlTransferAgent");
if (backend == "nixl")
{
using CreateNixlFuncType = std::unique_ptr<tensorrt_llm::executor::kv_cache::BaseTransferAgent> (*)(
tensorrt_llm::executor::kv_cache::BaseAgentConfig const*);
auto* func = loader.getFunctionPointer<CreateNixlFuncType>(
"libtensorrt_llm_nixl_wrapper.so", "createNixlTransferAgent");
}
else if (backend == "mooncake")
{
using CreateMooncakeFuncType = std::unique_ptr<tensorrt_llm::executor::kv_cache::BaseTransferAgent> (*)(
tensorrt_llm::executor::kv_cache::BaseAgentConfig const*);
auto* func = loader.getFunctionPointer<CreateMooncakeFuncType>(
"libtensorrt_llm_mooncake_wrapper.so", "createMooncakeTransferAgent");
}
else
{
skip = true;
skipReason = "Unknown backend: " + backend;
}
}
catch (std::exception const& e)
{
std::string error = e.what();
if (error.find("libtensorrt_llm_nixl_wrapper.so") != std::string::npos)
std::string libName
= (backend == "nixl") ? "libtensorrt_llm_nixl_wrapper.so" : "libtensorrt_llm_mooncake_wrapper.so";
if (error.find(libName) != std::string::npos)
{
skip = true;
skipReason = error;
@ -46,17 +78,26 @@ bool needSkipTest(std::string& skipReason)
return skip;
}
class AgentCommTest : public ::testing::Test
class AgentCommTest : public ::testing::TestWithParam<std::string>
{
protected:
void SetUp() override
{
backend = GetParam();
std::string skipReason;
if (needSkipTest(skipReason))
if (needSkipTest(backend, skipReason))
{
GTEST_SKIP() << skipReason;
}
setenv("TRTLLM_USE_NIXL_KVCACHE", "1", 1);
if (backend == "nixl")
{
setenv("TRTLLM_USE_NIXL_KVCACHE", "1", 1);
}
else if (backend == "mooncake")
{
setenv("TRTLLM_USE_MOONCAKE_KVCACHE", "1", 1);
}
auto constexpr numLayers = 8;
auto constexpr numHeads = 16;
@ -106,15 +147,16 @@ protected:
mCacheState.reset();
}
std::string backend;
std::unique_ptr<CacheTransBufferManager> mTransBufferManager;
std::unique_ptr<KVCacheManager> mCacheManager;
std::unique_ptr<CacheState> mCacheState;
};
TEST_F(AgentCommTest, AgentConnectionManagerBasic)
TEST_P(AgentCommTest, AgentConnectionManagerBasic)
{
std::vector<CacheTransBufferManager*> bufferManagers{mTransBufferManager.get()};
auto connectionManager = std::make_unique<AgentConnectionManager>(bufferManagers, *mCacheState);
auto connectionManager = std::make_unique<AgentConnectionManager>(bufferManagers, *mCacheState, backend);
ASSERT_TRUE(connectionManager != nullptr);
ASSERT_EQ(connectionManager->getCacheTransBufferManagers().size(), bufferManagers.size());
ASSERT_TRUE(connectionManager->getCacheTransBufferManagers().front() != nullptr);
@ -126,11 +168,11 @@ TEST_F(AgentCommTest, AgentConnectionManagerBasic)
ASSERT_EQ(commState.getAgentState().size(), 1);
}
TEST_F(AgentCommTest, AgentConnectionManagerConnect)
TEST_P(AgentCommTest, AgentConnectionManagerConnect)
{
std::vector<CacheTransBufferManager*> bufferManagers{mTransBufferManager.get()};
auto connectionManager0 = std::make_unique<AgentConnectionManager>(bufferManagers, *mCacheState);
auto connectionManager1 = std::make_unique<AgentConnectionManager>(bufferManagers, *mCacheState);
auto connectionManager0 = std::make_unique<AgentConnectionManager>(bufferManagers, *mCacheState, backend);
auto connectionManager1 = std::make_unique<AgentConnectionManager>(bufferManagers, *mCacheState, backend);
auto agentName0 = connectionManager0->getAgentName();
auto agentName1 = connectionManager1->getAgentName();
ASSERT_TRUE(!agentName0.empty());
@ -189,3 +231,6 @@ TEST_F(AgentCommTest, AgentConnectionManagerConnect)
}
TLLM_LOG_INFO("after finish");
}
INSTANTIATE_TEST_SUITE_P(AvailableBackends, AgentCommTest, ::testing::ValuesIn(getAvailableBackends()),
[](::testing::TestParamInfo<AgentCommTest::ParamType> const& info) { return info.param; });

View File

@ -22,11 +22,27 @@
#include <gtest/gtest.h>
#include <filesystem>
#include <vector>
namespace fs = std::filesystem;
using namespace tensorrt_llm::executor::kv_cache;
std::vector<std::string> getAvailableBackends()
{
std::vector<std::string> backends;
#ifdef TEST_NIXL_BACKEND
backends.push_back("nixl");
#endif
#ifdef TEST_MOONCAKE_BACKEND
backends.push_back("mooncake");
#endif
return backends;
}
class RegisteredHostMemory
{
public:
@ -54,100 +70,105 @@ private:
BaseTransferAgent* mAgentPtr{};
};
class TransferAgentTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro-type-member-init)
class TransferAgentTest : public ::testing::TestWithParam<std::string> // NOLINT(cppcoreguidelines-pro-type-member-init)
{
public:
void SetUp() override {}
void SetUp() override
{
backend = GetParam();
}
void TearDown() override {}
[[nodiscard]] std::unique_ptr<BaseTransferAgent> makeTransferAgent(BaseAgentConfig const& config)
{
return tensorrt_llm::executor::kv_cache::makeTransferAgent("nixl", &config);
return tensorrt_llm::executor::kv_cache::makeTransferAgent(backend, &config);
}
std::string backend;
};
TEST_F(TransferAgentTest, Basic)
TEST_P(TransferAgentTest, Basic)
{
std::string const agent0{"agent0"}, agent1{"agent1"};
BaseAgentConfig config0{agent0, true}, config1{agent1, true};
auto nixlAgent0 = makeTransferAgent(config0);
auto nixlAgent1 = makeTransferAgent(config1);
auto xferAgent0 = makeTransferAgent(config0);
auto xferAgent1 = makeTransferAgent(config1);
TLLM_CHECK(nixlAgent0);
TLLM_CHECK(nixlAgent1);
TLLM_CHECK(xferAgent0);
TLLM_CHECK(xferAgent1);
std::vector<char> memory0(100, 10);
std::vector<char> memory1(100, 1);
RegisteredHostMemory regMem0(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory0}}}, nixlAgent0.get());
RegisteredHostMemory regMem1(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, nixlAgent1.get());
RegisteredHostMemory regMem0(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory0}}}, xferAgent0.get());
RegisteredHostMemory regMem1(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, xferAgent1.get());
// nixlAgent0->loadRemoteAgent(agent1);
auto connectionInfo = nixlAgent1->getLocalConnectionInfo();
nixlAgent0->loadRemoteAgent(agent1, connectionInfo);
// xferAgent0->loadRemoteAgent(agent1);
auto connectionInfo = xferAgent1->getLocalConnectionInfo();
xferAgent0->loadRemoteAgent(agent1, connectionInfo);
bool checked = false;
do
{
checked = nixlAgent0->checkRemoteDescs(agent1, regMem1.getDescs());
// wait for regMem is unpacked by nixlAgent0
checked = xferAgent0->checkRemoteDescs(agent1, regMem1.getDescs());
// wait for regMem is unpacked by xferAgent0
} while (!checked);
TransferRequest writeReq{TransferOp::kWRITE, regMem0.getDescs(), regMem1.getDescs(), agent1};
auto status = nixlAgent0->submitTransferRequests(writeReq);
auto status = xferAgent0->submitTransferRequests(writeReq);
status->wait();
TLLM_CHECK(memory0 == memory1);
nixlAgent0->invalidateRemoteAgent(agent1);
xferAgent0->invalidateRemoteAgent(agent1);
}
TEST_F(TransferAgentTest, Basic2)
TEST_P(TransferAgentTest, Basic2)
{
std::string const agent0{"agent0"}, agent1{"agent1"};
BaseAgentConfig config0{agent0, true}, config1{agent1, true};
auto nixlAgent0 = makeTransferAgent(config0);
auto nixlAgent1 = makeTransferAgent(config1);
auto xferAgent0 = makeTransferAgent(config0);
auto xferAgent1 = makeTransferAgent(config1);
TLLM_CHECK(nixlAgent0);
TLLM_CHECK(nixlAgent1);
TLLM_CHECK(xferAgent0);
TLLM_CHECK(xferAgent1);
std::vector<char> memory0(100, 10);
std::vector<char> memory1(100, 1);
RegisteredHostMemory regMem0(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory0}}}, nixlAgent0.get());
RegisteredHostMemory regMem1(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, nixlAgent1.get());
RegisteredHostMemory regMem0(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory0}}}, xferAgent0.get());
RegisteredHostMemory regMem1(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, xferAgent1.get());
// nixlAgent0->loadRemoteAgent(agent1);
auto connectionInfo = nixlAgent1->getLocalConnectionInfo();
nixlAgent0->loadRemoteAgent(agent1, connectionInfo);
// xferAgent0->loadRemoteAgent(agent1);
auto connectionInfo = xferAgent1->getLocalConnectionInfo();
xferAgent0->loadRemoteAgent(agent1, connectionInfo);
bool checked = false;
do
{
checked = nixlAgent0->checkRemoteDescs(agent1, regMem1.getDescs());
checked = xferAgent0->checkRemoteDescs(agent1, regMem1.getDescs());
} while (!checked);
TransferRequest readReq{TransferOp::kREAD, regMem0.getDescs(), regMem1.getDescs(), agent1};
auto status = nixlAgent0->submitTransferRequests(readReq);
auto status = xferAgent0->submitTransferRequests(readReq);
status->wait();
TLLM_CHECK(memory0 == memory1);
nixlAgent0->invalidateRemoteAgent(agent1);
xferAgent0->invalidateRemoteAgent(agent1);
}
TEST_F(TransferAgentTest, DeviceMemory)
TEST_P(TransferAgentTest, DeviceMemory)
{
std::string const agent0{"agent0"}, agent1{"agent1"};
BaseAgentConfig config0{agent0, true}, config1{agent1, true};
auto nixlAgent0 = makeTransferAgent(config0);
auto nixlAgent1 = makeTransferAgent(config1);
auto xferAgent0 = makeTransferAgent(config0);
auto xferAgent1 = makeTransferAgent(config1);
TLLM_CHECK(nixlAgent0);
TLLM_CHECK(nixlAgent1);
TLLM_CHECK(xferAgent0);
TLLM_CHECK(xferAgent1);
char* dev_ptr0;
char* dev_ptr1;
size_t size = 100;
@ -159,20 +180,20 @@ TEST_F(TransferAgentTest, DeviceMemory)
cudaMemcpy(dev_ptr0, memory0.data(), size, cudaMemcpyHostToDevice);
cudaMemcpy(dev_ptr1, memory1.data(), size, cudaMemcpyHostToDevice);
RegisteredHostMemory regMem0(
MemoryDescs{MemoryType::kVRAM, {MemoryDesc{dev_ptr0, size, deviceId}}}, nixlAgent0.get());
MemoryDescs{MemoryType::kVRAM, {MemoryDesc{dev_ptr0, size, deviceId}}}, xferAgent0.get());
RegisteredHostMemory regMem1(
MemoryDescs{MemoryType::kVRAM, {MemoryDesc{dev_ptr1, size, deviceId}}}, nixlAgent1.get());
MemoryDescs{MemoryType::kVRAM, {MemoryDesc{dev_ptr1, size, deviceId}}}, xferAgent1.get());
// nixlAgent0->loadRemoteAgent(agent1);
auto connectionInfo = nixlAgent1->getLocalConnectionInfo();
nixlAgent0->loadRemoteAgent(agent1, connectionInfo);
// xferAgent0->loadRemoteAgent(agent1);
auto connectionInfo = xferAgent1->getLocalConnectionInfo();
xferAgent0->loadRemoteAgent(agent1, connectionInfo);
bool checked = false;
do
{
checked = nixlAgent0->checkRemoteDescs(agent1, regMem1.getDescs());
checked = xferAgent0->checkRemoteDescs(agent1, regMem1.getDescs());
} while (!checked);
TransferRequest writeReq{TransferOp::kWRITE, regMem0.getDescs(), regMem1.getDescs(), agent1};
auto status = nixlAgent0->submitTransferRequests(writeReq);
auto status = xferAgent0->submitTransferRequests(writeReq);
status->wait();
cudaMemcpy(memory0.data(), dev_ptr0, size, cudaMemcpyDeviceToHost);
@ -181,98 +202,99 @@ TEST_F(TransferAgentTest, DeviceMemory)
TLLM_CHECK(memory0 == memory1);
TLLM_CUDA_CHECK(cudaFree(dev_ptr0));
TLLM_CUDA_CHECK(cudaFree(dev_ptr1));
nixlAgent0->invalidateRemoteAgent(agent1);
xferAgent0->invalidateRemoteAgent(agent1);
}
TEST_F(TransferAgentTest, Connect)
TEST_P(TransferAgentTest, Connect)
{
std::string const agent0{"agent0"}, agent1{"agent1"}, agent2{"agent2"};
BaseAgentConfig config0{agent0, true}, config1{agent1, true}, config2{agent2, true};
auto nixlAgent0 = makeTransferAgent(config0);
auto nixlAgent1 = makeTransferAgent(config1);
auto nixlAgent2 = makeTransferAgent(config2);
auto xferAgent0 = makeTransferAgent(config0);
auto xferAgent1 = makeTransferAgent(config1);
auto xferAgent2 = makeTransferAgent(config2);
TLLM_CHECK(nixlAgent0);
TLLM_CHECK(nixlAgent1);
TLLM_CHECK(xferAgent0);
TLLM_CHECK(xferAgent1);
std::vector<char> memory0(100, 10);
std::vector<char> memory1(100, 1);
MemoryDescs memDescs0{MemoryType::kDRAM, {MemoryDesc{memory0}}};
MemoryDescs memDescs1{MemoryType::kDRAM, {MemoryDesc{memory1}}};
nixlAgent0->registerMemory(memDescs0);
nixlAgent1->registerMemory(memDescs1);
nixlAgent2->registerMemory(memDescs0);
xferAgent0->registerMemory(memDescs0);
xferAgent1->registerMemory(memDescs1);
xferAgent2->registerMemory(memDescs0);
// nixlAgent0->loadRemoteAgent(agent1);
auto connectionInfo = nixlAgent1->getLocalConnectionInfo();
nixlAgent0->loadRemoteAgent(agent1, connectionInfo);
// xferAgent0->loadRemoteAgent(agent1);
auto connectionInfo = xferAgent1->getLocalConnectionInfo();
xferAgent0->loadRemoteAgent(agent1, connectionInfo);
bool checked = false;
do
{
checked = nixlAgent0->checkRemoteDescs(agent1, memDescs1);
checked = xferAgent0->checkRemoteDescs(agent1, memDescs1);
} while (!checked);
TransferRequest writeReq{TransferOp::kWRITE, memDescs0, memDescs1, agent1};
auto status = nixlAgent0->submitTransferRequests(writeReq);
auto status = xferAgent0->submitTransferRequests(writeReq);
status->wait();
TLLM_CHECK(memory0 == memory1);
nixlAgent2->loadRemoteAgent(agent1, connectionInfo);
xferAgent2->loadRemoteAgent(agent1, connectionInfo);
checked = false;
do
{
checked = nixlAgent2->checkRemoteDescs(agent1, memDescs1);
checked = xferAgent2->checkRemoteDescs(agent1, memDescs1);
} while (!checked);
TransferRequest writeReq2{TransferOp::kWRITE, memDescs0, memDescs1, agent1};
auto status2 = nixlAgent2->submitTransferRequests(writeReq2);
auto status2 = xferAgent2->submitTransferRequests(writeReq2);
status2->wait();
TLLM_CHECK(memory0 == memory1);
nixlAgent0->invalidateRemoteAgent(agent1);
nixlAgent2->invalidateRemoteAgent(agent1);
nixlAgent0->deregisterMemory(memDescs0);
nixlAgent1->deregisterMemory(memDescs1);
nixlAgent2->deregisterMemory(memDescs0);
xferAgent0->invalidateRemoteAgent(agent1);
xferAgent2->invalidateRemoteAgent(agent1);
xferAgent0->deregisterMemory(memDescs0);
xferAgent1->deregisterMemory(memDescs1);
xferAgent2->deregisterMemory(memDescs0);
}
TEST_F(TransferAgentTest, SyncMessage)
TEST_P(TransferAgentTest, SyncMessage)
{
constexpr std::size_t MAX_QUERY_TIMES = std::numeric_limits<size_t>::max();
std::string const agent0{"agent0"}, agent1{"agent1"};
BaseAgentConfig config0{agent0, true}, config1{agent1, true};
auto nixlAgent0 = makeTransferAgent(config0);
auto nixlAgent1 = makeTransferAgent(config1);
auto xferAgent0 = makeTransferAgent(config0);
auto xferAgent1 = makeTransferAgent(config1);
TLLM_CHECK(nixlAgent0);
TLLM_CHECK(nixlAgent1);
TLLM_CHECK(xferAgent0);
TLLM_CHECK(xferAgent1);
std::vector<char> memory0(100, 10);
std::vector<char> memory1(100, 1);
RegisteredHostMemory regMem0(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory0}}}, nixlAgent0.get());
RegisteredHostMemory regMem1(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, nixlAgent0.get());
RegisteredHostMemory regMem0(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory0}}}, xferAgent0.get());
RegisteredHostMemory regMem1(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, xferAgent0.get());
RegisteredHostMemory regMem2(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory0}}}, nixlAgent1.get());
RegisteredHostMemory regMem3(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, nixlAgent1.get());
RegisteredHostMemory regMem2(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory0}}}, xferAgent1.get());
RegisteredHostMemory regMem3(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, xferAgent1.get());
// nixlAgent0->loadRemoteAgent(agent1);
auto connectionInfo = nixlAgent1->getLocalConnectionInfo();
nixlAgent0->loadRemoteAgent(agent1, connectionInfo);
// xferAgent0->loadRemoteAgent(agent1);
auto connectionInfo = xferAgent1->getLocalConnectionInfo();
xferAgent0->loadRemoteAgent(agent1, connectionInfo);
bool checked = false;
do
{
checked = nixlAgent0->checkRemoteDescs(agent1, regMem3.getDescs());
checked = xferAgent0->checkRemoteDescs(agent1, regMem3.getDescs());
} while (!checked);
auto syncMessage = std::string("agent_sync_message");
TransferRequest writeReq{TransferOp::kWRITE, regMem0.getDescs(), regMem3.getDescs(), agent1};
auto status = nixlAgent0->submitTransferRequests(writeReq);
nixlAgent0->notifySyncMessage(agent1, syncMessage);
auto status = xferAgent0->submitTransferRequests(writeReq);
xferAgent0->notifySyncMessage(agent1, syncMessage);
auto notif = nixlAgent1->getNotifiedSyncMessages();
auto notif = xferAgent1->getNotifiedSyncMessages();
for (std::size_t i = 0; i < MAX_QUERY_TIMES && notif.size() == 0; i++)
{
notif = nixlAgent1->getNotifiedSyncMessages();
notif = xferAgent1->getNotifiedSyncMessages();
}
status->wait();
TLLM_CHECK(status->isCompleted());
TLLM_CHECK(notif.size() == 1);
TLLM_CHECK(notif[agent0].size() == 1);
@ -281,25 +303,25 @@ TEST_F(TransferAgentTest, SyncMessage)
TLLM_CHECK(memory0 == memory1);
std::string syncMessage2 = "two_agent_sync_message";
nixlAgent0->notifySyncMessage(agent1, syncMessage2);
auto notif2 = nixlAgent1->getNotifiedSyncMessages();
xferAgent0->notifySyncMessage(agent1, syncMessage2);
auto notif2 = xferAgent1->getNotifiedSyncMessages();
for (std::size_t i = 0; i < MAX_QUERY_TIMES && notif2.size() == 0; i++)
{
notif2 = nixlAgent1->getNotifiedSyncMessages();
notif2 = xferAgent1->getNotifiedSyncMessages();
}
TLLM_CHECK(notif2.size() == 1);
TLLM_CHECK(notif2[agent0].size() == 1);
TLLM_CHECK(notif2[agent0][0] == syncMessage2);
// nixlAgent1->loadRemoteAgent(agent0);
auto connectionInfo2 = nixlAgent0->getLocalConnectionInfo();
nixlAgent1->loadRemoteAgent(agent0, connectionInfo2);
// xferAgent1->loadRemoteAgent(agent0);
auto connectionInfo2 = xferAgent0->getLocalConnectionInfo();
xferAgent1->loadRemoteAgent(agent0, connectionInfo2);
std::string syncMessage3 = "three_agent_sync_message";
nixlAgent1->notifySyncMessage(agent0, syncMessage3);
auto notif3 = nixlAgent0->getNotifiedSyncMessages();
xferAgent1->notifySyncMessage(agent0, syncMessage3);
auto notif3 = xferAgent0->getNotifiedSyncMessages();
for (std::size_t i = 0; i < MAX_QUERY_TIMES && notif3.size() == 0; i++)
{
notif3 = nixlAgent0->getNotifiedSyncMessages();
notif3 = xferAgent0->getNotifiedSyncMessages();
}
TLLM_CHECK(notif3.size() == 1);
TLLM_CHECK(notif3[agent1].size() == 1);
@ -308,19 +330,20 @@ TEST_F(TransferAgentTest, SyncMessage)
bool checked2 = false;
do
{
checked2 = nixlAgent0->checkRemoteDescs(agent1, regMem1.getDescs());
checked2 = xferAgent0->checkRemoteDescs(agent1, regMem1.getDescs());
} while (!checked2);
std::string syncMessage4 = "four_agent_sync_message";
TransferRequest writeReq1{TransferOp::kWRITE, regMem2.getDescs(), regMem1.getDescs(), agent0};
auto status1 = nixlAgent1->submitTransferRequests(writeReq1);
nixlAgent1->notifySyncMessage(agent0, syncMessage4);
auto status1 = xferAgent1->submitTransferRequests(writeReq1);
xferAgent1->notifySyncMessage(agent0, syncMessage4);
auto notif4 = nixlAgent0->getNotifiedSyncMessages();
auto notif4 = xferAgent0->getNotifiedSyncMessages();
for (std::size_t i = 0; i < MAX_QUERY_TIMES && notif4.size() == 0; i++)
{
notif4 = nixlAgent0->getNotifiedSyncMessages();
notif4 = xferAgent0->getNotifiedSyncMessages();
}
status1->wait();
TLLM_CHECK(status1->isCompleted());
TLLM_CHECK(notif4.size() == 1);
TLLM_CHECK(notif4[agent1].size() == 1);
@ -335,11 +358,11 @@ TEST_F(TransferAgentTest, SyncMessage)
std::stringstream ss;
Serialization::serialize(state, ss);
std::string serializedState = ss.str();
nixlAgent0->notifySyncMessage(agent1, serializedState);
auto notif5 = nixlAgent1->getNotifiedSyncMessages();
xferAgent0->notifySyncMessage(agent1, serializedState);
auto notif5 = xferAgent1->getNotifiedSyncMessages();
for (size_t i = 0; i < MAX_QUERY_TIMES && notif5.size() == 0; i++)
{
notif5 = nixlAgent1->getNotifiedSyncMessages();
notif5 = xferAgent1->getNotifiedSyncMessages();
}
TLLM_CHECK(notif5.size() == 1);
TLLM_CHECK(notif5[agent0].size() == 1);
@ -348,10 +371,16 @@ TEST_F(TransferAgentTest, SyncMessage)
auto state2 = Serialization::deserializeCommState(ss2);
TLLM_CHECK(state2 == state);
nixlAgent0->invalidateRemoteAgent(agent1);
nixlAgent1->invalidateRemoteAgent(agent0);
xferAgent0->invalidateRemoteAgent(agent1);
xferAgent1->invalidateRemoteAgent(agent0);
}
INSTANTIATE_TEST_SUITE_P(AvailableBackends, TransferAgentTest, ::testing::ValuesIn(getAvailableBackends()),
[](::testing::TestParamInfo<TransferAgentTest::ParamType> const& info) { return info.param; });
// Skip LoopbackAgentTest for mooncake backend for now
#ifdef TEST_NIXL_BACKEND
class LoopbackAgentTest : public ::testing::Test,
public ::testing::WithParamInterface<bool> // NOLINT(cppcoreguidelines-pro-type-member-init)
{
@ -466,3 +495,5 @@ TEST_P(LoopbackAgentTest, GpuToFile)
}
INSTANTIATE_TEST_SUITE_P(, LoopbackAgentTest, ::testing::Values(true, false));
#endif // TEST_NIXL_BACKEND

View File

@ -46,6 +46,7 @@
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <filesystem>
#include <memory>
#include <random>
#include <tensorrt_llm/batch_manager/cacheTransBuffer.h>
@ -713,7 +714,7 @@ protected:
return;
}
else if (tensorrt_llm::common::getEnvUseMPIKvCache() || tensorrt_llm::common::getEnvUseUCXKvCache()
|| tensorrt_llm::common::getEnvUseNixlKvCache())
|| tensorrt_llm::common::getEnvUseNixlKvCache() || tensorrt_llm::common::getEnvUseMooncakeKvCache())
{
int maxNumTokens = 2048;
mCacheTransBufferManagers.clear();
@ -729,7 +730,15 @@ protected:
}
bool isUcx = tensorrt_llm::common::getEnvUseUCXKvCache();
bool isNixl = tensorrt_llm::common::getEnvUseNixlKvCache();
TLLM_LOG_INFO("Enable %s KV cache transport.", isUcx ? "UCX" : isNixl ? "NIXL" : "MPI");
bool isMooncake = tensorrt_llm::common::getEnvUseMooncakeKvCache();
// Skip tests for MOONCAKE when on Rocky8
bool isRocky8 = std::filesystem::exists("/etc/redhat-release");
isMooncake = isMooncake && !isRocky8;
TLLM_LOG_INFO("Enable %s KV cache transport.",
isUcx ? "UCX"
: isNixl ? "NIXL"
: isMooncake ? "MOONCAKE"
: "MPI");
if (isUcx)
{
@ -756,7 +765,12 @@ protected:
setenv("TRTLLM_NIXL_PORT", std::to_string(port).c_str(), 1);
mConnectionManager
= std::make_unique<texec::kv_cache::AgentConnectionManager>(bufferManagers, *mCacheState);
= std::make_unique<texec::kv_cache::AgentConnectionManager>(bufferManagers, *mCacheState, "nixl");
}
else if (isMooncake)
{
mConnectionManager = std::make_unique<texec::kv_cache::AgentConnectionManager>(
bufferManagers, *mCacheState, "mooncake");
}
else
{
@ -783,7 +797,7 @@ protected:
std::vector<int> contextRankVec(mContextRankSize);
std::iota(contextRankVec.begin(), contextRankVec.end(), 0);
if (isUcx || isNixl)
if (isUcx || isNixl || isMooncake)
{
auto commState = mConnectionManager->getCommState();
namespace su = tensorrt_llm::executor::serialize_utils;
@ -1286,9 +1300,9 @@ TEST_P(AsymmetricalCacheTest, TestCase)
int indexerDimPerHead = std::get<17>(param);
int indexerKCacheQuantBlockSize = std::get<18>(param);
if (genCp > 1 && tensorrt_llm::common::getEnvUseNixlKvCache())
if (genCp > 1 && (tensorrt_llm::common::getEnvUseNixlKvCache() || tensorrt_llm::common::getEnvUseMooncakeKvCache()))
{
GTEST_SKIP() << "Temporarily skipping cache transceiver tests with NIXL backend for CP.";
GTEST_SKIP() << "Temporarily skipping cache transceiver tests with NIXL and MOONCAKE backend for CP.";
}
std::vector<int> lenList = {30, 10, 60, 80};
if (genCp > 1)
@ -1410,9 +1424,9 @@ TEST_P(AsymmetricalCacheTestWithDP, TestCase)
int indexerDimPerHead = std::get<17>(param);
int indexerKCacheQuantBlockSize = std::get<18>(param);
if (genCp > 1 && tensorrt_llm::common::getEnvUseNixlKvCache())
if (genCp > 1 && (tensorrt_llm::common::getEnvUseNixlKvCache() || tensorrt_llm::common::getEnvUseMooncakeKvCache()))
{
GTEST_SKIP() << "Temporarily skipping cache transceiver tests with NIXL backend for CP.";
GTEST_SKIP() << "Temporarily skipping cache transceiver tests with NIXL and MOONCAKE backend for CP.";
}
setUpCommunicator(contextTp, contextPp, contextCp, genTp, genPp, genCp, isMLA, contextDP, generationDP);

View File

@ -60,12 +60,12 @@ def BUILD_CONFIGS = [
// Vanilla TARNAME is used for packaging in runLLMPackage
// cmake-vars cannot be empty, so passing (default) multi-device configuration.
(CONFIG_LINUX_X86_64_VANILLA) : [
(WHEEL_EXTRA_ARGS) : "--extra-cmake-vars ENABLE_MULTI_DEVICE=1 --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl --micro_benchmarks",
(WHEEL_EXTRA_ARGS) : "--extra-cmake-vars ENABLE_MULTI_DEVICE=1 --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl --extra-cmake-vars MOONCAKE_ROOT=/usr/local/Mooncake --micro_benchmarks",
(TARNAME) : "TensorRT-LLM.tar.gz",
(WHEEL_ARCHS): "80-real;86-real;89-real;90-real;100-real;103-real;120-real",
],
(CONFIG_LINUX_X86_64_PYBIND) : [
(WHEEL_EXTRA_ARGS) : "--binding_type pybind --extra-cmake-vars ENABLE_MULTI_DEVICE=1 --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl --micro_benchmarks",
(WHEEL_EXTRA_ARGS) : "--binding_type pybind --extra-cmake-vars ENABLE_MULTI_DEVICE=1 --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl --extra-cmake-vars MOONCAKE_ROOT=/usr/local/Mooncake --micro_benchmarks",
(TARNAME) : "pybind-TensorRT-LLM.tar.gz",
(WHEEL_ARCHS): "80-real;86-real;89-real;90-real;100-real;103-real;120-real",
],
@ -80,13 +80,13 @@ def BUILD_CONFIGS = [
(WHEEL_ARCHS): "80-real;86-real;89-real;90-real;100-real;103-real;120-real",
],
(CONFIG_LINUX_AARCH64): [
(WHEEL_EXTRA_ARGS) : "--extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl",
(WHEEL_EXTRA_ARGS) : "--extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl --extra-cmake-vars MOONCAKE_ROOT=/usr/local/Mooncake",
(TARNAME) : "TensorRT-LLM-GH200.tar.gz",
(WHEEL_ARCHS): "90-real;100-real;103-real;120-real",
(BUILD_JOBS_FOR_CONFIG): "4", // TODO: Remove after fix the build OOM issue on SBSA
],
(CONFIG_LINUX_AARCH64_PYBIND): [
(WHEEL_EXTRA_ARGS) : "--binding_type pybind --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl",
(WHEEL_EXTRA_ARGS) : "--binding_type pybind --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl --extra-cmake-vars MOONCAKE_ROOT=/usr/local/Mooncake",
(TARNAME) : "pybind-TensorRT-LLM-GH200.tar.gz",
(WHEEL_ARCHS): "90-real;100-real;103-real;120-real",
(BUILD_JOBS_FOR_CONFIG): "4", // TODO: Remove after fix the build OOM issue on SBSA

View File

@ -458,6 +458,7 @@ def main(*,
trt_root: str = '/usr/local/tensorrt',
nccl_root: str = None,
nixl_root: str = None,
mooncake_root: str = None,
internal_cutlass_kernels_root: str = None,
clean: bool = False,
clean_wheel: bool = False,
@ -559,6 +560,11 @@ def main(*,
if nixl_root is not None:
cmake_def_args.append(f"-DNIXL_ROOT={nixl_root}")
if mooncake_root is not None:
if on_windows:
raise RuntimeError("Mooncake is not supported on Windows.")
cmake_def_args.append(f"-DMOONCAKE_ROOT={mooncake_root}")
build_dir = get_build_dir(build_dir, build_type)
first_build = not Path(build_dir, "CMakeFiles").exists()
@ -819,6 +825,14 @@ def main(*,
build_run(
f"find {nixl_dir} -type f -name '*.so*' -exec patchelf --set-rpath \'$ORIGIN:$ORIGIN/plugins:$ORIGIN/../:$ORIGIN/../ucx/:$ORIGIN/../../ucx/\' {{}} \\;"
)
if os.path.exists(
build_dir /
"tensorrt_llm/executor/cache_transmission/mooncake_utils/libtensorrt_llm_mooncake_wrapper.so"
):
install_file(
build_dir /
"tensorrt_llm/executor/cache_transmission/mooncake_utils/libtensorrt_llm_mooncake_wrapper.so",
lib_dir / "libtensorrt_llm_mooncake_wrapper.so")
install_file(
build_dir /
"tensorrt_llm/kernels/decoderMaskedMultiheadAttention/libdecoder_attention_0.so",
@ -1041,6 +1055,10 @@ def add_arguments(parser: ArgumentParser):
help="Directory containing NCCL headers and libraries")
parser.add_argument("--nixl_root",
help="Directory containing NIXL headers and libraries")
parser.add_argument(
"--mooncake_root",
help=
"Directory containing Mooncake transfer engine headers and libraries")
parser.add_argument(
"--internal-cutlass-kernels-root",
default="",

View File

@ -114,9 +114,9 @@ else:
'libs/libnvinfer_plugin_tensorrt_llm.so',
'libs/libtensorrt_llm_ucx_wrapper.so', 'libs/libdecoder_attention_0.so',
'libs/libtensorrt_llm_nixl_wrapper.so', 'libs/nixl/**/*',
'libs/ucx/**/*', 'libs/libpg_utils.so',
'libs/libdecoder_attention_1.so', 'libs/nvshmem/License.txt',
'libs/nvshmem/nvshmem_bootstrap_uid.so.3',
'libs/libtensorrt_llm_mooncake_wrapper.so', 'libs/ucx/**/*',
'libs/libpg_utils.so', 'libs/libdecoder_attention_1.so',
'libs/nvshmem/License.txt', 'libs/nvshmem/nvshmem_bootstrap_uid.so.3',
'libs/nvshmem/nvshmem_transport_ibgda.so.103', 'bindings.*.so',
'deep_ep/LICENSE', 'deep_ep/*.py', 'deep_ep_cpp_tllm.*.so',
"include/**/*", 'deep_gemm/LICENSE', 'deep_gemm/include/**/*',

View File

@ -42,6 +42,7 @@ def create_kv_cache_transceiver(
cache_transceiver_config.backend = "NIXL"
# Ordered by priority
env_vars = [("TRTLLM_USE_UCX_KVCACHE", "UCX"),
("TRTLLM_USE_MOONCAKE_KVCACHE", "MOONCAKE"),
("TRTLLM_USE_MPI_KVCACHE", "MPI")]
for env_var, be_type in env_vars:
if getenv(env_var) == "1":

View File

@ -1739,10 +1739,11 @@ class CacheTransceiverConfig(StrictBaseModel, PybindMirror):
Configuration for the cache transceiver.
"""
backend: Optional[Literal["DEFAULT", "UCX", "NIXL", "MPI"]] = Field(
default=None,
description=
"The communication backend type to use for the cache transceiver.")
backend: Optional[Literal[
"DEFAULT", "UCX", "NIXL", "MOONCAKE", "MPI"]] = Field(
default=None,
description=
"The communication backend type to use for the cache transceiver.")
max_tokens_in_buffer: Optional[int] = Field(
default=None,

View File

@ -25,6 +25,7 @@ class KVCacheType(Enum):
MPI = auto()
UCX = auto()
NIXL = auto()
MOONCAKE = auto()
def get_multi_gpu_env(kv_cache_type=KVCacheType.NONE, llama_multi_gpu=False):
@ -37,6 +38,9 @@ def get_multi_gpu_env(kv_cache_type=KVCacheType.NONE, llama_multi_gpu=False):
env["TRTLLM_USE_UCX_KVCACHE"] = "1"
case KVCacheType.NIXL:
env["TRTLLM_USE_NIXL_KVCACHE"] = "1"
case KVCacheType.MOONCAKE:
env["TRTLLM_USE_MOONCAKE_KVCACHE"] = "1"
env["MC_FORCE_TCP"] = "1"
case KVCacheType.NONE:
pass
case _:
@ -502,8 +506,9 @@ def test_fused_gemm_allreduce(build_google_tests, nprocs, build_dir):
@pytest.mark.parametrize("build_google_tests", ["80", "86", "89", "90"],
indirect=True)
@pytest.mark.parametrize("kvcache_type", [KVCacheType.NIXL, KVCacheType.UCX],
ids=["nixl_kvcache", "ucx_kvcache"])
@pytest.mark.parametrize(
"kvcache_type", [KVCacheType.NIXL, KVCacheType.UCX, KVCacheType.MOONCAKE],
ids=["nixl_kvcache", "ucx_kvcache", "mooncake_kvcache"])
@pytest.mark.parametrize("nprocs", [2, 8], ids=["2proc", "8proc"])
def test_cache_transceiver(build_google_tests, nprocs, kvcache_type, build_dir):

View File

@ -231,6 +231,7 @@ l0_dgx_h100:
- cpp/test_multi_gpu.py::test_cache_transceiver[2proc-ucx_kvcache-90] ISOLATION
- cpp/test_multi_gpu.py::test_cache_transceiver[8proc-nixl_kvcache-90] ISOLATION
- cpp/test_multi_gpu.py::test_cache_transceiver[8proc-ucx_kvcache-90] ISOLATION
- cpp/test_multi_gpu.py::test_cache_transceiver[8proc-mooncake_kvcache-90] ISOLATION
- cpp/test_multi_gpu.py::test_user_buffer[2proc-90]
- cpp/test_multi_gpu.py::test_enc_dec[t5-90]
- cpp/test_multi_gpu.py::test_llama_executor[llama-orchestrator-90]