mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Support Mooncake transfer engine as a cache transceiver backend (#8309)
Signed-off-by: wjueyao <wyao123@terpmail.umd.edu> Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co> Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
This commit is contained in:
parent
e0b2a94309
commit
9f283f330b
@ -1468,7 +1468,8 @@ public:
|
||||
DEFAULT = 0,
|
||||
MPI = 1,
|
||||
UCX = 2,
|
||||
NIXL = 3
|
||||
NIXL = 3,
|
||||
MOONCAKE = 4
|
||||
};
|
||||
explicit CacheTransceiverConfig(std::optional<BackendType> backendType = std::nullopt,
|
||||
std::optional<size_t> maxNumTokens = std::nullopt, std::optional<int> kvTransferTimeoutMs = std::nullopt,
|
||||
|
||||
@ -391,6 +391,14 @@ template <typename... Args>
|
||||
"libtensorrt_llm_nixl_wrapper.so", "createNixlTransferAgent");
|
||||
return func(std::forward<Args>(args)...);
|
||||
}
|
||||
if (backend == "mooncake")
|
||||
{
|
||||
auto& loader = DynLibLoader::getInstance();
|
||||
using CreateMooncakeFuncType = std::unique_ptr<BaseTransferAgent> (*)(BaseAgentConfig const*);
|
||||
auto* func = loader.getFunctionPointer<CreateMooncakeFuncType>(
|
||||
"libtensorrt_llm_mooncake_wrapper.so", "createMooncakeTransferAgent");
|
||||
return func(std::forward<Args>(args)...);
|
||||
}
|
||||
TLLM_THROW("Unknown backend name.");
|
||||
}
|
||||
|
||||
|
||||
@ -159,6 +159,10 @@ if(NIXL_ROOT)
|
||||
set(NIXL_WRAPPER_TARGET tensorrt_llm_nixl_wrapper)
|
||||
endif()
|
||||
|
||||
if(MOONCAKE_ROOT)
|
||||
set(MOONCAKE_WRAPPER_TARGET tensorrt_llm_mooncake_wrapper)
|
||||
endif()
|
||||
|
||||
add_subdirectory(executor)
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
@ -272,6 +276,11 @@ if(TARGET ${NIXL_WRAPPER_TARGET})
|
||||
add_dependencies(${SHARED_TARGET} ${NIXL_WRAPPER_TARGET})
|
||||
endif()
|
||||
|
||||
if(TARGET ${MOONCAKE_WRAPPER_TARGET})
|
||||
target_link_libraries(${MOONCAKE_WRAPPER_TARGET} INTERFACE ${SHARED_TARGET})
|
||||
add_dependencies(${SHARED_TARGET} ${MOONCAKE_WRAPPER_TARGET})
|
||||
endif()
|
||||
|
||||
if(NOT WIN32)
|
||||
# Load libraries at $PREFIX/lib from
|
||||
# $PREFIX/lib/python3.12/site-packages/tensorrt_llm/libs
|
||||
|
||||
@ -81,6 +81,11 @@ std::unique_ptr<BaseCacheTransceiver> CacheTransceiverFactory::createCacheTransc
|
||||
backendType = executor::CacheTransceiverConfig::BackendType::NIXL;
|
||||
TLLM_LOG_INFO("Enable NIXL KV cache transport.");
|
||||
}
|
||||
else if (common::getEnvUseMooncakeKvCache())
|
||||
{
|
||||
backendType = executor::CacheTransceiverConfig::BackendType::MOONCAKE;
|
||||
TLLM_LOG_INFO("Enable MOONCAKE KV cache transport.");
|
||||
}
|
||||
else if (common::getEnvUseMPIKvCache())
|
||||
{
|
||||
backendType = executor::CacheTransceiverConfig::BackendType::MPI;
|
||||
@ -203,9 +208,15 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
|
||||
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::NIXL)
|
||||
{
|
||||
mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
|
||||
mCacheTransBufferManagerPtrs, *mCacheState);
|
||||
mCacheTransBufferManagerPtrs, *mCacheState, "nixl");
|
||||
TLLM_LOG_INFO("NIXL Connection Manager created");
|
||||
}
|
||||
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MOONCAKE)
|
||||
{
|
||||
mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
|
||||
mCacheTransBufferManagerPtrs, *mCacheState, "mooncake");
|
||||
TLLM_LOG_INFO("MOONCAKE Connection Manager created");
|
||||
}
|
||||
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MPI)
|
||||
{
|
||||
mMpiWorldComm = std::addressof(tensorrt_llm::mpi::MpiComm::world());
|
||||
|
||||
@ -281,6 +281,12 @@ bool getEnvUseNixlKvCache()
|
||||
return useNixlKvCache;
|
||||
}
|
||||
|
||||
bool getEnvUseMooncakeKvCache()
|
||||
{
|
||||
static bool const useMooncakeKvCache = getBoolEnv("TRTLLM_USE_MOONCAKE_KVCACHE");
|
||||
return useMooncakeKvCache;
|
||||
}
|
||||
|
||||
bool getEnvUseRoundRobinBlockDistForCP()
|
||||
{
|
||||
static bool const useRoundRobinBlockDistForCP = getBoolEnv("TRTLLM_USE_ROUND_ROBIN_BLOCK_DIST_FOR_CP");
|
||||
@ -343,6 +349,23 @@ std::string getEnvNixlBackend()
|
||||
return nixlBackend;
|
||||
}
|
||||
|
||||
std::string getEnvMooncakeInterface()
|
||||
{
|
||||
static std::once_flag flag;
|
||||
static std::string mooncakeInterface;
|
||||
|
||||
std::call_once(flag,
|
||||
[&]()
|
||||
{
|
||||
char const* mooncake_interface = std::getenv("TRTLLM_MOONCAKE_INTERFACE");
|
||||
if (mooncake_interface)
|
||||
{
|
||||
mooncakeInterface = mooncake_interface;
|
||||
}
|
||||
});
|
||||
return mooncakeInterface;
|
||||
}
|
||||
|
||||
bool getEnvDisaggLayerwise()
|
||||
{
|
||||
static bool const disaggLayerwise = getBoolEnv("TRTLLM_DISAGG_LAYERWISE");
|
||||
|
||||
@ -83,8 +83,11 @@ inline void launchWithPdlWhenEnabled(char const* name, KernelFn kernelFn, dim3 g
|
||||
bool getEnvUseUCXKvCache();
|
||||
|
||||
bool getEnvUseMPIKvCache();
|
||||
|
||||
bool getEnvUseNixlKvCache();
|
||||
|
||||
bool getEnvUseMooncakeKvCache();
|
||||
|
||||
bool getEnvUseRoundRobinBlockDistForCP();
|
||||
|
||||
std::string getEnvUCXInterface();
|
||||
@ -93,6 +96,8 @@ std::string getEnvNixlInterface();
|
||||
|
||||
std::string getEnvNixlBackend();
|
||||
|
||||
std::string getEnvMooncakeInterface();
|
||||
|
||||
bool getEnvDisaggLayerwise();
|
||||
|
||||
bool getEnvParallelCacheSend();
|
||||
|
||||
226
cpp/tensorrt_llm/common/ipUtils.cpp
Normal file
226
cpp/tensorrt_llm/common/ipUtils.cpp
Normal file
@ -0,0 +1,226 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ipUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <dirent.h>
|
||||
#include <fcntl.h>
|
||||
#include <ifaddrs.h>
|
||||
#include <net/if.h>
|
||||
#include <netdb.h>
|
||||
#include <netinet/in.h>
|
||||
#include <string>
|
||||
#include <sys/socket.h>
|
||||
#include <unistd.h>
|
||||
|
||||
TRTLLM_NAMESPACE_BEGIN
|
||||
|
||||
namespace common
|
||||
{
|
||||
|
||||
std::string getLocalIpByNic(std::string const& interface, int rank)
|
||||
{
|
||||
struct ifaddrs* ifaddr = nullptr;
|
||||
if (getifaddrs(&ifaddr) == -1)
|
||||
{
|
||||
TLLM_LOG_ERROR(rank,
|
||||
"getLocalIpByNic: Can't get local ip from NIC Interface. Please check whether corresponding INTERFACE is "
|
||||
"set "
|
||||
"correctly.");
|
||||
return std::string{};
|
||||
}
|
||||
|
||||
for (struct ifaddrs* ifa = ifaddr; ifa != nullptr; ifa = ifa->ifa_next)
|
||||
{
|
||||
if (ifa->ifa_addr == nullptr)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ifa->ifa_name == interface)
|
||||
{
|
||||
if (ifa->ifa_addr->sa_family == AF_INET)
|
||||
{
|
||||
char ip[INET_ADDRSTRLEN]{};
|
||||
void* addr = &((reinterpret_cast<struct sockaddr_in*>(ifa->ifa_addr))->sin_addr);
|
||||
if ((inet_ntop(AF_INET, addr, ip, sizeof(ip)) != nullptr) && std::strcmp(ip, "0.0.0.0") != 0)
|
||||
{
|
||||
freeifaddrs(ifaddr);
|
||||
return std::string(ip);
|
||||
}
|
||||
}
|
||||
else if (ifa->ifa_addr->sa_family == AF_INET6)
|
||||
{
|
||||
char ip[INET6_ADDRSTRLEN]{};
|
||||
void* addr = &((reinterpret_cast<struct sockaddr_in6*>(ifa->ifa_addr))->sin6_addr);
|
||||
if ((inet_ntop(AF_INET6, addr, ip, sizeof(ip)) != nullptr) && std::strncmp(ip, "fe80::", 6) != 0
|
||||
&& std::strcmp(ip, "::1") != 0)
|
||||
{
|
||||
freeifaddrs(ifaddr);
|
||||
return std::string(ip);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
freeifaddrs(ifaddr);
|
||||
TLLM_LOG_ERROR(
|
||||
rank, "Can't get local ip from NIC Interface. Please check whether corresponding INTERFACE is set correctly.");
|
||||
return std::string{};
|
||||
}
|
||||
|
||||
std::string getLocalIpByHostname(int rank)
|
||||
{
|
||||
char hostname[256]{};
|
||||
if (gethostname(hostname, sizeof(hostname)) == -1)
|
||||
{
|
||||
TLLM_LOG_ERROR(rank, "getLocalIpByHostname: Can't get hostname");
|
||||
return std::string{};
|
||||
}
|
||||
|
||||
struct addrinfo hints = {};
|
||||
hints.ai_family = AF_UNSPEC;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
hints.ai_flags = AI_CANONNAME;
|
||||
|
||||
struct addrinfo* res = nullptr;
|
||||
if (getaddrinfo(hostname, nullptr, &hints, &res) != 0)
|
||||
{
|
||||
TLLM_LOG_WARNING(rank, "getLocalIpByHostname: Can't get address info for hostname");
|
||||
return std::string{};
|
||||
}
|
||||
|
||||
for (struct addrinfo* p = res; p != nullptr; p = p->ai_next)
|
||||
{
|
||||
|
||||
if (p->ai_family == AF_INET)
|
||||
{ // IPv4
|
||||
char ip[INET_ADDRSTRLEN]{};
|
||||
struct sockaddr_in* ipv4 = reinterpret_cast<struct sockaddr_in*>(p->ai_addr);
|
||||
void* addr = &(ipv4->sin_addr);
|
||||
if ((inet_ntop(AF_INET, addr, ip, sizeof(ip)) != nullptr) && std::strcmp(ip, "127.0.0.1") != 0
|
||||
&& std::strcmp(ip, "0.0.0.0") != 0)
|
||||
{
|
||||
freeaddrinfo(res);
|
||||
return std::string(ip);
|
||||
}
|
||||
}
|
||||
else if (p->ai_family == AF_INET6)
|
||||
{ // IPv6
|
||||
char ip[INET6_ADDRSTRLEN]{};
|
||||
struct sockaddr_in6* ipv6 = reinterpret_cast<struct sockaddr_in6*>(p->ai_addr);
|
||||
void* addr = &(ipv6->sin6_addr);
|
||||
if ((inet_ntop(AF_INET6, addr, ip, sizeof(ip)) != nullptr) && std::strncmp(ip, "fe80::", 6) != 0
|
||||
&& std::strcmp(ip, "::1") != 0)
|
||||
{
|
||||
freeaddrinfo(res);
|
||||
return std::string(ip);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
freeaddrinfo(res);
|
||||
TLLM_LOG_WARNING(rank, "getLocalIpByHostname: Can't get local ip from hostname");
|
||||
return std::string{};
|
||||
}
|
||||
|
||||
std::string getLocalIpByRemoteOrHostName(int rank)
|
||||
{
|
||||
|
||||
// Try IPv4
|
||||
struct sockaddr_in addr
|
||||
{
|
||||
};
|
||||
|
||||
addr.sin_family = AF_INET;
|
||||
addr.sin_port = htons(80);
|
||||
// using google's public dns server to get the local ip which can be accessed from remote
|
||||
char const* dns_ip_v4 = "8.8.8.8";
|
||||
inet_pton(AF_INET, dns_ip_v4, &addr.sin_addr);
|
||||
|
||||
int sock = socket(AF_INET, SOCK_DGRAM, 0);
|
||||
if (sock != -1)
|
||||
{
|
||||
if (connect(sock, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) != -1)
|
||||
{
|
||||
socklen_t addr_len = sizeof(addr);
|
||||
if (getsockname(sock, reinterpret_cast<struct sockaddr*>(&addr), &addr_len) != -1)
|
||||
{
|
||||
char ip[INET_ADDRSTRLEN]{};
|
||||
inet_ntop(AF_INET, &addr.sin_addr, ip, sizeof(ip));
|
||||
close(sock);
|
||||
return std::string(ip);
|
||||
}
|
||||
}
|
||||
close(sock);
|
||||
}
|
||||
|
||||
// Try IPv6
|
||||
struct sockaddr_in6 addr6
|
||||
{
|
||||
};
|
||||
|
||||
addr6.sin6_family = AF_INET6;
|
||||
addr6.sin6_port = htons(80);
|
||||
// using google's public dns server
|
||||
char const* dns_ipv6 = "2001:4860:4860::8888";
|
||||
inet_pton(AF_INET6, dns_ipv6, &addr6.sin6_addr);
|
||||
|
||||
sock = socket(AF_INET6, SOCK_DGRAM, 0);
|
||||
if (sock != -1)
|
||||
{
|
||||
if (connect(sock, reinterpret_cast<struct sockaddr*>(&addr6), sizeof(addr6)) != -1)
|
||||
{
|
||||
socklen_t addr_len = sizeof(addr6);
|
||||
if (getsockname(sock, reinterpret_cast<struct sockaddr*>(&addr6), &addr_len) != -1)
|
||||
{
|
||||
char ip[INET6_ADDRSTRLEN]{};
|
||||
inet_ntop(AF_INET6, &addr6.sin6_addr, ip, sizeof(ip));
|
||||
close(sock);
|
||||
return std::string(ip);
|
||||
}
|
||||
}
|
||||
close(sock);
|
||||
}
|
||||
|
||||
// Try hostname
|
||||
return getLocalIpByHostname(rank);
|
||||
}
|
||||
|
||||
std::string getLocalIp(std::string interface, int rank)
|
||||
{
|
||||
std::string localIP = {};
|
||||
if (!interface.empty())
|
||||
{
|
||||
localIP = getLocalIpByNic(interface, rank);
|
||||
}
|
||||
if (localIP.empty())
|
||||
{
|
||||
localIP = getLocalIpByRemoteOrHostName(rank);
|
||||
}
|
||||
// check whether the localIP is valid
|
||||
if (localIP.empty())
|
||||
{
|
||||
TLLM_THROW("getLocalIp: Can't get local ip");
|
||||
}
|
||||
return localIP;
|
||||
}
|
||||
} // namespace common
|
||||
|
||||
TRTLLM_NAMESPACE_END
|
||||
28
cpp/tensorrt_llm/common/ipUtils.h
Normal file
28
cpp/tensorrt_llm/common/ipUtils.h
Normal file
@ -0,0 +1,28 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/config.h"
|
||||
#include <string>
|
||||
|
||||
TRTLLM_NAMESPACE_BEGIN
|
||||
|
||||
namespace common
|
||||
{
|
||||
std::string getLocalIp(std::string interface, int rank);
|
||||
} // namespace common
|
||||
|
||||
TRTLLM_NAMESPACE_END
|
||||
@ -91,3 +91,4 @@ target_compile_definitions(${EXECUTOR_STATIC_TARGET}
|
||||
|
||||
add_subdirectory(cache_transmission/ucx_utils)
|
||||
add_subdirectory(cache_transmission/nixl_utils)
|
||||
add_subdirectory(cache_transmission/mooncake_utils)
|
||||
|
||||
@ -236,7 +236,7 @@ bool AgentConnection::recvReadySignal(DataContext const& ctx) const
|
||||
|
||||
AgentConnectionManager::AgentConnectionManager(
|
||||
std::vector<batch_manager::kv_cache_manager::CacheTransBufferManager*> cacheTransBufferManagers,
|
||||
CacheState cacheState)
|
||||
CacheState cacheState, std::string const& backendType)
|
||||
: mCacheState(std::move(cacheState))
|
||||
, mCacheTransBufferManagers(std::move(cacheTransBufferManagers))
|
||||
, mRegMemDescs(MemoryType::kVRAM, {})
|
||||
@ -247,7 +247,7 @@ AgentConnectionManager::AgentConnectionManager(
|
||||
mAgentName = genUniqueAgentName();
|
||||
// Create Agent
|
||||
BaseAgentConfig config{mAgentName, true};
|
||||
m_Agent = makeTransferAgent("nixl", &config);
|
||||
m_Agent = makeTransferAgent(backendType, &config);
|
||||
TLLM_CHECK(!mCacheTransBufferManagers.empty());
|
||||
std::vector<MemoryDesc> memDescs;
|
||||
for (auto* cacheTransBufferManager : mCacheTransBufferManagers)
|
||||
|
||||
@ -277,7 +277,7 @@ class AgentConnectionManager : public ConnectionManager
|
||||
public:
|
||||
AgentConnectionManager(
|
||||
std::vector<batch_manager::kv_cache_manager::CacheTransBufferManager*> cacheTransBufferManagers,
|
||||
CacheState cacheState);
|
||||
CacheState cacheState, std::string const& backendType);
|
||||
~AgentConnectionManager();
|
||||
AgentConnection* recvConnect(DataContext const& ctx, void* data, size_t size) override;
|
||||
[[nodiscard]] std::vector<Connection const*> getConnections(CommState const& state) override;
|
||||
|
||||
@ -0,0 +1,40 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION &
|
||||
# AFFILIATES. All rights reserved. SPDX-License-Identifier: NVIDIA TensorRT
|
||||
# Source Code License Agreement
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this material and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION or its affiliates is strictly
|
||||
# prohibited.
|
||||
|
||||
# MOONCAKE is not supported on Rocky8 for now
|
||||
set(IS_ROCKY8 FALSE)
|
||||
if(EXISTS "/etc/redhat-release")
|
||||
set(IS_ROCKY8 TRUE)
|
||||
endif()
|
||||
|
||||
if(MOONCAKE_ROOT AND NOT IS_ROCKY8)
|
||||
find_library(TRANSFER_ENGINE_LIB transfer_engine ${MOONCAKE_ROOT}/lib)
|
||||
find_path(TRANSFER_ENGINE_INCLUDE_DIR transfer_engine_c.h
|
||||
${MOONCAKE_ROOT}/include)
|
||||
|
||||
message(STATUS "Find transfer engine results:")
|
||||
message(STATUS " TRANSFER_ENGINE_LIB = ${TRANSFER_ENGINE_LIB}")
|
||||
message(
|
||||
STATUS " TRANSFER_ENGINE_INCLUDE_DIR = ${TRANSFER_ENGINE_INCLUDE_DIR}")
|
||||
|
||||
if(TRANSFER_ENGINE_LIB AND TRANSFER_ENGINE_INCLUDE_DIR)
|
||||
set(MOONCAKE_WRAPPER_TARGET "tensorrt_llm_mooncake_wrapper")
|
||||
|
||||
add_library(${MOONCAKE_WRAPPER_TARGET} SHARED transferAgent.cpp)
|
||||
target_compile_options(${MOONCAKE_WRAPPER_TARGET} PRIVATE -Wno-error)
|
||||
|
||||
target_include_directories(${MOONCAKE_WRAPPER_TARGET}
|
||||
PRIVATE ${TRANSFER_ENGINE_INCLUDE_DIR})
|
||||
|
||||
target_link_libraries(${MOONCAKE_WRAPPER_TARGET}
|
||||
PRIVATE ${TRANSFER_ENGINE_LIB} CUDA::cudart)
|
||||
endif()
|
||||
endif()
|
||||
@ -0,0 +1,546 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/executor/cache_transmission/mooncake_utils/transferAgent.h"
|
||||
#include "tensorrt_llm/common/envUtils.h"
|
||||
#include "tensorrt_llm/common/ipUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/executor/transferAgent.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <arpa/inet.h>
|
||||
#include <chrono>
|
||||
#include <dirent.h>
|
||||
#include <fcntl.h>
|
||||
#include <ifaddrs.h>
|
||||
#include <net/if.h>
|
||||
#include <netdb.h>
|
||||
#include <netinet/in.h>
|
||||
#include <sys/file.h>
|
||||
#include <sys/stat.h>
|
||||
#include <thread>
|
||||
#include <unistd.h>
|
||||
|
||||
namespace tensorrt_llm::executor::kv_cache
|
||||
{
|
||||
|
||||
MooncakeTransferStatus::MooncakeTransferStatus(transfer_engine_t engine, uint64_t batchId, size_t requestCount)
|
||||
: mEngine{engine}
|
||||
, mBatchId{batchId}
|
||||
, mRequestCount{requestCount}
|
||||
{
|
||||
TLLM_CHECK(mEngine);
|
||||
}
|
||||
|
||||
void MooncakeTransferStatus::wait() const
|
||||
{
|
||||
while (!isCompleted())
|
||||
{
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1));
|
||||
}
|
||||
}
|
||||
|
||||
[[nodiscard]] bool MooncakeTransferStatus::isCompleted() const
|
||||
{
|
||||
if (mBatchFreed)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
bool has_failed = false;
|
||||
for (size_t index = 0; index < mRequestCount; ++index)
|
||||
{
|
||||
transfer_status_t status;
|
||||
int rc = getTransferStatus(mEngine, mBatchId, index, &status);
|
||||
if (rc || status.status == STATUS_FAILED)
|
||||
{
|
||||
has_failed = true;
|
||||
if (rc)
|
||||
{
|
||||
TLLM_LOG_ERROR(
|
||||
"Failed to get transfer status for batch %lu, task %zu: error code %d", mBatchId, index, rc);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_ERROR("Transfer failed for batch %lu, task %zu: status %d", mBatchId, index, status.status);
|
||||
}
|
||||
}
|
||||
else if (status.status == STATUS_PENDING || status.status == STATUS_WAITING)
|
||||
{
|
||||
TLLM_LOG_DEBUG("Transfer is pending for batch %lu, task %zu", mBatchId, index);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (!has_failed)
|
||||
{
|
||||
// Each batchId has the batch size, and cannot process more requests
|
||||
// than the batch size. So, free the batch id here to workaround the issue
|
||||
// where the same batchId could be used to post multiple transfer.
|
||||
freeBatchID(mEngine, mBatchId);
|
||||
mBatchFreed = true;
|
||||
TLLM_LOG_DEBUG("Batch ID %lu freed, future calls will return true directly", mBatchId);
|
||||
}
|
||||
// Currently, we cannot distinguish between failed and completed from return value.
|
||||
TLLM_LOG_DEBUG("Transfer is completed for batch %lu", mBatchId);
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::string MooncakeBase64Helper::STANDARD_CHARS
|
||||
= "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
"abcdefghijklmnopqrstuvwxyz"
|
||||
"0123456789+/";
|
||||
|
||||
std::string MooncakeBase64Helper::encode(std::vector<uint8_t> const& data)
|
||||
{
|
||||
return encodeInternal(data, STANDARD_CHARS);
|
||||
}
|
||||
|
||||
std::string MooncakeBase64Helper::encode(std::string const& data)
|
||||
{
|
||||
std::vector<uint8_t> vec(data.begin(), data.end());
|
||||
return encode(vec);
|
||||
}
|
||||
|
||||
std::vector<uint8_t> MooncakeBase64Helper::decode(std::string const& encoded)
|
||||
{
|
||||
return decodeInternal(encoded, STANDARD_CHARS);
|
||||
}
|
||||
|
||||
std::string MooncakeBase64Helper::decodeToString(std::string const& encoded)
|
||||
{
|
||||
auto vec = decode(encoded);
|
||||
return std::string(vec.begin(), vec.end());
|
||||
}
|
||||
|
||||
std::string MooncakeBase64Helper::encodeInternal(std::vector<uint8_t> const& data, std::string const& chars)
|
||||
{
|
||||
std::string encoded;
|
||||
size_t i = 0;
|
||||
size_t j = 0;
|
||||
std::array<uint8_t, 3> charArray3{};
|
||||
std::array<uint8_t, 4> charArray4{};
|
||||
size_t dataLen = data.size();
|
||||
uint8_t const* bytes = data.data();
|
||||
|
||||
while (dataLen--)
|
||||
{
|
||||
charArray3[i++] = *(bytes++);
|
||||
if (i == 3)
|
||||
{
|
||||
charArray4[0] = (charArray3[0] & 0xfc) >> 2;
|
||||
charArray4[1] = ((charArray3[0] & 0x03) << 4) + ((charArray3[1] & 0xf0) >> 4);
|
||||
charArray4[2] = ((charArray3[1] & 0x0f) << 2) + ((charArray3[2] & 0xc0) >> 6);
|
||||
charArray4[3] = charArray3[2] & 0x3f;
|
||||
|
||||
for (i = 0; i < 4; i++)
|
||||
{
|
||||
encoded += chars[charArray4[i]];
|
||||
}
|
||||
i = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (i > 0)
|
||||
{
|
||||
for (j = i; j < 3; j++)
|
||||
{
|
||||
charArray3[j] = '\0';
|
||||
}
|
||||
|
||||
charArray4[0] = (charArray3[0] & 0xfc) >> 2;
|
||||
charArray4[1] = ((charArray3[0] & 0x03) << 4) + ((charArray3[1] & 0xf0) >> 4);
|
||||
charArray4[2] = ((charArray3[1] & 0x0f) << 2) + ((charArray3[2] & 0xc0) >> 6);
|
||||
charArray4[3] = charArray3[2] & 0x3f;
|
||||
|
||||
for (j = 0; j < i + 1; j++)
|
||||
{
|
||||
encoded += chars[charArray4[j]];
|
||||
}
|
||||
|
||||
while (i++ < 3)
|
||||
{
|
||||
encoded += '=';
|
||||
}
|
||||
}
|
||||
|
||||
return encoded;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> MooncakeBase64Helper::decodeInternal(std::string const& encoded, std::string const& chars)
|
||||
{
|
||||
size_t encodedLen = encoded.size();
|
||||
size_t i = 0;
|
||||
size_t j = 0;
|
||||
size_t in_ = 0;
|
||||
std::array<uint8_t, 3> charArray3{};
|
||||
std::array<uint8_t, 4> charArray4{};
|
||||
std::vector<uint8_t> decoded;
|
||||
|
||||
std::string cleanEncoded;
|
||||
for (char c : encoded)
|
||||
{
|
||||
if (!isWhitespace(c))
|
||||
{
|
||||
cleanEncoded += c;
|
||||
}
|
||||
}
|
||||
|
||||
encodedLen = cleanEncoded.size();
|
||||
|
||||
while (encodedLen-- && cleanEncoded[in_] != '=' && isBase64(cleanEncoded[in_], chars))
|
||||
{
|
||||
charArray4[i++] = cleanEncoded[in_];
|
||||
in_++;
|
||||
if (i == 4)
|
||||
{
|
||||
for (i = 0; i < 4; i++)
|
||||
{
|
||||
charArray4[i] = chars.find(charArray4[i]);
|
||||
}
|
||||
|
||||
charArray3[0] = (charArray4[0] << 2) + ((charArray4[1] & 0x30) >> 4);
|
||||
charArray3[1] = ((charArray4[1] & 0xf) << 4) + ((charArray4[2] & 0x3c) >> 2);
|
||||
charArray3[2] = ((charArray4[2] & 0x3) << 6) + charArray4[3];
|
||||
|
||||
for (i = 0; i < 3; i++)
|
||||
{
|
||||
decoded.push_back(charArray3[i]);
|
||||
}
|
||||
i = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (i > 0)
|
||||
{
|
||||
for (j = i; j < 4; j++)
|
||||
{
|
||||
charArray4[j] = 0;
|
||||
}
|
||||
|
||||
for (j = 0; j < 4; j++)
|
||||
{
|
||||
charArray4[j] = chars.find(charArray4[j]);
|
||||
}
|
||||
|
||||
charArray3[0] = (charArray4[0] << 2) + ((charArray4[1] & 0x30) >> 4);
|
||||
charArray3[1] = ((charArray4[1] & 0xf) << 4) + ((charArray4[2] & 0x3c) >> 2);
|
||||
charArray3[2] = ((charArray4[2] & 0x3) << 6) + charArray4[3];
|
||||
|
||||
for (j = 0; j < i - 1; j++)
|
||||
{
|
||||
decoded.push_back(charArray3[j]);
|
||||
}
|
||||
}
|
||||
|
||||
return decoded;
|
||||
}
|
||||
|
||||
bool MooncakeBase64Helper::isBase64(uint8_t c, std::string const& chars)
|
||||
{
|
||||
return (isalnum(c) || (c == chars[62]) || (c == chars[63]));
|
||||
}
|
||||
|
||||
bool MooncakeBase64Helper::isWhitespace(uint8_t c)
|
||||
{
|
||||
return (c == ' ' || c == '\n' || c == '\r' || c == '\t');
|
||||
}
|
||||
|
||||
MooncakeTransferAgent::MooncakeTransferAgent(BaseAgentConfig const& config)
|
||||
{
|
||||
mLocalAgentName = config.mName;
|
||||
std::string segmentName = "127.0.0.1";
|
||||
|
||||
if (getenv("TLLM_MOONCAKE_IP_ADDR"))
|
||||
{
|
||||
segmentName = std::string(getenv("TLLM_MOONCAKE_IP_ADDR"));
|
||||
}
|
||||
else
|
||||
{
|
||||
auto ip = common::getLocalIp(common::getEnvMooncakeInterface(), mpi::MpiComm::session().getRank());
|
||||
if (!ip.empty())
|
||||
segmentName = ip;
|
||||
}
|
||||
|
||||
mEngine = createTransferEngine("P2PHANDSHAKE", segmentName.c_str(), "", 0, true);
|
||||
}
|
||||
|
||||
void MooncakeTransferAgent::registerMemory(RegisterDescs const& descs)
|
||||
{
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::registerMemory");
|
||||
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
for (auto const& desc : descs.getDescs())
|
||||
{
|
||||
auto it = mMemRegInfo.find(desc.getAddr());
|
||||
if (it != mMemRegInfo.end())
|
||||
{
|
||||
it->second->addRef();
|
||||
continue;
|
||||
}
|
||||
|
||||
int err = registerLocalMemory(mEngine, reinterpret_cast<void*>(desc.getAddr()), desc.getLen(), "*", 1);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(err == 0, "registerLocalMemory failed, addr: %p, len: %lu",
|
||||
reinterpret_cast<void*>(desc.getAddr()), desc.getLen());
|
||||
|
||||
auto mooncakeDesc = std::make_shared<MooncakeMemoryDesc>(desc);
|
||||
mMemRegInfo[desc.getAddr()] = std::move(mooncakeDesc);
|
||||
}
|
||||
}
|
||||
|
||||
void MooncakeTransferAgent::deregisterMemory(RegisterDescs const& descs)
|
||||
{
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::deregisterMemory");
|
||||
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
for (auto const& desc : descs.getDescs())
|
||||
{
|
||||
auto it = mMemRegInfo.find(desc.getAddr());
|
||||
if (it != mMemRegInfo.end())
|
||||
{
|
||||
auto const& mooncakeDesc = it->second;
|
||||
mooncakeDesc->releaseRef();
|
||||
if (mooncakeDesc->getRefCount())
|
||||
continue;
|
||||
|
||||
int err = unregisterLocalMemory(mEngine, reinterpret_cast<void*>(desc.getAddr()));
|
||||
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
err == 0, "unregisterLocalMemory failed, addr: %p", reinterpret_cast<void*>(desc.getAddr()));
|
||||
|
||||
mMemRegInfo.erase(desc.getAddr());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MooncakeTransferAgent::loadRemoteAgent(std::string const& name, AgentDesc const& agentDesc)
|
||||
{
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::loadRemoteAgent");
|
||||
|
||||
// Do the same thing as loadRemoteAgent(std::string const& name, ConnectionInfoType const& connectionInfo)
|
||||
loadRemoteAgent(name, std::move(agentDesc.getBackendAgentDesc()));
|
||||
}
|
||||
|
||||
void MooncakeTransferAgent::loadRemoteAgent(std::string const& name, ConnectionInfoType const& connectionInfo)
|
||||
{
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
|
||||
"MooncakeTransferAgent::loadRemoteAgent loadRemoteAgent to %s remoteagent name: %s", connectionInfo.c_str(),
|
||||
name.c_str());
|
||||
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
auto segmentId = openSegment(mEngine, connectionInfo.c_str());
|
||||
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
segmentId >= 0, "loadRemoteAgent openSegment failed, connectionInfo: %s", connectionInfo.c_str());
|
||||
|
||||
mConnectedAgents[name].segmentId = segmentId;
|
||||
}
|
||||
|
||||
void MooncakeTransferAgent::invalidateRemoteAgent(std::string const& name)
|
||||
{
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::invalidateRemoteAgent");
|
||||
}
|
||||
|
||||
AgentDesc MooncakeTransferAgent::getLocalAgentDesc()
|
||||
{
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::getLocalAgentDesc");
|
||||
|
||||
// Using connection info as agent desc
|
||||
const static size_t kBufLen = 64;
|
||||
char connectionInfo[kBufLen];
|
||||
|
||||
int ret = getLocalIpAndPort(mEngine, connectionInfo, kBufLen);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(ret == 0, "MooncakeTransferAgent::getLocalAgentDesc::getLocalIpAndPort failed");
|
||||
|
||||
return AgentDesc{std::string(connectionInfo)};
|
||||
}
|
||||
|
||||
ConnectionInfoType MooncakeTransferAgent::getLocalConnectionInfo()
|
||||
{
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::getLocalConnectionInfo");
|
||||
|
||||
const static size_t kBufLen = 64;
|
||||
char connectionInfo[kBufLen];
|
||||
|
||||
int ret = getLocalIpAndPort(mEngine, connectionInfo, kBufLen);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(ret == 0, "MooncakeTransferAgent::getLocalAgentDesc::getLocalConnectionInfo failed");
|
||||
|
||||
return std::string(connectionInfo);
|
||||
}
|
||||
|
||||
[[nodiscard]] std::unique_ptr<TransferStatus> MooncakeTransferAgent::submitTransferRequests(
|
||||
TransferRequest const& request)
|
||||
{
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::submitTransferRequests");
|
||||
|
||||
bool hasNotif = false;
|
||||
std::string syncMessage;
|
||||
|
||||
if (request.getSyncMessage().has_value())
|
||||
{
|
||||
hasNotif = true;
|
||||
syncMessage = request.getSyncMessage().value();
|
||||
}
|
||||
|
||||
const static size_t kMaxRequestCount = 1024;
|
||||
uint64_t batchId = allocateBatchID(mEngine, kMaxRequestCount);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(batchId != INVALID_BATCH, "allocateBatchID failed");
|
||||
|
||||
int segmentId;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
std::string remoteName = request.getRemoteName();
|
||||
|
||||
auto it = mConnectedAgents.find(remoteName);
|
||||
if (it == mConnectedAgents.end())
|
||||
{
|
||||
std::string error = "Remote agent " + remoteName + "not found";
|
||||
TLLM_THROW(error);
|
||||
}
|
||||
|
||||
auto const& agentInfo = it->second;
|
||||
segmentId = agentInfo.segmentId;
|
||||
}
|
||||
|
||||
auto localDescs = request.getSrcDescs().getDescs();
|
||||
auto remoteDescs = request.getDstDescs().getDescs();
|
||||
|
||||
TLLM_CHECK_WITH_INFO(localDescs.size() == remoteDescs.size(), "Number of local and remote memory must match");
|
||||
|
||||
size_t requestCount = localDescs.size();
|
||||
std::vector<transfer_request_t> transferRequests(requestCount);
|
||||
|
||||
for (size_t index = 0; index < requestCount; ++index)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
localDescs[index].getLen() == remoteDescs[index].getLen(), "Length of local and remote memory must match");
|
||||
|
||||
transferRequests[index].opcode = (request.getOp() == TransferOp::kREAD) ? OPCODE_READ : OPCODE_WRITE;
|
||||
transferRequests[index].source = reinterpret_cast<void*>(localDescs[index].getAddr());
|
||||
transferRequests[index].target_offset = remoteDescs[index].getAddr();
|
||||
transferRequests[index].length = localDescs[index].getLen();
|
||||
transferRequests[index].target_id = segmentId;
|
||||
}
|
||||
|
||||
int rc = 0;
|
||||
if (hasNotif)
|
||||
{
|
||||
notify_msg_t notifyMsg;
|
||||
notifyMsg.name = const_cast<char*>(mLocalAgentName.c_str());
|
||||
notifyMsg.msg = const_cast<char*>(syncMessage.c_str());
|
||||
rc = submitTransferWithNotify(mEngine, batchId, transferRequests.data(), requestCount, notifyMsg);
|
||||
}
|
||||
else
|
||||
{
|
||||
rc = submitTransfer(mEngine, batchId, transferRequests.data(), requestCount);
|
||||
}
|
||||
|
||||
TLLM_CHECK_WITH_INFO(rc == 0, "submitTransfer failed with status: %d", rc);
|
||||
|
||||
return std::make_unique<MooncakeTransferStatus>(mEngine, batchId, requestCount);
|
||||
}
|
||||
|
||||
void MooncakeTransferAgent::notifySyncMessage(std::string const& name, SyncMessage const& syncMessage)
|
||||
{
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::notifySyncMessage");
|
||||
int segmentId;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
auto it = mConnectedAgents.find(name);
|
||||
|
||||
if (it == mConnectedAgents.end())
|
||||
{
|
||||
TLLM_LOG_WARNING("Remote agent %s not found", name.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
auto const& agentInfo = it->second;
|
||||
segmentId = agentInfo.segmentId;
|
||||
}
|
||||
|
||||
notify_msg_t notifyMsg;
|
||||
notifyMsg.name = const_cast<char*>(mLocalAgentName.c_str());
|
||||
std::string encoded = MooncakeBase64Helper::encode(syncMessage);
|
||||
notifyMsg.msg = const_cast<char*>(encoded.c_str());
|
||||
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::notifySyncMessage notifyMsg.name: %s, notifyMsg.msg: %s", notifyMsg.name,
|
||||
notifyMsg.msg);
|
||||
|
||||
int ret = genNotifyInEngine(mEngine, segmentId, notifyMsg);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(ret == 0, "genNotifyInEngine failed with status: %d", ret);
|
||||
}
|
||||
|
||||
[[nodiscard]] std::unordered_map<std::string, std::vector<SyncMessage>> MooncakeTransferAgent::getNotifiedSyncMessages()
|
||||
{
|
||||
std::unordered_map<std::string, std::vector<SyncMessage>> notifs;
|
||||
int size = 0;
|
||||
|
||||
notify_msg_t* notifyMsgs = getNotifsFromEngine(mEngine, &size);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(size >= 0, "getNotifsFromEngine returned negative size: %d", size);
|
||||
|
||||
for (int i = 0; i < size; i++)
|
||||
{
|
||||
if (notifyMsgs[i].msg == nullptr)
|
||||
{
|
||||
TLLM_LOG_WARNING("Message pointer is null for: %s", notifyMsgs[i].name);
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string decoded = MooncakeBase64Helper::decodeToString(notifyMsgs[i].msg);
|
||||
notifs[notifyMsgs[i].name].emplace_back(std::move(decoded));
|
||||
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::getNotifiedSyncMessages getNotifsFromEngine: %s, %s", notifyMsgs[i].name,
|
||||
notifyMsgs[i].msg);
|
||||
}
|
||||
|
||||
freeNotifsMsgBuf(notifyMsgs, size);
|
||||
return notifs;
|
||||
}
|
||||
|
||||
bool MooncakeTransferAgent::checkRemoteDescs(std::string const& name, MemoryDescs const& memoryDescs)
|
||||
{
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::checkRemoteDescs");
|
||||
return true;
|
||||
}
|
||||
|
||||
MooncakeTransferAgent::~MooncakeTransferAgent()
|
||||
{
|
||||
destroyTransferEngine(mEngine);
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::~MooncakeTransferAgent");
|
||||
}
|
||||
|
||||
#if defined(__clang__)
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wreturn-type-c-linkage"
|
||||
#endif
|
||||
|
||||
extern "C"
|
||||
{
|
||||
std::unique_ptr<BaseTransferAgent> createMooncakeTransferAgent(BaseAgentConfig const* config)
|
||||
{
|
||||
TLLM_CHECK(config);
|
||||
return std::make_unique<MooncakeTransferAgent>(*config);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::executor::kv_cache
|
||||
@ -0,0 +1,165 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorrt_llm/executor/transferAgent.h"
|
||||
#include "transfer_engine_c.h"
|
||||
|
||||
namespace tensorrt_llm::executor::kv_cache
|
||||
{
|
||||
|
||||
class MooncakeTransferStatus final : public TransferStatus
|
||||
{
|
||||
public:
|
||||
MooncakeTransferStatus(transfer_engine_t engine, uint64_t batchId, size_t requestCount);
|
||||
|
||||
[[nodiscard]] bool isCompleted() const override;
|
||||
|
||||
void wait() const override;
|
||||
|
||||
private:
|
||||
transfer_engine_t mEngine;
|
||||
uint64_t mBatchId;
|
||||
size_t mRequestCount;
|
||||
mutable bool mBatchFreed = false;
|
||||
};
|
||||
|
||||
class MooncakeMemoryDesc
|
||||
{
|
||||
public:
|
||||
MooncakeMemoryDesc(MemoryDesc desc)
|
||||
: mDesc{std::move(desc)}
|
||||
, mRefCnt{0}
|
||||
{
|
||||
}
|
||||
|
||||
MooncakeMemoryDesc(MooncakeMemoryDesc const& other)
|
||||
: mDesc{other.mDesc}
|
||||
, mRefCnt{0}
|
||||
{
|
||||
}
|
||||
|
||||
MooncakeMemoryDesc& operator=(MooncakeMemoryDesc const&) = delete;
|
||||
|
||||
~MooncakeMemoryDesc() = default;
|
||||
|
||||
void addRef() noexcept
|
||||
{
|
||||
++mRefCnt;
|
||||
}
|
||||
|
||||
int releaseRef() noexcept
|
||||
{
|
||||
return --mRefCnt;
|
||||
}
|
||||
|
||||
int getRefCount() const noexcept
|
||||
{
|
||||
return mRefCnt;
|
||||
}
|
||||
|
||||
MemoryDesc const& getDesc() const noexcept
|
||||
{
|
||||
return mDesc;
|
||||
}
|
||||
|
||||
private:
|
||||
MemoryDesc mDesc;
|
||||
int mRefCnt;
|
||||
};
|
||||
|
||||
class MooncakeBase64Helper
|
||||
{
|
||||
public:
|
||||
static std::string encode(std::vector<uint8_t> const& data);
|
||||
static std::string encode(std::string const& data);
|
||||
|
||||
static std::vector<uint8_t> decode(std::string const& encoded);
|
||||
static std::string decodeToString(std::string const& encoded);
|
||||
|
||||
private:
|
||||
static const std::string STANDARD_CHARS;
|
||||
|
||||
static std::string encodeInternal(std::vector<uint8_t> const& data, std::string const& chars);
|
||||
static std::vector<uint8_t> decodeInternal(std::string const& encoded, std::string const& chars);
|
||||
|
||||
static inline bool isBase64(uint8_t c, std::string const& chars);
|
||||
static inline bool isWhitespace(uint8_t c);
|
||||
};
|
||||
|
||||
class MooncakeTransferAgent final : public BaseTransferAgent
|
||||
{
|
||||
public:
|
||||
MooncakeTransferAgent(BaseAgentConfig const& config);
|
||||
~MooncakeTransferAgent();
|
||||
|
||||
void registerMemory(RegisterDescs const& descs) override;
|
||||
|
||||
void deregisterMemory(RegisterDescs const& descs) override;
|
||||
|
||||
void loadRemoteAgent(std::string const& name, AgentDesc const& agentDesc) override;
|
||||
|
||||
void loadRemoteAgent(std::string const& name, ConnectionInfoType const& connectionInfo) override;
|
||||
|
||||
void invalidateRemoteAgent(std::string const& name) override;
|
||||
|
||||
AgentDesc getLocalAgentDesc() override;
|
||||
|
||||
ConnectionInfoType getLocalConnectionInfo() override;
|
||||
|
||||
[[nodiscard]] std::unique_ptr<TransferStatus> submitTransferRequests(TransferRequest const& request) override;
|
||||
|
||||
void notifySyncMessage(std::string const& name, SyncMessage const& syncMessage) override;
|
||||
|
||||
[[nodiscard]] std::unordered_map<std::string, std::vector<SyncMessage>> getNotifiedSyncMessages() override;
|
||||
|
||||
bool checkRemoteDescs(std::string const& name, MemoryDescs const& memoryDescs) override;
|
||||
|
||||
private:
|
||||
struct AgentInfo
|
||||
{
|
||||
int segmentId;
|
||||
};
|
||||
|
||||
mutable std::mutex mMutex;
|
||||
transfer_engine_t mEngine;
|
||||
std::unordered_map<uintptr_t, std::shared_ptr<MooncakeMemoryDesc>> mMemRegInfo;
|
||||
std::unordered_map<std::string, AgentInfo> mConnectedAgents;
|
||||
std::string mLocalAgentName;
|
||||
};
|
||||
|
||||
#if defined(__clang__)
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wreturn-type-c-linkage"
|
||||
#endif
|
||||
|
||||
extern "C"
|
||||
{
|
||||
[[nodiscard]] std::unique_ptr<BaseTransferAgent> createMooncakeTransferAgent(BaseAgentConfig const* config);
|
||||
}
|
||||
|
||||
#if defined(__clang__)
|
||||
#pragma clang diagnostic pop
|
||||
#endif
|
||||
|
||||
} // namespace tensorrt_llm::executor::kv_cache
|
||||
@ -449,6 +449,7 @@ void initConfigBindings(nb::module_& m)
|
||||
.value("MPI", tle::CacheTransceiverConfig::BackendType::MPI)
|
||||
.value("UCX", tle::CacheTransceiverConfig::BackendType::UCX)
|
||||
.value("NIXL", tle::CacheTransceiverConfig::BackendType::NIXL)
|
||||
.value("MOONCAKE", tle::CacheTransceiverConfig::BackendType::MOONCAKE)
|
||||
.def("from_string",
|
||||
[](std::string const& str)
|
||||
{
|
||||
@ -460,6 +461,8 @@ void initConfigBindings(nb::module_& m)
|
||||
return tle::CacheTransceiverConfig::BackendType::UCX;
|
||||
if (str == "NIXL" || str == "nixl")
|
||||
return tle::CacheTransceiverConfig::BackendType::NIXL;
|
||||
if (str == "MOONCAKE" || str == "mooncake")
|
||||
return tle::CacheTransceiverConfig::BackendType::MOONCAKE;
|
||||
throw std::runtime_error("Invalid backend type: " + str);
|
||||
});
|
||||
|
||||
|
||||
@ -431,6 +431,7 @@ void initConfigBindings(pybind11::module_& m)
|
||||
.value("MPI", tle::CacheTransceiverConfig::BackendType::MPI)
|
||||
.value("UCX", tle::CacheTransceiverConfig::BackendType::UCX)
|
||||
.value("NIXL", tle::CacheTransceiverConfig::BackendType::NIXL)
|
||||
.value("MOONCAKE", tle::CacheTransceiverConfig::BackendType::MOONCAKE)
|
||||
.def("from_string",
|
||||
[](std::string const& str)
|
||||
{
|
||||
@ -442,6 +443,8 @@ void initConfigBindings(pybind11::module_& m)
|
||||
return tle::CacheTransceiverConfig::BackendType::UCX;
|
||||
if (str == "NIXL" || str == "nixl")
|
||||
return tle::CacheTransceiverConfig::BackendType::NIXL;
|
||||
if (str == "MOONCAKE" || str == "mooncake")
|
||||
return tle::CacheTransceiverConfig::BackendType::MOONCAKE;
|
||||
throw std::runtime_error("Invalid backend type: " + str);
|
||||
});
|
||||
|
||||
|
||||
@ -38,10 +38,31 @@ add_gtest(ucxCommTest ucxCommTest.cpp)
|
||||
target_link_libraries(ucxCommTest PRIVATE ${Python3_LIBRARIES})
|
||||
target_link_libraries(serializeUtilsTest PRIVATE ${Python3_LIBRARIES})
|
||||
|
||||
if(NIXL_ROOT)
|
||||
add_gtest(transferAgentTest transferAgentTest.cpp)
|
||||
add_gtest(agentCommTest agentCommTest.cpp)
|
||||
target_link_libraries(transferAgentTest PRIVATE tensorrt_llm_nixl_wrapper)
|
||||
target_link_libraries(agentCommTest PRIVATE tensorrt_llm_nixl_wrapper
|
||||
${Python3_LIBRARIES})
|
||||
# Skip MOONCAKE related tests on Rocky8
|
||||
set(IS_ROCKY8 FALSE)
|
||||
if(EXISTS "/etc/redhat-release")
|
||||
set(IS_ROCKY8 TRUE)
|
||||
endif()
|
||||
|
||||
if(NIXL_ROOT OR (MOONCAKE_ROOT AND NOT IS_ROCKY8))
|
||||
add_gtest(agentCommTest agentCommTest.cpp)
|
||||
add_gtest(transferAgentTest transferAgentTest.cpp)
|
||||
|
||||
if(NIXL_ROOT)
|
||||
target_link_libraries(transferAgentTest PRIVATE tensorrt_llm_nixl_wrapper)
|
||||
target_link_libraries(agentCommTest PRIVATE tensorrt_llm_nixl_wrapper
|
||||
${Python3_LIBRARIES})
|
||||
target_compile_definitions(transferAgentTest PRIVATE TEST_NIXL_BACKEND=1)
|
||||
target_compile_definitions(agentCommTest PRIVATE TEST_NIXL_BACKEND=1)
|
||||
endif()
|
||||
|
||||
if(MOONCAKE_ROOT)
|
||||
target_link_libraries(transferAgentTest
|
||||
PRIVATE tensorrt_llm_mooncake_wrapper)
|
||||
target_link_libraries(agentCommTest PRIVATE tensorrt_llm_mooncake_wrapper
|
||||
${Python3_LIBRARIES})
|
||||
target_compile_definitions(transferAgentTest
|
||||
PRIVATE TEST_MOONCAKE_BACKEND=1)
|
||||
target_compile_definitions(agentCommTest PRIVATE TEST_MOONCAKE_BACKEND=1)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@ -22,22 +22,54 @@ using namespace tensorrt_llm::batch_manager::kv_cache_manager;
|
||||
using namespace tensorrt_llm::runtime;
|
||||
using namespace tensorrt_llm::executor::kv_cache;
|
||||
|
||||
bool needSkipTest(std::string& skipReason)
|
||||
std::vector<std::string> getAvailableBackends()
|
||||
{
|
||||
std::vector<std::string> backends;
|
||||
|
||||
#ifdef TEST_NIXL_BACKEND
|
||||
backends.push_back("nixl");
|
||||
#endif
|
||||
|
||||
#ifdef TEST_MOONCAKE_BACKEND
|
||||
backends.push_back("mooncake");
|
||||
#endif
|
||||
|
||||
return backends;
|
||||
}
|
||||
|
||||
bool needSkipTest(std::string const& backend, std::string& skipReason)
|
||||
{
|
||||
bool skip = false;
|
||||
try
|
||||
{
|
||||
auto& loader = tensorrt_llm::executor::kv_cache::DynLibLoader::getInstance();
|
||||
|
||||
using CreateNixlFuncType = std::unique_ptr<tensorrt_llm::executor::kv_cache::BaseTransferAgent> (*)(
|
||||
tensorrt_llm::executor::kv_cache::BaseAgentConfig const*);
|
||||
auto* func = loader.getFunctionPointer<CreateNixlFuncType>(
|
||||
"libtensorrt_llm_nixl_wrapper.so", "createNixlTransferAgent");
|
||||
if (backend == "nixl")
|
||||
{
|
||||
using CreateNixlFuncType = std::unique_ptr<tensorrt_llm::executor::kv_cache::BaseTransferAgent> (*)(
|
||||
tensorrt_llm::executor::kv_cache::BaseAgentConfig const*);
|
||||
auto* func = loader.getFunctionPointer<CreateNixlFuncType>(
|
||||
"libtensorrt_llm_nixl_wrapper.so", "createNixlTransferAgent");
|
||||
}
|
||||
else if (backend == "mooncake")
|
||||
{
|
||||
using CreateMooncakeFuncType = std::unique_ptr<tensorrt_llm::executor::kv_cache::BaseTransferAgent> (*)(
|
||||
tensorrt_llm::executor::kv_cache::BaseAgentConfig const*);
|
||||
auto* func = loader.getFunctionPointer<CreateMooncakeFuncType>(
|
||||
"libtensorrt_llm_mooncake_wrapper.so", "createMooncakeTransferAgent");
|
||||
}
|
||||
else
|
||||
{
|
||||
skip = true;
|
||||
skipReason = "Unknown backend: " + backend;
|
||||
}
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
std::string error = e.what();
|
||||
if (error.find("libtensorrt_llm_nixl_wrapper.so") != std::string::npos)
|
||||
std::string libName
|
||||
= (backend == "nixl") ? "libtensorrt_llm_nixl_wrapper.so" : "libtensorrt_llm_mooncake_wrapper.so";
|
||||
if (error.find(libName) != std::string::npos)
|
||||
{
|
||||
skip = true;
|
||||
skipReason = error;
|
||||
@ -46,17 +78,26 @@ bool needSkipTest(std::string& skipReason)
|
||||
return skip;
|
||||
}
|
||||
|
||||
class AgentCommTest : public ::testing::Test
|
||||
class AgentCommTest : public ::testing::TestWithParam<std::string>
|
||||
{
|
||||
protected:
|
||||
void SetUp() override
|
||||
{
|
||||
backend = GetParam();
|
||||
std::string skipReason;
|
||||
if (needSkipTest(skipReason))
|
||||
if (needSkipTest(backend, skipReason))
|
||||
{
|
||||
GTEST_SKIP() << skipReason;
|
||||
}
|
||||
setenv("TRTLLM_USE_NIXL_KVCACHE", "1", 1);
|
||||
|
||||
if (backend == "nixl")
|
||||
{
|
||||
setenv("TRTLLM_USE_NIXL_KVCACHE", "1", 1);
|
||||
}
|
||||
else if (backend == "mooncake")
|
||||
{
|
||||
setenv("TRTLLM_USE_MOONCAKE_KVCACHE", "1", 1);
|
||||
}
|
||||
|
||||
auto constexpr numLayers = 8;
|
||||
auto constexpr numHeads = 16;
|
||||
@ -106,15 +147,16 @@ protected:
|
||||
mCacheState.reset();
|
||||
}
|
||||
|
||||
std::string backend;
|
||||
std::unique_ptr<CacheTransBufferManager> mTransBufferManager;
|
||||
std::unique_ptr<KVCacheManager> mCacheManager;
|
||||
std::unique_ptr<CacheState> mCacheState;
|
||||
};
|
||||
|
||||
TEST_F(AgentCommTest, AgentConnectionManagerBasic)
|
||||
TEST_P(AgentCommTest, AgentConnectionManagerBasic)
|
||||
{
|
||||
std::vector<CacheTransBufferManager*> bufferManagers{mTransBufferManager.get()};
|
||||
auto connectionManager = std::make_unique<AgentConnectionManager>(bufferManagers, *mCacheState);
|
||||
auto connectionManager = std::make_unique<AgentConnectionManager>(bufferManagers, *mCacheState, backend);
|
||||
ASSERT_TRUE(connectionManager != nullptr);
|
||||
ASSERT_EQ(connectionManager->getCacheTransBufferManagers().size(), bufferManagers.size());
|
||||
ASSERT_TRUE(connectionManager->getCacheTransBufferManagers().front() != nullptr);
|
||||
@ -126,11 +168,11 @@ TEST_F(AgentCommTest, AgentConnectionManagerBasic)
|
||||
ASSERT_EQ(commState.getAgentState().size(), 1);
|
||||
}
|
||||
|
||||
TEST_F(AgentCommTest, AgentConnectionManagerConnect)
|
||||
TEST_P(AgentCommTest, AgentConnectionManagerConnect)
|
||||
{
|
||||
std::vector<CacheTransBufferManager*> bufferManagers{mTransBufferManager.get()};
|
||||
auto connectionManager0 = std::make_unique<AgentConnectionManager>(bufferManagers, *mCacheState);
|
||||
auto connectionManager1 = std::make_unique<AgentConnectionManager>(bufferManagers, *mCacheState);
|
||||
auto connectionManager0 = std::make_unique<AgentConnectionManager>(bufferManagers, *mCacheState, backend);
|
||||
auto connectionManager1 = std::make_unique<AgentConnectionManager>(bufferManagers, *mCacheState, backend);
|
||||
auto agentName0 = connectionManager0->getAgentName();
|
||||
auto agentName1 = connectionManager1->getAgentName();
|
||||
ASSERT_TRUE(!agentName0.empty());
|
||||
@ -189,3 +231,6 @@ TEST_F(AgentCommTest, AgentConnectionManagerConnect)
|
||||
}
|
||||
TLLM_LOG_INFO("after finish");
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(AvailableBackends, AgentCommTest, ::testing::ValuesIn(getAvailableBackends()),
|
||||
[](::testing::TestParamInfo<AgentCommTest::ParamType> const& info) { return info.param; });
|
||||
|
||||
@ -22,11 +22,27 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <filesystem>
|
||||
#include <vector>
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
using namespace tensorrt_llm::executor::kv_cache;
|
||||
|
||||
std::vector<std::string> getAvailableBackends()
|
||||
{
|
||||
std::vector<std::string> backends;
|
||||
|
||||
#ifdef TEST_NIXL_BACKEND
|
||||
backends.push_back("nixl");
|
||||
#endif
|
||||
|
||||
#ifdef TEST_MOONCAKE_BACKEND
|
||||
backends.push_back("mooncake");
|
||||
#endif
|
||||
|
||||
return backends;
|
||||
}
|
||||
|
||||
class RegisteredHostMemory
|
||||
{
|
||||
public:
|
||||
@ -54,100 +70,105 @@ private:
|
||||
BaseTransferAgent* mAgentPtr{};
|
||||
};
|
||||
|
||||
class TransferAgentTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro-type-member-init)
|
||||
class TransferAgentTest : public ::testing::TestWithParam<std::string> // NOLINT(cppcoreguidelines-pro-type-member-init)
|
||||
{
|
||||
public:
|
||||
void SetUp() override {}
|
||||
void SetUp() override
|
||||
{
|
||||
backend = GetParam();
|
||||
}
|
||||
|
||||
void TearDown() override {}
|
||||
|
||||
[[nodiscard]] std::unique_ptr<BaseTransferAgent> makeTransferAgent(BaseAgentConfig const& config)
|
||||
{
|
||||
return tensorrt_llm::executor::kv_cache::makeTransferAgent("nixl", &config);
|
||||
return tensorrt_llm::executor::kv_cache::makeTransferAgent(backend, &config);
|
||||
}
|
||||
|
||||
std::string backend;
|
||||
};
|
||||
|
||||
TEST_F(TransferAgentTest, Basic)
|
||||
TEST_P(TransferAgentTest, Basic)
|
||||
{
|
||||
|
||||
std::string const agent0{"agent0"}, agent1{"agent1"};
|
||||
BaseAgentConfig config0{agent0, true}, config1{agent1, true};
|
||||
auto nixlAgent0 = makeTransferAgent(config0);
|
||||
auto nixlAgent1 = makeTransferAgent(config1);
|
||||
auto xferAgent0 = makeTransferAgent(config0);
|
||||
auto xferAgent1 = makeTransferAgent(config1);
|
||||
|
||||
TLLM_CHECK(nixlAgent0);
|
||||
TLLM_CHECK(nixlAgent1);
|
||||
TLLM_CHECK(xferAgent0);
|
||||
TLLM_CHECK(xferAgent1);
|
||||
|
||||
std::vector<char> memory0(100, 10);
|
||||
std::vector<char> memory1(100, 1);
|
||||
|
||||
RegisteredHostMemory regMem0(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory0}}}, nixlAgent0.get());
|
||||
RegisteredHostMemory regMem1(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, nixlAgent1.get());
|
||||
RegisteredHostMemory regMem0(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory0}}}, xferAgent0.get());
|
||||
RegisteredHostMemory regMem1(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, xferAgent1.get());
|
||||
|
||||
// nixlAgent0->loadRemoteAgent(agent1);
|
||||
auto connectionInfo = nixlAgent1->getLocalConnectionInfo();
|
||||
nixlAgent0->loadRemoteAgent(agent1, connectionInfo);
|
||||
// xferAgent0->loadRemoteAgent(agent1);
|
||||
auto connectionInfo = xferAgent1->getLocalConnectionInfo();
|
||||
xferAgent0->loadRemoteAgent(agent1, connectionInfo);
|
||||
bool checked = false;
|
||||
do
|
||||
{
|
||||
checked = nixlAgent0->checkRemoteDescs(agent1, regMem1.getDescs());
|
||||
// wait for regMem is unpacked by nixlAgent0
|
||||
checked = xferAgent0->checkRemoteDescs(agent1, regMem1.getDescs());
|
||||
// wait for regMem is unpacked by xferAgent0
|
||||
} while (!checked);
|
||||
|
||||
TransferRequest writeReq{TransferOp::kWRITE, regMem0.getDescs(), regMem1.getDescs(), agent1};
|
||||
auto status = nixlAgent0->submitTransferRequests(writeReq);
|
||||
auto status = xferAgent0->submitTransferRequests(writeReq);
|
||||
status->wait();
|
||||
|
||||
TLLM_CHECK(memory0 == memory1);
|
||||
|
||||
nixlAgent0->invalidateRemoteAgent(agent1);
|
||||
xferAgent0->invalidateRemoteAgent(agent1);
|
||||
}
|
||||
|
||||
TEST_F(TransferAgentTest, Basic2)
|
||||
TEST_P(TransferAgentTest, Basic2)
|
||||
{
|
||||
|
||||
std::string const agent0{"agent0"}, agent1{"agent1"};
|
||||
BaseAgentConfig config0{agent0, true}, config1{agent1, true};
|
||||
auto nixlAgent0 = makeTransferAgent(config0);
|
||||
auto nixlAgent1 = makeTransferAgent(config1);
|
||||
auto xferAgent0 = makeTransferAgent(config0);
|
||||
auto xferAgent1 = makeTransferAgent(config1);
|
||||
|
||||
TLLM_CHECK(nixlAgent0);
|
||||
TLLM_CHECK(nixlAgent1);
|
||||
TLLM_CHECK(xferAgent0);
|
||||
TLLM_CHECK(xferAgent1);
|
||||
|
||||
std::vector<char> memory0(100, 10);
|
||||
std::vector<char> memory1(100, 1);
|
||||
|
||||
RegisteredHostMemory regMem0(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory0}}}, nixlAgent0.get());
|
||||
RegisteredHostMemory regMem1(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, nixlAgent1.get());
|
||||
RegisteredHostMemory regMem0(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory0}}}, xferAgent0.get());
|
||||
RegisteredHostMemory regMem1(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, xferAgent1.get());
|
||||
|
||||
// nixlAgent0->loadRemoteAgent(agent1);
|
||||
auto connectionInfo = nixlAgent1->getLocalConnectionInfo();
|
||||
nixlAgent0->loadRemoteAgent(agent1, connectionInfo);
|
||||
// xferAgent0->loadRemoteAgent(agent1);
|
||||
auto connectionInfo = xferAgent1->getLocalConnectionInfo();
|
||||
xferAgent0->loadRemoteAgent(agent1, connectionInfo);
|
||||
bool checked = false;
|
||||
do
|
||||
{
|
||||
checked = nixlAgent0->checkRemoteDescs(agent1, regMem1.getDescs());
|
||||
checked = xferAgent0->checkRemoteDescs(agent1, regMem1.getDescs());
|
||||
} while (!checked);
|
||||
|
||||
TransferRequest readReq{TransferOp::kREAD, regMem0.getDescs(), regMem1.getDescs(), agent1};
|
||||
auto status = nixlAgent0->submitTransferRequests(readReq);
|
||||
auto status = xferAgent0->submitTransferRequests(readReq);
|
||||
status->wait();
|
||||
|
||||
TLLM_CHECK(memory0 == memory1);
|
||||
|
||||
nixlAgent0->invalidateRemoteAgent(agent1);
|
||||
xferAgent0->invalidateRemoteAgent(agent1);
|
||||
}
|
||||
|
||||
TEST_F(TransferAgentTest, DeviceMemory)
|
||||
TEST_P(TransferAgentTest, DeviceMemory)
|
||||
{
|
||||
|
||||
std::string const agent0{"agent0"}, agent1{"agent1"};
|
||||
BaseAgentConfig config0{agent0, true}, config1{agent1, true};
|
||||
auto nixlAgent0 = makeTransferAgent(config0);
|
||||
auto nixlAgent1 = makeTransferAgent(config1);
|
||||
auto xferAgent0 = makeTransferAgent(config0);
|
||||
auto xferAgent1 = makeTransferAgent(config1);
|
||||
|
||||
TLLM_CHECK(nixlAgent0);
|
||||
TLLM_CHECK(nixlAgent1);
|
||||
TLLM_CHECK(xferAgent0);
|
||||
TLLM_CHECK(xferAgent1);
|
||||
char* dev_ptr0;
|
||||
char* dev_ptr1;
|
||||
size_t size = 100;
|
||||
@ -159,20 +180,20 @@ TEST_F(TransferAgentTest, DeviceMemory)
|
||||
cudaMemcpy(dev_ptr0, memory0.data(), size, cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(dev_ptr1, memory1.data(), size, cudaMemcpyHostToDevice);
|
||||
RegisteredHostMemory regMem0(
|
||||
MemoryDescs{MemoryType::kVRAM, {MemoryDesc{dev_ptr0, size, deviceId}}}, nixlAgent0.get());
|
||||
MemoryDescs{MemoryType::kVRAM, {MemoryDesc{dev_ptr0, size, deviceId}}}, xferAgent0.get());
|
||||
RegisteredHostMemory regMem1(
|
||||
MemoryDescs{MemoryType::kVRAM, {MemoryDesc{dev_ptr1, size, deviceId}}}, nixlAgent1.get());
|
||||
MemoryDescs{MemoryType::kVRAM, {MemoryDesc{dev_ptr1, size, deviceId}}}, xferAgent1.get());
|
||||
|
||||
// nixlAgent0->loadRemoteAgent(agent1);
|
||||
auto connectionInfo = nixlAgent1->getLocalConnectionInfo();
|
||||
nixlAgent0->loadRemoteAgent(agent1, connectionInfo);
|
||||
// xferAgent0->loadRemoteAgent(agent1);
|
||||
auto connectionInfo = xferAgent1->getLocalConnectionInfo();
|
||||
xferAgent0->loadRemoteAgent(agent1, connectionInfo);
|
||||
bool checked = false;
|
||||
do
|
||||
{
|
||||
checked = nixlAgent0->checkRemoteDescs(agent1, regMem1.getDescs());
|
||||
checked = xferAgent0->checkRemoteDescs(agent1, regMem1.getDescs());
|
||||
} while (!checked);
|
||||
TransferRequest writeReq{TransferOp::kWRITE, regMem0.getDescs(), regMem1.getDescs(), agent1};
|
||||
auto status = nixlAgent0->submitTransferRequests(writeReq);
|
||||
auto status = xferAgent0->submitTransferRequests(writeReq);
|
||||
status->wait();
|
||||
|
||||
cudaMemcpy(memory0.data(), dev_ptr0, size, cudaMemcpyDeviceToHost);
|
||||
@ -181,98 +202,99 @@ TEST_F(TransferAgentTest, DeviceMemory)
|
||||
TLLM_CHECK(memory0 == memory1);
|
||||
TLLM_CUDA_CHECK(cudaFree(dev_ptr0));
|
||||
TLLM_CUDA_CHECK(cudaFree(dev_ptr1));
|
||||
nixlAgent0->invalidateRemoteAgent(agent1);
|
||||
xferAgent0->invalidateRemoteAgent(agent1);
|
||||
}
|
||||
|
||||
TEST_F(TransferAgentTest, Connect)
|
||||
TEST_P(TransferAgentTest, Connect)
|
||||
{
|
||||
|
||||
std::string const agent0{"agent0"}, agent1{"agent1"}, agent2{"agent2"};
|
||||
BaseAgentConfig config0{agent0, true}, config1{agent1, true}, config2{agent2, true};
|
||||
auto nixlAgent0 = makeTransferAgent(config0);
|
||||
auto nixlAgent1 = makeTransferAgent(config1);
|
||||
auto nixlAgent2 = makeTransferAgent(config2);
|
||||
auto xferAgent0 = makeTransferAgent(config0);
|
||||
auto xferAgent1 = makeTransferAgent(config1);
|
||||
auto xferAgent2 = makeTransferAgent(config2);
|
||||
|
||||
TLLM_CHECK(nixlAgent0);
|
||||
TLLM_CHECK(nixlAgent1);
|
||||
TLLM_CHECK(xferAgent0);
|
||||
TLLM_CHECK(xferAgent1);
|
||||
|
||||
std::vector<char> memory0(100, 10);
|
||||
std::vector<char> memory1(100, 1);
|
||||
MemoryDescs memDescs0{MemoryType::kDRAM, {MemoryDesc{memory0}}};
|
||||
MemoryDescs memDescs1{MemoryType::kDRAM, {MemoryDesc{memory1}}};
|
||||
|
||||
nixlAgent0->registerMemory(memDescs0);
|
||||
nixlAgent1->registerMemory(memDescs1);
|
||||
nixlAgent2->registerMemory(memDescs0);
|
||||
xferAgent0->registerMemory(memDescs0);
|
||||
xferAgent1->registerMemory(memDescs1);
|
||||
xferAgent2->registerMemory(memDescs0);
|
||||
|
||||
// nixlAgent0->loadRemoteAgent(agent1);
|
||||
auto connectionInfo = nixlAgent1->getLocalConnectionInfo();
|
||||
nixlAgent0->loadRemoteAgent(agent1, connectionInfo);
|
||||
// xferAgent0->loadRemoteAgent(agent1);
|
||||
auto connectionInfo = xferAgent1->getLocalConnectionInfo();
|
||||
xferAgent0->loadRemoteAgent(agent1, connectionInfo);
|
||||
bool checked = false;
|
||||
do
|
||||
{
|
||||
checked = nixlAgent0->checkRemoteDescs(agent1, memDescs1);
|
||||
checked = xferAgent0->checkRemoteDescs(agent1, memDescs1);
|
||||
} while (!checked);
|
||||
TransferRequest writeReq{TransferOp::kWRITE, memDescs0, memDescs1, agent1};
|
||||
auto status = nixlAgent0->submitTransferRequests(writeReq);
|
||||
auto status = xferAgent0->submitTransferRequests(writeReq);
|
||||
status->wait();
|
||||
|
||||
TLLM_CHECK(memory0 == memory1);
|
||||
nixlAgent2->loadRemoteAgent(agent1, connectionInfo);
|
||||
xferAgent2->loadRemoteAgent(agent1, connectionInfo);
|
||||
checked = false;
|
||||
do
|
||||
{
|
||||
checked = nixlAgent2->checkRemoteDescs(agent1, memDescs1);
|
||||
checked = xferAgent2->checkRemoteDescs(agent1, memDescs1);
|
||||
} while (!checked);
|
||||
TransferRequest writeReq2{TransferOp::kWRITE, memDescs0, memDescs1, agent1};
|
||||
auto status2 = nixlAgent2->submitTransferRequests(writeReq2);
|
||||
auto status2 = xferAgent2->submitTransferRequests(writeReq2);
|
||||
status2->wait();
|
||||
TLLM_CHECK(memory0 == memory1);
|
||||
nixlAgent0->invalidateRemoteAgent(agent1);
|
||||
nixlAgent2->invalidateRemoteAgent(agent1);
|
||||
nixlAgent0->deregisterMemory(memDescs0);
|
||||
nixlAgent1->deregisterMemory(memDescs1);
|
||||
nixlAgent2->deregisterMemory(memDescs0);
|
||||
xferAgent0->invalidateRemoteAgent(agent1);
|
||||
xferAgent2->invalidateRemoteAgent(agent1);
|
||||
xferAgent0->deregisterMemory(memDescs0);
|
||||
xferAgent1->deregisterMemory(memDescs1);
|
||||
xferAgent2->deregisterMemory(memDescs0);
|
||||
}
|
||||
|
||||
TEST_F(TransferAgentTest, SyncMessage)
|
||||
TEST_P(TransferAgentTest, SyncMessage)
|
||||
{
|
||||
constexpr std::size_t MAX_QUERY_TIMES = std::numeric_limits<size_t>::max();
|
||||
std::string const agent0{"agent0"}, agent1{"agent1"};
|
||||
BaseAgentConfig config0{agent0, true}, config1{agent1, true};
|
||||
auto nixlAgent0 = makeTransferAgent(config0);
|
||||
auto nixlAgent1 = makeTransferAgent(config1);
|
||||
auto xferAgent0 = makeTransferAgent(config0);
|
||||
auto xferAgent1 = makeTransferAgent(config1);
|
||||
|
||||
TLLM_CHECK(nixlAgent0);
|
||||
TLLM_CHECK(nixlAgent1);
|
||||
TLLM_CHECK(xferAgent0);
|
||||
TLLM_CHECK(xferAgent1);
|
||||
|
||||
std::vector<char> memory0(100, 10);
|
||||
std::vector<char> memory1(100, 1);
|
||||
|
||||
RegisteredHostMemory regMem0(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory0}}}, nixlAgent0.get());
|
||||
RegisteredHostMemory regMem1(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, nixlAgent0.get());
|
||||
RegisteredHostMemory regMem0(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory0}}}, xferAgent0.get());
|
||||
RegisteredHostMemory regMem1(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, xferAgent0.get());
|
||||
|
||||
RegisteredHostMemory regMem2(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory0}}}, nixlAgent1.get());
|
||||
RegisteredHostMemory regMem3(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, nixlAgent1.get());
|
||||
RegisteredHostMemory regMem2(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory0}}}, xferAgent1.get());
|
||||
RegisteredHostMemory regMem3(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, xferAgent1.get());
|
||||
|
||||
// nixlAgent0->loadRemoteAgent(agent1);
|
||||
auto connectionInfo = nixlAgent1->getLocalConnectionInfo();
|
||||
nixlAgent0->loadRemoteAgent(agent1, connectionInfo);
|
||||
// xferAgent0->loadRemoteAgent(agent1);
|
||||
auto connectionInfo = xferAgent1->getLocalConnectionInfo();
|
||||
xferAgent0->loadRemoteAgent(agent1, connectionInfo);
|
||||
bool checked = false;
|
||||
do
|
||||
{
|
||||
checked = nixlAgent0->checkRemoteDescs(agent1, regMem3.getDescs());
|
||||
checked = xferAgent0->checkRemoteDescs(agent1, regMem3.getDescs());
|
||||
} while (!checked);
|
||||
auto syncMessage = std::string("agent_sync_message");
|
||||
TransferRequest writeReq{TransferOp::kWRITE, regMem0.getDescs(), regMem3.getDescs(), agent1};
|
||||
auto status = nixlAgent0->submitTransferRequests(writeReq);
|
||||
nixlAgent0->notifySyncMessage(agent1, syncMessage);
|
||||
auto status = xferAgent0->submitTransferRequests(writeReq);
|
||||
xferAgent0->notifySyncMessage(agent1, syncMessage);
|
||||
|
||||
auto notif = nixlAgent1->getNotifiedSyncMessages();
|
||||
auto notif = xferAgent1->getNotifiedSyncMessages();
|
||||
for (std::size_t i = 0; i < MAX_QUERY_TIMES && notif.size() == 0; i++)
|
||||
{
|
||||
notif = nixlAgent1->getNotifiedSyncMessages();
|
||||
notif = xferAgent1->getNotifiedSyncMessages();
|
||||
}
|
||||
status->wait();
|
||||
TLLM_CHECK(status->isCompleted());
|
||||
TLLM_CHECK(notif.size() == 1);
|
||||
TLLM_CHECK(notif[agent0].size() == 1);
|
||||
@ -281,25 +303,25 @@ TEST_F(TransferAgentTest, SyncMessage)
|
||||
TLLM_CHECK(memory0 == memory1);
|
||||
|
||||
std::string syncMessage2 = "two_agent_sync_message";
|
||||
nixlAgent0->notifySyncMessage(agent1, syncMessage2);
|
||||
auto notif2 = nixlAgent1->getNotifiedSyncMessages();
|
||||
xferAgent0->notifySyncMessage(agent1, syncMessage2);
|
||||
auto notif2 = xferAgent1->getNotifiedSyncMessages();
|
||||
for (std::size_t i = 0; i < MAX_QUERY_TIMES && notif2.size() == 0; i++)
|
||||
{
|
||||
notif2 = nixlAgent1->getNotifiedSyncMessages();
|
||||
notif2 = xferAgent1->getNotifiedSyncMessages();
|
||||
}
|
||||
TLLM_CHECK(notif2.size() == 1);
|
||||
TLLM_CHECK(notif2[agent0].size() == 1);
|
||||
TLLM_CHECK(notif2[agent0][0] == syncMessage2);
|
||||
|
||||
// nixlAgent1->loadRemoteAgent(agent0);
|
||||
auto connectionInfo2 = nixlAgent0->getLocalConnectionInfo();
|
||||
nixlAgent1->loadRemoteAgent(agent0, connectionInfo2);
|
||||
// xferAgent1->loadRemoteAgent(agent0);
|
||||
auto connectionInfo2 = xferAgent0->getLocalConnectionInfo();
|
||||
xferAgent1->loadRemoteAgent(agent0, connectionInfo2);
|
||||
std::string syncMessage3 = "three_agent_sync_message";
|
||||
nixlAgent1->notifySyncMessage(agent0, syncMessage3);
|
||||
auto notif3 = nixlAgent0->getNotifiedSyncMessages();
|
||||
xferAgent1->notifySyncMessage(agent0, syncMessage3);
|
||||
auto notif3 = xferAgent0->getNotifiedSyncMessages();
|
||||
for (std::size_t i = 0; i < MAX_QUERY_TIMES && notif3.size() == 0; i++)
|
||||
{
|
||||
notif3 = nixlAgent0->getNotifiedSyncMessages();
|
||||
notif3 = xferAgent0->getNotifiedSyncMessages();
|
||||
}
|
||||
TLLM_CHECK(notif3.size() == 1);
|
||||
TLLM_CHECK(notif3[agent1].size() == 1);
|
||||
@ -308,19 +330,20 @@ TEST_F(TransferAgentTest, SyncMessage)
|
||||
bool checked2 = false;
|
||||
do
|
||||
{
|
||||
checked2 = nixlAgent0->checkRemoteDescs(agent1, regMem1.getDescs());
|
||||
checked2 = xferAgent0->checkRemoteDescs(agent1, regMem1.getDescs());
|
||||
} while (!checked2);
|
||||
|
||||
std::string syncMessage4 = "four_agent_sync_message";
|
||||
TransferRequest writeReq1{TransferOp::kWRITE, regMem2.getDescs(), regMem1.getDescs(), agent0};
|
||||
auto status1 = nixlAgent1->submitTransferRequests(writeReq1);
|
||||
nixlAgent1->notifySyncMessage(agent0, syncMessage4);
|
||||
auto status1 = xferAgent1->submitTransferRequests(writeReq1);
|
||||
xferAgent1->notifySyncMessage(agent0, syncMessage4);
|
||||
|
||||
auto notif4 = nixlAgent0->getNotifiedSyncMessages();
|
||||
auto notif4 = xferAgent0->getNotifiedSyncMessages();
|
||||
for (std::size_t i = 0; i < MAX_QUERY_TIMES && notif4.size() == 0; i++)
|
||||
{
|
||||
notif4 = nixlAgent0->getNotifiedSyncMessages();
|
||||
notif4 = xferAgent0->getNotifiedSyncMessages();
|
||||
}
|
||||
status1->wait();
|
||||
TLLM_CHECK(status1->isCompleted());
|
||||
TLLM_CHECK(notif4.size() == 1);
|
||||
TLLM_CHECK(notif4[agent1].size() == 1);
|
||||
@ -335,11 +358,11 @@ TEST_F(TransferAgentTest, SyncMessage)
|
||||
std::stringstream ss;
|
||||
Serialization::serialize(state, ss);
|
||||
std::string serializedState = ss.str();
|
||||
nixlAgent0->notifySyncMessage(agent1, serializedState);
|
||||
auto notif5 = nixlAgent1->getNotifiedSyncMessages();
|
||||
xferAgent0->notifySyncMessage(agent1, serializedState);
|
||||
auto notif5 = xferAgent1->getNotifiedSyncMessages();
|
||||
for (size_t i = 0; i < MAX_QUERY_TIMES && notif5.size() == 0; i++)
|
||||
{
|
||||
notif5 = nixlAgent1->getNotifiedSyncMessages();
|
||||
notif5 = xferAgent1->getNotifiedSyncMessages();
|
||||
}
|
||||
TLLM_CHECK(notif5.size() == 1);
|
||||
TLLM_CHECK(notif5[agent0].size() == 1);
|
||||
@ -348,10 +371,16 @@ TEST_F(TransferAgentTest, SyncMessage)
|
||||
auto state2 = Serialization::deserializeCommState(ss2);
|
||||
TLLM_CHECK(state2 == state);
|
||||
|
||||
nixlAgent0->invalidateRemoteAgent(agent1);
|
||||
nixlAgent1->invalidateRemoteAgent(agent0);
|
||||
xferAgent0->invalidateRemoteAgent(agent1);
|
||||
xferAgent1->invalidateRemoteAgent(agent0);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(AvailableBackends, TransferAgentTest, ::testing::ValuesIn(getAvailableBackends()),
|
||||
[](::testing::TestParamInfo<TransferAgentTest::ParamType> const& info) { return info.param; });
|
||||
|
||||
// Skip LoopbackAgentTest for mooncake backend for now
|
||||
#ifdef TEST_NIXL_BACKEND
|
||||
|
||||
class LoopbackAgentTest : public ::testing::Test,
|
||||
public ::testing::WithParamInterface<bool> // NOLINT(cppcoreguidelines-pro-type-member-init)
|
||||
{
|
||||
@ -466,3 +495,5 @@ TEST_P(LoopbackAgentTest, GpuToFile)
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(, LoopbackAgentTest, ::testing::Values(true, false));
|
||||
|
||||
#endif // TEST_NIXL_BACKEND
|
||||
|
||||
@ -46,6 +46,7 @@
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <filesystem>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <tensorrt_llm/batch_manager/cacheTransBuffer.h>
|
||||
@ -713,7 +714,7 @@ protected:
|
||||
return;
|
||||
}
|
||||
else if (tensorrt_llm::common::getEnvUseMPIKvCache() || tensorrt_llm::common::getEnvUseUCXKvCache()
|
||||
|| tensorrt_llm::common::getEnvUseNixlKvCache())
|
||||
|| tensorrt_llm::common::getEnvUseNixlKvCache() || tensorrt_llm::common::getEnvUseMooncakeKvCache())
|
||||
{
|
||||
int maxNumTokens = 2048;
|
||||
mCacheTransBufferManagers.clear();
|
||||
@ -729,7 +730,15 @@ protected:
|
||||
}
|
||||
bool isUcx = tensorrt_llm::common::getEnvUseUCXKvCache();
|
||||
bool isNixl = tensorrt_llm::common::getEnvUseNixlKvCache();
|
||||
TLLM_LOG_INFO("Enable %s KV cache transport.", isUcx ? "UCX" : isNixl ? "NIXL" : "MPI");
|
||||
bool isMooncake = tensorrt_llm::common::getEnvUseMooncakeKvCache();
|
||||
// Skip tests for MOONCAKE when on Rocky8
|
||||
bool isRocky8 = std::filesystem::exists("/etc/redhat-release");
|
||||
isMooncake = isMooncake && !isRocky8;
|
||||
TLLM_LOG_INFO("Enable %s KV cache transport.",
|
||||
isUcx ? "UCX"
|
||||
: isNixl ? "NIXL"
|
||||
: isMooncake ? "MOONCAKE"
|
||||
: "MPI");
|
||||
|
||||
if (isUcx)
|
||||
{
|
||||
@ -756,7 +765,12 @@ protected:
|
||||
setenv("TRTLLM_NIXL_PORT", std::to_string(port).c_str(), 1);
|
||||
|
||||
mConnectionManager
|
||||
= std::make_unique<texec::kv_cache::AgentConnectionManager>(bufferManagers, *mCacheState);
|
||||
= std::make_unique<texec::kv_cache::AgentConnectionManager>(bufferManagers, *mCacheState, "nixl");
|
||||
}
|
||||
else if (isMooncake)
|
||||
{
|
||||
mConnectionManager = std::make_unique<texec::kv_cache::AgentConnectionManager>(
|
||||
bufferManagers, *mCacheState, "mooncake");
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -783,7 +797,7 @@ protected:
|
||||
std::vector<int> contextRankVec(mContextRankSize);
|
||||
std::iota(contextRankVec.begin(), contextRankVec.end(), 0);
|
||||
|
||||
if (isUcx || isNixl)
|
||||
if (isUcx || isNixl || isMooncake)
|
||||
{
|
||||
auto commState = mConnectionManager->getCommState();
|
||||
namespace su = tensorrt_llm::executor::serialize_utils;
|
||||
@ -1286,9 +1300,9 @@ TEST_P(AsymmetricalCacheTest, TestCase)
|
||||
int indexerDimPerHead = std::get<17>(param);
|
||||
int indexerKCacheQuantBlockSize = std::get<18>(param);
|
||||
|
||||
if (genCp > 1 && tensorrt_llm::common::getEnvUseNixlKvCache())
|
||||
if (genCp > 1 && (tensorrt_llm::common::getEnvUseNixlKvCache() || tensorrt_llm::common::getEnvUseMooncakeKvCache()))
|
||||
{
|
||||
GTEST_SKIP() << "Temporarily skipping cache transceiver tests with NIXL backend for CP.";
|
||||
GTEST_SKIP() << "Temporarily skipping cache transceiver tests with NIXL and MOONCAKE backend for CP.";
|
||||
}
|
||||
std::vector<int> lenList = {30, 10, 60, 80};
|
||||
if (genCp > 1)
|
||||
@ -1410,9 +1424,9 @@ TEST_P(AsymmetricalCacheTestWithDP, TestCase)
|
||||
int indexerDimPerHead = std::get<17>(param);
|
||||
int indexerKCacheQuantBlockSize = std::get<18>(param);
|
||||
|
||||
if (genCp > 1 && tensorrt_llm::common::getEnvUseNixlKvCache())
|
||||
if (genCp > 1 && (tensorrt_llm::common::getEnvUseNixlKvCache() || tensorrt_llm::common::getEnvUseMooncakeKvCache()))
|
||||
{
|
||||
GTEST_SKIP() << "Temporarily skipping cache transceiver tests with NIXL backend for CP.";
|
||||
GTEST_SKIP() << "Temporarily skipping cache transceiver tests with NIXL and MOONCAKE backend for CP.";
|
||||
}
|
||||
setUpCommunicator(contextTp, contextPp, contextCp, genTp, genPp, genCp, isMLA, contextDP, generationDP);
|
||||
|
||||
|
||||
@ -60,12 +60,12 @@ def BUILD_CONFIGS = [
|
||||
// Vanilla TARNAME is used for packaging in runLLMPackage
|
||||
// cmake-vars cannot be empty, so passing (default) multi-device configuration.
|
||||
(CONFIG_LINUX_X86_64_VANILLA) : [
|
||||
(WHEEL_EXTRA_ARGS) : "--extra-cmake-vars ENABLE_MULTI_DEVICE=1 --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl --micro_benchmarks",
|
||||
(WHEEL_EXTRA_ARGS) : "--extra-cmake-vars ENABLE_MULTI_DEVICE=1 --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl --extra-cmake-vars MOONCAKE_ROOT=/usr/local/Mooncake --micro_benchmarks",
|
||||
(TARNAME) : "TensorRT-LLM.tar.gz",
|
||||
(WHEEL_ARCHS): "80-real;86-real;89-real;90-real;100-real;103-real;120-real",
|
||||
],
|
||||
(CONFIG_LINUX_X86_64_PYBIND) : [
|
||||
(WHEEL_EXTRA_ARGS) : "--binding_type pybind --extra-cmake-vars ENABLE_MULTI_DEVICE=1 --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl --micro_benchmarks",
|
||||
(WHEEL_EXTRA_ARGS) : "--binding_type pybind --extra-cmake-vars ENABLE_MULTI_DEVICE=1 --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl --extra-cmake-vars MOONCAKE_ROOT=/usr/local/Mooncake --micro_benchmarks",
|
||||
(TARNAME) : "pybind-TensorRT-LLM.tar.gz",
|
||||
(WHEEL_ARCHS): "80-real;86-real;89-real;90-real;100-real;103-real;120-real",
|
||||
],
|
||||
@ -80,13 +80,13 @@ def BUILD_CONFIGS = [
|
||||
(WHEEL_ARCHS): "80-real;86-real;89-real;90-real;100-real;103-real;120-real",
|
||||
],
|
||||
(CONFIG_LINUX_AARCH64): [
|
||||
(WHEEL_EXTRA_ARGS) : "--extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl",
|
||||
(WHEEL_EXTRA_ARGS) : "--extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl --extra-cmake-vars MOONCAKE_ROOT=/usr/local/Mooncake",
|
||||
(TARNAME) : "TensorRT-LLM-GH200.tar.gz",
|
||||
(WHEEL_ARCHS): "90-real;100-real;103-real;120-real",
|
||||
(BUILD_JOBS_FOR_CONFIG): "4", // TODO: Remove after fix the build OOM issue on SBSA
|
||||
],
|
||||
(CONFIG_LINUX_AARCH64_PYBIND): [
|
||||
(WHEEL_EXTRA_ARGS) : "--binding_type pybind --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl",
|
||||
(WHEEL_EXTRA_ARGS) : "--binding_type pybind --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl --extra-cmake-vars MOONCAKE_ROOT=/usr/local/Mooncake",
|
||||
(TARNAME) : "pybind-TensorRT-LLM-GH200.tar.gz",
|
||||
(WHEEL_ARCHS): "90-real;100-real;103-real;120-real",
|
||||
(BUILD_JOBS_FOR_CONFIG): "4", // TODO: Remove after fix the build OOM issue on SBSA
|
||||
|
||||
@ -458,6 +458,7 @@ def main(*,
|
||||
trt_root: str = '/usr/local/tensorrt',
|
||||
nccl_root: str = None,
|
||||
nixl_root: str = None,
|
||||
mooncake_root: str = None,
|
||||
internal_cutlass_kernels_root: str = None,
|
||||
clean: bool = False,
|
||||
clean_wheel: bool = False,
|
||||
@ -559,6 +560,11 @@ def main(*,
|
||||
if nixl_root is not None:
|
||||
cmake_def_args.append(f"-DNIXL_ROOT={nixl_root}")
|
||||
|
||||
if mooncake_root is not None:
|
||||
if on_windows:
|
||||
raise RuntimeError("Mooncake is not supported on Windows.")
|
||||
cmake_def_args.append(f"-DMOONCAKE_ROOT={mooncake_root}")
|
||||
|
||||
build_dir = get_build_dir(build_dir, build_type)
|
||||
first_build = not Path(build_dir, "CMakeFiles").exists()
|
||||
|
||||
@ -819,6 +825,14 @@ def main(*,
|
||||
build_run(
|
||||
f"find {nixl_dir} -type f -name '*.so*' -exec patchelf --set-rpath \'$ORIGIN:$ORIGIN/plugins:$ORIGIN/../:$ORIGIN/../ucx/:$ORIGIN/../../ucx/\' {{}} \\;"
|
||||
)
|
||||
if os.path.exists(
|
||||
build_dir /
|
||||
"tensorrt_llm/executor/cache_transmission/mooncake_utils/libtensorrt_llm_mooncake_wrapper.so"
|
||||
):
|
||||
install_file(
|
||||
build_dir /
|
||||
"tensorrt_llm/executor/cache_transmission/mooncake_utils/libtensorrt_llm_mooncake_wrapper.so",
|
||||
lib_dir / "libtensorrt_llm_mooncake_wrapper.so")
|
||||
install_file(
|
||||
build_dir /
|
||||
"tensorrt_llm/kernels/decoderMaskedMultiheadAttention/libdecoder_attention_0.so",
|
||||
@ -1041,6 +1055,10 @@ def add_arguments(parser: ArgumentParser):
|
||||
help="Directory containing NCCL headers and libraries")
|
||||
parser.add_argument("--nixl_root",
|
||||
help="Directory containing NIXL headers and libraries")
|
||||
parser.add_argument(
|
||||
"--mooncake_root",
|
||||
help=
|
||||
"Directory containing Mooncake transfer engine headers and libraries")
|
||||
parser.add_argument(
|
||||
"--internal-cutlass-kernels-root",
|
||||
default="",
|
||||
|
||||
6
setup.py
6
setup.py
@ -114,9 +114,9 @@ else:
|
||||
'libs/libnvinfer_plugin_tensorrt_llm.so',
|
||||
'libs/libtensorrt_llm_ucx_wrapper.so', 'libs/libdecoder_attention_0.so',
|
||||
'libs/libtensorrt_llm_nixl_wrapper.so', 'libs/nixl/**/*',
|
||||
'libs/ucx/**/*', 'libs/libpg_utils.so',
|
||||
'libs/libdecoder_attention_1.so', 'libs/nvshmem/License.txt',
|
||||
'libs/nvshmem/nvshmem_bootstrap_uid.so.3',
|
||||
'libs/libtensorrt_llm_mooncake_wrapper.so', 'libs/ucx/**/*',
|
||||
'libs/libpg_utils.so', 'libs/libdecoder_attention_1.so',
|
||||
'libs/nvshmem/License.txt', 'libs/nvshmem/nvshmem_bootstrap_uid.so.3',
|
||||
'libs/nvshmem/nvshmem_transport_ibgda.so.103', 'bindings.*.so',
|
||||
'deep_ep/LICENSE', 'deep_ep/*.py', 'deep_ep_cpp_tllm.*.so',
|
||||
"include/**/*", 'deep_gemm/LICENSE', 'deep_gemm/include/**/*',
|
||||
|
||||
@ -42,6 +42,7 @@ def create_kv_cache_transceiver(
|
||||
cache_transceiver_config.backend = "NIXL"
|
||||
# Ordered by priority
|
||||
env_vars = [("TRTLLM_USE_UCX_KVCACHE", "UCX"),
|
||||
("TRTLLM_USE_MOONCAKE_KVCACHE", "MOONCAKE"),
|
||||
("TRTLLM_USE_MPI_KVCACHE", "MPI")]
|
||||
for env_var, be_type in env_vars:
|
||||
if getenv(env_var) == "1":
|
||||
|
||||
@ -1739,10 +1739,11 @@ class CacheTransceiverConfig(StrictBaseModel, PybindMirror):
|
||||
Configuration for the cache transceiver.
|
||||
"""
|
||||
|
||||
backend: Optional[Literal["DEFAULT", "UCX", "NIXL", "MPI"]] = Field(
|
||||
default=None,
|
||||
description=
|
||||
"The communication backend type to use for the cache transceiver.")
|
||||
backend: Optional[Literal[
|
||||
"DEFAULT", "UCX", "NIXL", "MOONCAKE", "MPI"]] = Field(
|
||||
default=None,
|
||||
description=
|
||||
"The communication backend type to use for the cache transceiver.")
|
||||
|
||||
max_tokens_in_buffer: Optional[int] = Field(
|
||||
default=None,
|
||||
|
||||
@ -25,6 +25,7 @@ class KVCacheType(Enum):
|
||||
MPI = auto()
|
||||
UCX = auto()
|
||||
NIXL = auto()
|
||||
MOONCAKE = auto()
|
||||
|
||||
|
||||
def get_multi_gpu_env(kv_cache_type=KVCacheType.NONE, llama_multi_gpu=False):
|
||||
@ -37,6 +38,9 @@ def get_multi_gpu_env(kv_cache_type=KVCacheType.NONE, llama_multi_gpu=False):
|
||||
env["TRTLLM_USE_UCX_KVCACHE"] = "1"
|
||||
case KVCacheType.NIXL:
|
||||
env["TRTLLM_USE_NIXL_KVCACHE"] = "1"
|
||||
case KVCacheType.MOONCAKE:
|
||||
env["TRTLLM_USE_MOONCAKE_KVCACHE"] = "1"
|
||||
env["MC_FORCE_TCP"] = "1"
|
||||
case KVCacheType.NONE:
|
||||
pass
|
||||
case _:
|
||||
@ -502,8 +506,9 @@ def test_fused_gemm_allreduce(build_google_tests, nprocs, build_dir):
|
||||
|
||||
@pytest.mark.parametrize("build_google_tests", ["80", "86", "89", "90"],
|
||||
indirect=True)
|
||||
@pytest.mark.parametrize("kvcache_type", [KVCacheType.NIXL, KVCacheType.UCX],
|
||||
ids=["nixl_kvcache", "ucx_kvcache"])
|
||||
@pytest.mark.parametrize(
|
||||
"kvcache_type", [KVCacheType.NIXL, KVCacheType.UCX, KVCacheType.MOONCAKE],
|
||||
ids=["nixl_kvcache", "ucx_kvcache", "mooncake_kvcache"])
|
||||
@pytest.mark.parametrize("nprocs", [2, 8], ids=["2proc", "8proc"])
|
||||
def test_cache_transceiver(build_google_tests, nprocs, kvcache_type, build_dir):
|
||||
|
||||
|
||||
@ -231,6 +231,7 @@ l0_dgx_h100:
|
||||
- cpp/test_multi_gpu.py::test_cache_transceiver[2proc-ucx_kvcache-90] ISOLATION
|
||||
- cpp/test_multi_gpu.py::test_cache_transceiver[8proc-nixl_kvcache-90] ISOLATION
|
||||
- cpp/test_multi_gpu.py::test_cache_transceiver[8proc-ucx_kvcache-90] ISOLATION
|
||||
- cpp/test_multi_gpu.py::test_cache_transceiver[8proc-mooncake_kvcache-90] ISOLATION
|
||||
- cpp/test_multi_gpu.py::test_user_buffer[2proc-90]
|
||||
- cpp/test_multi_gpu.py::test_enc_dec[t5-90]
|
||||
- cpp/test_multi_gpu.py::test_llama_executor[llama-orchestrator-90]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user