TensorRT-LLMs/cpp/tests/batch_manager/cacheTransceiverTest.cpp
Guoming Zhang 57079cecb3
[None][chroe] Rename TensorRT-LLM to TensorRT LLM for source code. (#7851)
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
2025-09-22 10:05:47 -07:00

1671 lines
74 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
*
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
*/
#define UCX_WRAPPER_LIB_NAME "tensorrt_llm_ucx_wrapper"
#if defined(_WIN32)
#include <windows.h>
#define dllOpen(name) LoadLibrary(name ".dll")
#define dllClose(handle) FreeLibrary(static_cast<HMODULE>(handle))
#define dllGetSym(handle, name) static_cast<void*>(GetProcAddress(static_cast<HMODULE>(handle), name))
#else // For non-Windows platforms
#include <dlfcn.h>
#define dllOpen(name) dlopen("lib" name ".so", RTLD_LAZY)
#define dllClose(handle) dlclose(handle)
#define dllGetSym(handle, name) dlsym(handle, name)
#endif // defined(_WIN32)
#include "tensorrt_llm/batch_manager/cacheFormatter.h"
#include "tensorrt_llm/batch_manager/cacheTransceiver.h"
#include "tensorrt_llm/batch_manager/dataTransceiverImpl.h"
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h"
#include "tensorrt_llm/executor/cache_transmission/mpi_utils/connection.h"
#include "tensorrt_llm/executor/dataTransceiverState.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include <csignal>
#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <memory>
#include <random>
#include <tensorrt_llm/batch_manager/cacheTransBuffer.h>
#include <tensorrt_llm/batch_manager/mlaCacheFormatter.h>
#include <tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h>
#include "gtest/gtest.h"
#include <gmock/gmock.h>
namespace tr = tensorrt_llm::runtime;
using SizeType32 = tensorrt_llm::runtime::SizeType32;
using LlmRequest = tensorrt_llm::batch_manager::LlmRequest;
using namespace tensorrt_llm::batch_manager::kv_cache_manager;
using namespace tensorrt_llm::batch_manager;
namespace texec = tensorrt_llm::executor;
using testing::Return;
using testing::ReturnRef;
// ---------------------------------------
// RequestInfoTest
// ---------------------------------------
namespace
{
std::mutex mDllMutex;
template <typename T>
T serializeDeserialize(T const& val)
{
auto size = T::serializedSize(val);
std::ostringstream oss;
T::serialize(val, oss);
EXPECT_EQ(oss.str().size(), size);
std::istringstream iss(oss.str());
return T::deserialize(iss);
}
} // namespace
class RequestInfoTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro-type-member-init)
{
public:
void SetUp() override {}
void TearDown() override {}
};
TEST_F(RequestInfoTest, Basic)
{
if (tensorrt_llm::mpi::MpiComm::world().getSize() > 2)
{
GTEST_SKIP() << "mpirun with procs<=2 is required to run this test.";
}
auto state = std::make_unique<texec::DataTransceiverState>();
state->setCommState(texec::kv_cache::CommState{12, "127.0.0.1"});
state->setCacheState(texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, nvinfer1::DataType::kFLOAT});
RequestInfo info{1, *state};
auto info2 = serializeDeserialize(info);
EXPECT_EQ(info, info2);
}
// ---------------------------------------
// CacheConfigTest
// ---------------------------------------
class CacheConfigTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro-type-member-init)
{
public:
void SetUp() override {}
void TearDown() override {}
};
TEST_F(CacheConfigTest, EqualTo)
{
if (tensorrt_llm::mpi::MpiComm::world().getSize() > 2)
{
GTEST_SKIP() << "mpirun with procs<=2 is required to run this test.";
}
using tensorrt_llm::executor::kv_cache::CacheState;
constexpr SizeType32 vocabSize{25};
constexpr SizeType32 nbAttentionLayers{10};
constexpr SizeType32 nbRnnLayers{2};
constexpr SizeType32 nbHeads{12};
constexpr SizeType32 hiddenSize{768};
constexpr nvinfer1::DataType dtype{nvinfer1::DataType::kFLOAT};
constexpr SizeType32 tokensPerBlock{64};
constexpr SizeType32 tensorParallelism{8};
constexpr SizeType32 pipelineParallelism{2};
constexpr SizeType32 contextParallelism{1};
constexpr SizeType32 sizePerHead{hiddenSize / nbHeads};
constexpr CacheState::AttentionType attentionType{CacheState::AttentionType::kDEFAULT};
constexpr int kvFactor = 2;
tr::ModelConfig modelConfig{
vocabSize, nbAttentionLayers + nbRnnLayers, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, dtype};
modelConfig.setTokensPerBlock(tokensPerBlock);
tr::WorldConfig worldConfig{tensorParallelism, pipelineParallelism, contextParallelism};
texec::kv_cache::CacheState::ModelConfig cacheStateCfg{
modelConfig.getNumKvHeadsPerLayer(), modelConfig.getSizePerHead(), modelConfig.getTokensPerBlock()};
texec::kv_cache::CacheState state0{
cacheStateCfg, worldConfig, modelConfig.getKvDataType(), attentionType, kvFactor};
texec::kv_cache::CacheState state1{nbAttentionLayers, nbHeads, sizePerHead, tokensPerBlock, tensorParallelism,
pipelineParallelism, dtype, attentionType, kvFactor, false, 0, tensorParallelism};
EXPECT_EQ(state0, state1);
}
// ---------------------------------------
// MockTransceiverTest
// ---------------------------------------
class MockDataSender : public DataSender
{
public:
MockDataSender()
{
ON_CALL(*this, getCommState).WillByDefault(ReturnRef(mState));
ON_CALL(*this, recvRequestInfo)
.WillByDefault(Return(RequestInfo{0,
texec::DataTransceiverState{
texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, nvinfer1::DataType::kFLOAT},
texec::kv_cache::CommState{std::vector<SizeType32>{0}, 0}}}));
ON_CALL(*this, getCounterpartsCount).WillByDefault(Return(1));
}
MOCK_METHOD(RequestInfo, recvRequestInfo, (), (override));
MOCK_METHOD(void, sendSync, (LlmRequest const&), (override));
MOCK_METHOD(texec::kv_cache::CommState const&, getCommState, (), (const));
MOCK_METHOD(void, setCommState, (texec::kv_cache::CommState), (override));
MOCK_METHOD(size_t, getCounterpartsCount, (LlmRequest::RequestIdType), (const));
MOCK_METHOD(void, release, (LlmRequest::RequestIdType), (override));
private:
static texec::kv_cache::CommState mState;
};
texec::kv_cache::CommState MockDataSender::mState;
class MockDataReceiver : public DataReceiver
{
public:
MOCK_METHOD(TransferSession, sendRequestInfo, (LlmRequest const&), (override));
MOCK_METHOD(void, receiveSync, (TransferSession&), (override));
};
class MockTransceiverTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro-type-member-init)
{
public:
void SetUp() override {}
void TearDown() override {}
static auto makeLlmRequest(
LlmRequest::RequestIdType requestId = 0, SizeType32 maxNewTokens = 1, VecTokens inputTokens = {-1})
{
texec::Request request{std::move(inputTokens), maxNewTokens};
auto state = std::make_unique<texec::DataTransceiverState>();
auto stats = texec::ContextPhaseParams({}, requestId, state.release(), std::nullopt);
request.setContextPhaseParams(std::move(stats));
return std::make_unique<LlmRequest>(requestId, std::move(request));
}
};
TEST_F(MockTransceiverTest, MpiResponderBasic)
{
if (tensorrt_llm::mpi::MpiComm::world().getSize() > 2)
{
GTEST_SKIP() << "mpirun with procs<=2 is required to run this test.";
}
auto sender = std::make_unique<MockDataSender>();
EXPECT_CALL(*sender, recvRequestInfo)
.WillOnce(Return(RequestInfo{0,
texec::DataTransceiverState{texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, nvinfer1::DataType::kFLOAT},
texec::kv_cache::CommState{std::vector<SizeType32>{0}, 0}}}));
EXPECT_CALL(*sender, sendSync).WillOnce(Return());
EXPECT_CALL(*sender, getCounterpartsCount).WillOnce(Return(1));
EXPECT_CALL(*sender, release).WillOnce(Return());
DataResponder responder{std::move(sender)};
auto request = makeLlmRequest(0);
auto future = responder.respondAndSendAsync(*request);
future.get();
}
TEST_F(MockTransceiverTest, MpiRequesterBasic)
{
if (tensorrt_llm::mpi::MpiComm::world().getSize() > 2)
{
GTEST_SKIP() << "mpirun with procs<=2 is required to run this test.";
}
auto receiver = std::make_unique<MockDataReceiver>();
auto state = std::make_unique<texec::DataTransceiverState>();
state->setCommState(texec::kv_cache::CommState{std::vector<int>{0}});
EXPECT_CALL(*receiver, sendRequestInfo)
.WillOnce(Return(TransferSession({nullptr}, DataContext{0}, *state, *state,
tensorrt_llm::runtime::BufferManager{std::make_shared<tr::CudaStream>()}, nullptr)));
EXPECT_CALL(*receiver, receiveSync).WillOnce(Return());
DataRequester requester{std::move(receiver)};
auto request = makeLlmRequest(0);
auto stats = texec::ContextPhaseParams({}, 0, state.release(), std::nullopt);
request->setContextPhaseParams(std::move(stats));
auto future = requester.requestAndReceiveAsync(*request);
future.get();
}
// TODO: Restore multi-rank tests.
// ---------------------------------------
// RealTransceiverTest
// ---------------------------------------
class SymmetricalCacheTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro-type-member-init)
{
protected:
void SetUp() override {}
void TearDown() override
{
for (auto& future : mFutures)
{
if (future.valid())
{
future.get();
}
}
}
SizeType32 setUpCommunicator()
{
tensorrt_llm::mpi::initialize(tensorrt_llm::mpi::MpiThreadSupport::THREAD_MULTIPLE);
mComm = std::addressof(tensorrt_llm::mpi::MpiComm::world());
mWorldSize = mComm->getSize();
mlocalRank = mComm->getRank() / 2;
isSender = mComm->getRank() % 2 == 0;
tensorrt_llm::mpi::MpiComm::setSession(mComm->split(static_cast<int>(isSender), mlocalRank));
return mWorldSize;
}
void setUpCacheManager()
{
auto constexpr numLayers = 4;
auto constexpr numHeads = 2;
auto constexpr sizePerHead = 64;
auto constexpr hiddenSize = numHeads * sizePerHead;
auto constexpr tokensPerBlock = 8;
auto constexpr maxBlocksPerSeq = 10;
auto constexpr maxBeamWidth = 4;
auto constexpr sinkTokenLength = 0;
mMaxNumSequences = 8;
auto const stream = std::make_shared<tr::CudaStream>();
auto constexpr maxNumTokens = tokensPerBlock * maxBlocksPerSeq;
auto constexpr maxAttentionWindow = maxNumTokens;
auto constexpr inputLength = maxNumTokens - tokensPerBlock - 1;
auto constexpr numSharedBlocks = inputLength / tokensPerBlock;
auto constexpr numBlocksPerSeq = numSharedBlocks + (maxBlocksPerSeq - numSharedBlocks) * maxBeamWidth;
auto totalNumBlocks = mMaxNumSequences * numBlocksPerSeq;
auto constexpr blocksInSecondaryPool = 0;
auto constexpr enableBlockReuse = true;
auto constexpr onboardBlocks = true;
auto constexpr dataType = nvinfer1::DataType::kFLOAT;
using BlocksPerWindow = std::map<SizeType32, std::tuple<SizeType32, SizeType32>>;
auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {totalNumBlocks, blocksInSecondaryPool}}};
mManager = std::make_unique<KVCacheManager>(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow,
mMaxNumSequences, maxBeamWidth, std::vector<BlockManager::SizeType32>{maxAttentionWindow}, std::nullopt,
dataType, sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks, CacheType::kSELF,
std::nullopt, nullptr, true);
mCacheState = std::make_unique<texec::kv_cache::CacheState>(
numLayers, numHeads, sizePerHead, tokensPerBlock, 1, 1, dataType);
if (tensorrt_llm::common::getEnvUseUCXKvCache())
{
std::lock_guard<std::mutex> lock(mDllMutex);
void* WrapperLibHandle{nullptr};
WrapperLibHandle = dllOpen(UCX_WRAPPER_LIB_NAME);
TLLM_CHECK_WITH_INFO(WrapperLibHandle != nullptr, "UCX wrapper library is not open correctly.");
auto load_sym = [](void* handle, char const* name)
{
void* ret = dllGetSym(handle, name);
TLLM_CHECK_WITH_INFO(ret != nullptr,
"Unable to load UCX wrapper library symbol, possible cause is that TensorRT LLM library is not "
"built with UCX support, please rebuild in UCX-enabled environment.");
return ret;
};
std::unique_ptr<tensorrt_llm::executor::kv_cache::ConnectionManager> (*makeUcxConnectionManager)();
*(void**) (&makeUcxConnectionManager) = load_sym(WrapperLibHandle, "makeUcxConnectionManager");
mConnectionManager = makeUcxConnectionManager();
auto commState = mConnectionManager->getCommState();
namespace su = tensorrt_llm::executor::serialize_utils;
if (tensorrt_llm::mpi::MpiComm::world().getRank() == 0)
{
std::ostringstream oStream;
su::serialize(commState, oStream);
auto str = oStream.str();
std::vector<char> buffer(str.begin(), str.end());
int genRank = 1;
int64_t bufferSize = buffer.size();
TLLM_LOG_DEBUG(
tensorrt_llm::mpi::MpiComm::world().getRank(), "send bufferSize: %ld to %d", bufferSize, genRank);
tensorrt_llm::mpi::MpiComm::world().sendRawTag(
&bufferSize, 1, tensorrt_llm::mpi::MpiType::kINT64, genRank, 0x1F);
tensorrt_llm::mpi::MpiComm::world().sendRawTag(
buffer.data(), buffer.size(), tensorrt_llm::mpi::MpiType::kCHAR, genRank, 0x2F);
TLLM_LOG_DEBUG(tensorrt_llm::mpi::MpiComm::world().getRank(), "send buffer to %d", genRank);
mContextCommState = std::make_unique<tensorrt_llm::executor::kv_cache::CommState>(commState);
}
else
{
int64_t bufferSize;
tensorrt_llm::mpi::MpiComm::world().recvRawTag(
&bufferSize, 1, tensorrt_llm::mpi::MpiType::kINT64, 0, 0x1F);
TLLM_LOG_DEBUG(
tensorrt_llm::mpi::MpiComm::world().getRank(), "recv bufferSize: %ld from 0", bufferSize);
std::vector<char> recvBuffer(bufferSize);
tensorrt_llm::mpi::MpiComm::world().recvRawTag(
recvBuffer.data(), bufferSize, tensorrt_llm::mpi::MpiType::kCHAR, 0, 0x2F);
TLLM_LOG_DEBUG(tensorrt_llm::mpi::MpiComm::world().getRank(), "recv buffer from 0", bufferSize);
std::istringstream iStream(std::string(recvBuffer.begin(), recvBuffer.end()));
su::VectorWrapBuf<char> strbuf(recvBuffer);
std::istream is(&strbuf);
mContextCommState = std::make_unique<tensorrt_llm::executor::kv_cache::CommState>(
su::deserialize<tensorrt_llm::executor::kv_cache::CommState>(is));
}
}
else
{
mConnectionManager = std::make_unique<texec::kv_cache::MpiConnectionManager>(mComm);
mContextCommState
= std::make_unique<texec::kv_cache::CommState>(texec::kv_cache::CommState{std::vector<int>{0}});
}
// UVM seems to be incompatible with MPI, and it is continuing to investigate.
bool constexpr useUvm = false;
mManager->allocatePools(useUvm);
}
void setUpCacheTransceiver()
{
int maxNumTokens = 1024;
mCacheTransBufferManager = std::make_unique<CacheTransBufferManager>(mManager.get(), maxNumTokens);
if (isSender)
{
mResponder = std::make_unique<DataResponder>(
std::make_unique<DataSenderImpl>(mConnectionManager.get(), *mCacheState, mlocalRank,
std::make_unique<CacheFormatter>(mManager.get(), mCacheTransBufferManager.get())));
}
else
{
mRequester = std::make_unique<DataRequester>(
std::make_unique<DataReceiverImpl>(mConnectionManager.get(), *mCacheState, mlocalRank,
std::make_unique<CacheFormatter>(mManager.get(), mCacheTransBufferManager.get())));
}
}
auto makeLlmRequest(SizeType32 length)
{
constexpr SizeType32 maxNewTokens{1};
// create request with tokens [length, ..., length] (<length> tokens)
texec::Request request{VecTokens(length, length), maxNewTokens};
auto state = std::make_unique<texec::DataTransceiverState>();
state->setCommState(*mContextCommState);
state->setCacheState(*mCacheState);
auto stats = texec::ContextPhaseParams({}, mRequestId, state.release(), std::nullopt);
request.setContextPhaseParams(std::move(stats));
return std::make_unique<LlmRequest>(mRequestId++, std::move(request));
}
void addRequestAndTransportCache(std::shared_ptr<LlmRequest> const& llmRequest)
{
auto constexpr beamIdx{0};
auto constexpr beamWidth{1};
mManager->addSequence(llmRequest->mRequestId, llmRequest->getNumTokens(beamIdx), beamWidth, llmRequest);
if (isSender)
{
auto blockRange = BlockRange::fromAllBlockIds(*mManager, llmRequest->mRequestId);
for (auto& block : blockRange)
{
// fill cache with tokens (= request length), for reuse test
TLLM_CUDA_CHECK(cudaMemset(block.data(), llmRequest->getPromptLen(), block.getSizeInBytes()));
}
mFutures.emplace_back(mResponder->respondAndSendAsync(*llmRequest));
}
else
{
auto future = mRequester->requestAndReceiveAsync(*llmRequest);
future.get();
TLLM_CUDA_CHECK(cudaDeviceSynchronize());
auto blockRange = BlockRange::fromAllBlockIds(*mManager, llmRequest->mRequestId);
for (auto& block : blockRange)
{
std::vector<uint8_t> bytes(block.getSizeInBytes());
TLLM_CUDA_CHECK(cudaMemcpy(bytes.data(), block.data(), block.getSizeInBytes(), cudaMemcpyDeviceToHost));
EXPECT_TRUE(std::all_of(bytes.begin(), bytes.end(),
[&llmRequest](uint8_t i) { return i == llmRequest->getPromptLen() & 0xff; }));
}
}
}
bool isSender{false};
tensorrt_llm::mpi::MpiComm const* mComm;
SizeType32 mWorldSize{0}, mlocalRank{0};
LlmRequest::RequestIdType mRequestId{0};
SizeType32 mMaxNumSequences{};
std::unique_ptr<KVCacheManager> mManager;
std::unique_ptr<CacheTransBufferManager> mCacheTransBufferManager;
std::unique_ptr<DataResponder> mResponder;
std::unique_ptr<DataRequester> mRequester;
std::unique_ptr<texec::kv_cache::CacheState> mCacheState;
std::unique_ptr<texec::kv_cache::CommState> mContextCommState;
std::vector<std::future<void>> mFutures;
std::unique_ptr<texec::kv_cache::ConnectionManager> mConnectionManager;
};
TEST_F(SymmetricalCacheTest, SimpleTest)
{
auto worldSize = setUpCommunicator();
if (worldSize != 2)
{
GTEST_SKIP() << "mpirun 2 processes is required to run this test.";
}
setUpCacheManager();
setUpCacheTransceiver();
std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>> requests;
for (auto len : {10, 20, 30})
{
requests.emplace_back(makeLlmRequest(len));
addRequestAndTransportCache(requests.back());
}
for (auto& future : mFutures)
{
future.get();
}
mFutures.clear();
for (auto& request : requests)
{
mManager->removeSequence(request->mRequestId, request);
}
requests.clear();
// test reuse
for (auto len : {10, 20, 30})
{
requests.emplace_back(makeLlmRequest(len));
addRequestAndTransportCache(requests.back());
}
for (auto& future : mFutures)
{
future.get();
}
}
#if ENABLE_MULTI_DEVICE
using AsymmetricTestParam
= std::tuple<int, int, int, int, int, int, int, int, nvinfer1::DataType, int, bool, bool, bool, bool>;
class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestParam>
{
protected:
void SetUp() override {}
void TearDown() override {}
void setUpCommunicator(int contextTp, int contextPp, int genTp, int genPp, bool isMLA = false,
bool contextDP = false, bool generationDP = false)
{
#if ENABLE_MULTI_DEVICE
tensorrt_llm::mpi::initialize(tensorrt_llm::mpi::MpiThreadSupport::THREAD_MULTIPLE);
if (tensorrt_llm::mpi::MpiComm::world().getSize() != 8)
{
GTEST_SKIP() << "mpirun with procs=8 is required to run this test.";
}
int worldSize = tensorrt_llm::mpi::MpiComm::world().getSize();
int worldRank = tensorrt_llm::mpi::MpiComm::world().getRank();
tensorrt_llm::mpi::MpiComm::world().barrier();
int contextRanks = contextTp * contextPp;
int genRanks = genTp * genPp;
int nprocs = (contextRanks + genRanks);
mIsContext = false;
mIsGeneration = false;
mParticipatingComm = tensorrt_llm::mpi::MpiComm::world().split(static_cast<int>(worldRank < nprocs), worldRank);
tensorrt_llm::mpi::MpiComm::setSession(
tensorrt_llm::mpi::MpiComm::world().split(static_cast<int>(worldRank < nprocs), worldRank));
mIsContext = worldRank < contextRanks;
mIsGeneration = (worldRank >= contextRanks && worldRank < (contextRanks + genRanks));
if (worldRank >= nprocs)
{
return;
}
TLLM_LOG_INFO("Run cacheTransceiverTest for ContextTp: %d, ContextPp: %d, GenTp: %d, GenPp:%d", contextTp,
contextPp, genTp, genPp);
mComm = std::addressof(mParticipatingComm);
mWorldSize = mComm->getSize();
mRank = mComm->getRank();
{
mIsContext = mRank < contextRanks;
mIsGeneration = (mRank >= contextRanks && mRank < (contextRanks + genRanks));
mRankInInstance = mIsContext ? mRank : (mRank - contextRanks);
mSizeInInstance = mIsContext ? (contextTp * contextPp) : (genTp * genPp);
int color = 0;
if (mIsGeneration)
{
color = 1;
}
if (mIsContext)
{
color = 2;
}
auto sessionComm = mComm->split(static_cast<int>(color), mComm->getRank());
if (mIsContext)
{
mTpSize = contextTp;
mPpSize = contextPp;
}
if (mIsGeneration)
{
mTpSize = genTp;
mPpSize = genPp;
}
mTpRank = mRankInInstance % mTpSize;
mPpRank = mRankInInstance / mTpSize;
mContextRankSize = contextRanks;
mGenRankSize = genRanks;
mContextTpSize = contextTp;
mContextPpSize = contextPp;
EXPECT_EQ((sessionComm.getRank()), mRankInInstance);
EXPECT_EQ(sessionComm.getSize(), mSizeInInstance);
mContextDP = contextDP;
mGenerationDP = generationDP;
mIsMLA = isMLA;
tensorrt_llm::mpi::MpiComm::setSession(std::move(sessionComm));
}
#else
GTEST_SKIP() << "ENABLE_MULTI_DEVICE is required to run this test.";
#endif
}
void setUpCacheManager(int numLayers, int numHeads, int sizePerHead, int tokensPerBlock,
nvinfer1::DataType dataType, int kvFactor = 2, bool isMLA = false, bool enableDPAttention = false,
bool isWindow = false)
{
mIsWindowAttention = isWindow;
if (!(mIsContext || mIsGeneration))
{
return;
}
ASSERT_EQ(numLayers % mPpSize, 0);
if (!isMLA)
{
// ASSERT_EQ(numHeads % mTpSize , 0);
ASSERT_TRUE(numHeads % mTpSize == 0 || mTpSize % numHeads == 0);
}
else
{
ASSERT_EQ(numHeads, 1);
}
int numHeadsPerRank = (numHeads + mTpSize - 1) / mTpSize;
mDupHeadFactor = 1;
if (mTpSize > numHeads)
{
mDupHeadFactor = mTpSize / numHeads;
ASSERT_EQ(numHeadsPerRank, 1);
}
if (isMLA || enableDPAttention)
{
numHeadsPerRank = numHeads;
mDupHeadFactor = 1;
}
auto hiddenSize = numHeadsPerRank * sizePerHead;
auto maxBlocksPerSeq = 10;
auto maxBeamWidth = 1;
auto constexpr sinkTokenLength = 0;
mMaxNumSequences = 8;
auto const stream = std::make_shared<tr::CudaStream>();
auto maxNumTokens = tokensPerBlock * maxBlocksPerSeq;
auto windowAttentionToken = 2 * tokensPerBlock;
auto maxAttentionWindow = maxNumTokens;
auto inputLength = maxNumTokens - tokensPerBlock - 1;
auto numSharedBlocks = inputLength / tokensPerBlock;
auto numBlocksPerSeq = numSharedBlocks + (maxBlocksPerSeq - numSharedBlocks) * maxBeamWidth;
auto totalNumBlocks = mMaxNumSequences * numBlocksPerSeq;
auto constexpr blocksInSecondaryPool = 0;
auto constexpr enableBlockReuse = true;
auto constexpr onboardBlocks = true;
CacheType cacheType = CacheType::kSELF;
if (kvFactor == 1)
{
auto cacheType = CacheType::kSELFKONLY;
}
TLLM_CHECK(kvFactor == 2 || kvFactor == 1);
int DPrank = 0;
int DPsize = 0;
if (mIsContext)
{
enableDPAttention = mContextDP;
DPrank = mTpRank; // need to be changed in making the llmRequest
DPsize = mTpSize;
}
if (mIsGeneration)
{
enableDPAttention = mGenerationDP;
DPrank = mTpRank;
DPsize = mTpSize;
}
int numHeadsPerRankForContext = (numHeads + mContextTpSize - 1) / mContextTpSize;
if (isMLA || mContextDP)
{
numHeadsPerRankForContext = numHeads;
}
using BlocksPerWindow = std::map<SizeType32, std::tuple<SizeType32, SizeType32>>;
auto blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {totalNumBlocks, blocksInSecondaryPool}}};
std::vector<SizeType32> maxAttentionWindowVec{};
maxAttentionWindowVec.push_back(maxAttentionWindow);
if (mIsWindowAttention)
{
auto attentionNumBlocks = 2 * mMaxNumSequences;
blocksPerWindow[windowAttentionToken] = {attentionNumBlocks, blocksInSecondaryPool};
maxAttentionWindowVec.push_back(windowAttentionToken);
}
TLLM_LOG_DEBUG(" cacheManager isWindowAttention: %d", mIsWindowAttention);
mManager = std::make_unique<KVCacheManager>(numLayers / mPpSize, numHeadsPerRank, sizePerHead, tokensPerBlock,
blocksPerWindow, mMaxNumSequences, maxBeamWidth, maxAttentionWindowVec, std::nullopt, dataType,
sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks, cacheType, std::nullopt, nullptr,
true);
texec::kv_cache::CacheState::AttentionType attentionType = isMLA
? texec::kv_cache::CacheState::AttentionType::kMLA
: texec::kv_cache::CacheState::AttentionType::kDEFAULT;
mCacheState = std::make_unique<texec::kv_cache::CacheState>(numLayers, numHeadsPerRank, sizePerHead,
tokensPerBlock, mTpSize, mPpSize, dataType, attentionType, kvFactor, enableDPAttention, DPrank, DPsize);
mContextCacheState = std::make_unique<texec::kv_cache::CacheState>(numLayers, numHeadsPerRankForContext,
sizePerHead, tokensPerBlock, mContextTpSize, mContextPpSize, dataType, attentionType, kvFactor, mContextDP,
DPrank, mContextTpSize);
// UVM seems to be incompatible with MPI, and it is continuing to investigate.
bool constexpr useUvm = false;
mManager->allocatePools(useUvm);
}
void setUpCacheTransceiver()
{
if (!(mIsContext || mIsGeneration))
{
return;
}
else if (tensorrt_llm::common::getEnvUseMPIKvCache() || tensorrt_llm::common::getEnvUseUCXKvCache()
|| tensorrt_llm::common::getEnvUseNixlKvCache())
{
int maxNumTokens = 2048;
mCacheTransBufferManager = std::make_unique<CacheTransBufferManager>(mManager.get(), maxNumTokens);
bool isUcx = tensorrt_llm::common::getEnvUseUCXKvCache();
bool isNixl = tensorrt_llm::common::getEnvUseNixlKvCache();
TLLM_LOG_INFO("Enable %s KV cache transport.", isUcx ? "UCX" : isNixl ? "NIXL" : "MPI");
if (isUcx)
{
std::lock_guard<std::mutex> lock(mDllMutex);
void* WrapperLibHandle = dllOpen(UCX_WRAPPER_LIB_NAME);
TLLM_CHECK_WITH_INFO(
WrapperLibHandle != nullptr, "UCX wrapper library is not open correctly. dlerror: %s", dlerror());
auto load_sym = [](void* handle, char const* name)
{
void* ret = dllGetSym(handle, name);
TLLM_CHECK_WITH_INFO(ret != nullptr,
"Unable to load UCX wrapper library symbol, possible cause is that TensorRT LLM library is not "
"built with UCX support, please rebuild in UCX-enabled environment.");
return ret;
};
std::unique_ptr<tensorrt_llm::executor::kv_cache::ConnectionManager> (*makeUcxConnectionManager)();
*(void**) (&makeUcxConnectionManager) = load_sym(WrapperLibHandle, "makeUcxConnectionManager");
mConnectionManager = makeUcxConnectionManager();
}
else if (isNixl)
{
constexpr auto port = 22345;
setenv("TRTLLM_NIXL_PORT", std::to_string(port).c_str(), 1);
mConnectionManager
= std::make_unique<texec::kv_cache::AgentConnectionManager>(mCacheTransBufferManager.get());
}
else
{
mConnectionManager = std::make_unique<texec::kv_cache::MpiConnectionManager>(mComm);
}
auto makeFormatter
= [this]() { return createCacheFormatter(mManager.get(), mCacheTransBufferManager.get(), mIsMLA); };
if (mIsContext)
{
mResponder = std::make_unique<DataResponder>(std::make_unique<DataSenderImpl>(
mConnectionManager.get(), *mCacheState, mRankInInstance, makeFormatter()));
}
else
{
mRequester = std::make_unique<DataRequester>(std::make_unique<DataReceiverImpl>(
mConnectionManager.get(), *mCacheState, mRankInInstance, makeFormatter()));
}
std::vector<int> contextRankVec(mContextRankSize);
std::iota(contextRankVec.begin(), contextRankVec.end(), 0);
if (isUcx || isNixl)
{
auto commState = mConnectionManager->getCommState();
namespace su = tensorrt_llm::executor::serialize_utils;
if (tensorrt_llm::mpi::MpiComm::world().getRank() == 0)
{
std::ostringstream oStream;
su::serialize(commState, oStream);
auto str = oStream.str();
std::vector<char> buffer(str.begin(), str.end());
for (int genRank = mContextRankSize; genRank < mContextRankSize + mGenRankSize; genRank++)
{
int64_t bufferSize = buffer.size();
TLLM_LOG_DEBUG(tensorrt_llm::mpi::MpiComm::world().getRank(), "send bufferSize: %ld to %d",
bufferSize, genRank);
tensorrt_llm::mpi::MpiComm::world().sendRawTag(
&bufferSize, 1, tensorrt_llm::mpi::MpiType::kINT64, genRank, 0x1F);
tensorrt_llm::mpi::MpiComm::world().sendRawTag(
buffer.data(), buffer.size(), tensorrt_llm::mpi::MpiType::kCHAR, genRank, 0x2F);
TLLM_LOG_DEBUG(tensorrt_llm::mpi::MpiComm::world().getRank(), "send buffer to %d", genRank);
}
}
if (mIsGeneration)
{
int64_t bufferSize;
tensorrt_llm::mpi::MpiComm::world().recvRawTag(
&bufferSize, 1, tensorrt_llm::mpi::MpiType::kINT64, 0, 0x1F);
TLLM_LOG_DEBUG(
tensorrt_llm::mpi::MpiComm::world().getRank(), "recv bufferSize: %ld from 0", bufferSize);
std::vector<char> recvBuffer(bufferSize);
tensorrt_llm::mpi::MpiComm::world().recvRawTag(
recvBuffer.data(), bufferSize, tensorrt_llm::mpi::MpiType::kCHAR, 0, 0x2F);
TLLM_LOG_DEBUG(tensorrt_llm::mpi::MpiComm::world().getRank(), "recv buffer from 0", bufferSize);
std::istringstream iStream(std::string(recvBuffer.begin(), recvBuffer.end()));
su::VectorWrapBuf<char> strbuf(recvBuffer);
std::istream is(&strbuf);
mContextCommState = std::make_unique<tensorrt_llm::executor::kv_cache::CommState>(
su::deserialize<tensorrt_llm::executor::kv_cache::CommState>(is));
}
if (mIsContext)
{
mContextCommState = std::make_unique<tensorrt_llm::executor::kv_cache::CommState>(commState);
}
TLLM_LOG_INFO(tensorrt_llm::mpi::MpiComm::world().getRank(), "mContextCommState: %s",
mContextCommState->toString().c_str());
}
else
{
mContextCommState = std::make_unique<tensorrt_llm::executor::kv_cache::CommState>(contextRankVec);
}
}
else
{
TLLM_CHECK_WITH_INFO(false, "Please set at least one cache transfer backend");
}
}
auto makeLlmRequest(SizeType32 length)
{
constexpr SizeType32 maxNewTokens{1};
texec::Request request{VecTokens(length, length), maxNewTokens};
auto state = std::make_unique<texec::DataTransceiverState>();
TLLM_CHECK(mContextCommState);
state->setCommState(texec::kv_cache::CommState{*mContextCommState});
state->setCacheState(*mContextCacheState);
auto stats = texec::ContextPhaseParams({}, mRequestId, state.release(), std::nullopt);
request.setContextPhaseParams(std::move(stats));
return std::make_unique<LlmRequest>(mRequestId++, std::move(request));
}
auto makeLlmRequestWithDP(SizeType32 length, LlmRequest::RequestIdType requestId, int contextDpRank)
{
constexpr SizeType32 maxNewTokens{1};
texec::Request request{VecTokens(length), maxNewTokens};
auto state = std::make_unique<texec::DataTransceiverState>();
state->setCommState(texec::kv_cache::CommState{*mContextCommState});
texec::kv_cache::CacheState cacheState{mContextCacheState->getModelConfig().mNbKvHeadsPerLayer,
mContextCacheState->getModelConfig().mSizePerHead, mContextCacheState->getModelConfig().mTokensPerBlock,
mContextCacheState->getParallelConfig().mTensorParallelism,
mContextCacheState->getParallelConfig().mPipelineParallelism, mContextCacheState->getDataType(),
mContextCacheState->getAttentionConfig().mAttentionType, mContextCacheState->getAttentionConfig().mKvFactor,
mContextCacheState->getParallelConfig().mEnableAttentionDP, contextDpRank,
mContextCacheState->getParallelConfig().mTensorParallelism};
state->setCacheState(cacheState);
auto stats = texec::ContextPhaseParams({}, requestId, state.release(), std::nullopt);
request.setContextPhaseParams(std::move(stats));
return std::make_unique<LlmRequest>(requestId, std::move(request));
}
std::future<void> addRequestAndTransportCacheForContext(std::shared_ptr<LlmRequest> const& llmRequest)
{
auto constexpr beamIdx{0};
auto constexpr beamWidth{1};
mManager->addSequence(llmRequest->mRequestId, llmRequest->getNumTokens(beamIdx), beamWidth, llmRequest);
auto blockRange = BlockRange::fromAllBlockIds(*mManager, llmRequest->mRequestId);
int blockIdx = 0;
int const numPools = mManager->getBlockManager().getNumPools();
TLLM_LOG_DEBUG(" addRequestAndTransportCacheForContext mManager numPools: %d", numPools);
for (int poolIdx = 0; poolIdx < numPools; poolIdx++)
{
blockRange.updatePoolIdx(poolIdx);
TLLM_LOG_DEBUG("update poolIdx: %d", poolIdx);
for (auto& block : blockRange)
{
fillBlockData(block, blockIdx, llmRequest->getPromptLen(), poolIdx);
blockIdx++;
}
TLLM_LOG_DEBUG("blockPoolIdx: %d finish fill block data", poolIdx);
}
TLLM_LOG_DEBUG(
"addRequestAndTransportCacheForContext blockManager numPools: %d finish fill block data", numPools);
auto const& blockManager = mManager->getBlockManager();
auto const onlyWindowSize = blockManager.getPoolWindowSize(0);
blockManager.getBufferManager(onlyWindowSize).getStream().synchronize();
auto future = mResponder->respondAndSendAsync(*llmRequest);
return future;
}
std::future<void> addRequestAndTransportCacheForGeneration(std::shared_ptr<LlmRequest> const& llmRequest)
{
auto constexpr beamIdx{0};
auto constexpr beamWidth{1};
mManager->addSequence(llmRequest->mRequestId, llmRequest->getNumTokens(beamIdx), beamWidth, llmRequest);
return mRequester->requestAndReceiveAsync(*llmRequest);
}
void generationVerifyKVCache(std::shared_ptr<LlmRequest> const& llmRequest)
{
auto constexpr beamIdx{0};
auto constexpr beamWidth{1};
int blockIdx = 0;
TLLM_CUDA_CHECK(cudaDeviceSynchronize());
auto blockRange = BlockRange::fromAllBlockIds(*mManager, llmRequest->mRequestId);
auto const numPools = mManager->getBlockManager().getNumPools();
for (int poolIdx = 0; poolIdx < numPools; poolIdx++)
{
blockRange.updatePoolIdx(poolIdx);
for (auto& block : blockRange)
{
verifyBlockData(block, blockIdx, llmRequest->getPromptLen(), poolIdx);
blockIdx++;
}
}
}
void fillBlockData(tensorrt_llm::runtime::ITensor& blockData, int blockId, size_t initial, int blockPoolIdx = 0)
{
auto const& blockManager = mManager->getBlockManager();
auto const onlyWindowSize = blockManager.getPoolWindowSize(blockPoolIdx);
auto const& bufferManager = blockManager.getBufferManager(onlyWindowSize);
auto hostTensor = tensorrt_llm::runtime::BufferManager::cpu(blockData.getShape(), blockData.getDataType());
int layerSizePerRank = blockData.getDimension<1>();
int startLayerId = layerSizePerRank * mPpRank;
int headSizePerRank = mCacheState->getModelConfig().mNbKvHeadsPerLayer.at(0);
int startHeadId = headSizePerRank * (mTpRank / mDupHeadFactor);
bool enableDP = mCacheState->getParallelConfig().mEnableAttentionDP;
if (mIsMLA || enableDP)
{
startHeadId = 0;
}
int kvFactor = mCacheState->getAttentionConfig().mKvFactor;
int tokensPerBlock = mCacheState->getModelConfig().mTokensPerBlock;
int startTokenId = blockId * tokensPerBlock;
int sizePerHead = mCacheState->getModelConfig().mSizePerHead;
auto dataTypeSize = tensorrt_llm::common::getDTypeSize(blockData.getDataType());
for (int layerId = 0; layerId < layerSizePerRank; layerId++)
{
for (int headId = 0; headId < headSizePerRank; headId++)
{
for (int tokenId = 0; tokenId < tokensPerBlock; tokenId++)
{
for (int hiddenId = 0; hiddenId < sizePerHead; hiddenId++)
{
size_t keyIndex = layerId * (kvFactor * headSizePerRank * tokensPerBlock * sizePerHead)
+ headId * (tokensPerBlock * sizePerHead) + tokenId * sizePerHead + hiddenId;
size_t valueIndex
= keyIndex + static_cast<size_t>(headSizePerRank * tokensPerBlock * sizePerHead);
std::visit(
[&](auto generateValue)
{
using ValueType = decltype(generateValue);
auto* dataPtr = static_cast<ValueType*>(hostTensor->data(keyIndex));
*dataPtr = generateValue;
},
generateExpectedValue(initial, blockPoolIdx, tokenId + startTokenId, layerId + startLayerId,
headId + startHeadId, hiddenId, true, blockData.getDataType()));
if (kvFactor == 2)
{
std::visit(
[&](auto generateValue)
{
using ValueType = decltype(generateValue);
auto* dataPtr = static_cast<ValueType*>(hostTensor->data(valueIndex));
*dataPtr = generateValue;
},
generateExpectedValue(initial, blockPoolIdx, tokenId + startTokenId,
layerId + startLayerId, headId + startHeadId, hiddenId, false,
blockData.getDataType()));
}
}
}
}
}
bufferManager.copy(*hostTensor, blockData);
bufferManager.getStream().synchronize();
}
void verifyBlockData(tensorrt_llm::runtime::ITensor& blockData, int blockId, size_t initial, int blockPoolIdx = 0)
{
auto const& blockManager = mManager->getBlockManager();
auto const onlyWindowSize = blockManager.getPoolWindowSize(blockPoolIdx);
auto const& bufferManager = blockManager.getBufferManager(onlyWindowSize);
auto hostTensor = tensorrt_llm::runtime::BufferManager::cpu(blockData.getShape(), blockData.getDataType());
int layerSizePerRank = blockData.getDimension<1>();
int startLayerId = layerSizePerRank * mPpRank;
int headSizePerRank = mCacheState->getModelConfig().mNbKvHeadsPerLayer.at(0);
int startHeadId = headSizePerRank * (mTpRank / mDupHeadFactor);
bool enableDP = mCacheState->getParallelConfig().mEnableAttentionDP;
if (mIsMLA || enableDP)
{
startHeadId = 0;
}
int kvFactor = mCacheState->getAttentionConfig().mKvFactor;
int tokensPerBlock = mCacheState->getModelConfig().mTokensPerBlock;
int startTokenId = blockId * tokensPerBlock;
int sizePerHead = mCacheState->getModelConfig().mSizePerHead;
bufferManager.copy(blockData, *hostTensor);
bufferManager.getStream().synchronize();
for (int layerId = 0; layerId < layerSizePerRank; layerId++)
{
for (int headId = 0; headId < headSizePerRank; headId++)
{
for (int tokenId = 0; tokenId < tokensPerBlock; tokenId++)
{
for (int hiddenId = 0; hiddenId < sizePerHead; hiddenId++)
{
size_t keyIndex = layerId * (kvFactor * headSizePerRank * tokensPerBlock * sizePerHead)
+ headId * (tokensPerBlock * sizePerHead) + tokenId * sizePerHead + hiddenId;
size_t valueIndex
= keyIndex + static_cast<size_t>(headSizePerRank * tokensPerBlock * sizePerHead);
std::visit(
[&](auto generateValue)
{
using ValueType = decltype(generateValue);
auto* dataPtr = static_cast<ValueType*>(hostTensor->data(keyIndex));
EXPECT_EQ(*dataPtr, generateValue);
},
generateExpectedValue(initial, blockPoolIdx, tokenId + startTokenId, layerId + startLayerId,
headId + startHeadId, hiddenId, true, blockData.getDataType()));
if (kvFactor == 2)
{
std::visit(
[&](auto generateValue)
{
using ValueType = decltype(generateValue);
auto* dataPtr = static_cast<ValueType*>(hostTensor->data(valueIndex));
EXPECT_EQ(*dataPtr, generateValue);
},
generateExpectedValue(initial, blockPoolIdx, tokenId + startTokenId,
layerId + startLayerId, headId + startHeadId, hiddenId, false,
blockData.getDataType()));
}
}
}
}
}
}
std::variant<double, float, int16_t, int8_t> generateExpectedValue(size_t initial, int blockPoolIdx, int tokenId,
int layerId, int headId, int hiddenId, bool key, nvinfer1::DataType dataType)
{
size_t seed = 0;
std::size_t hashValue = std::hash<size_t>{}(initial);
std::hash<int> hasher{};
seed ^= hashValue + 0x9e3779b9 + (seed << 6) + (seed >> 2);
seed ^= hasher(blockPoolIdx) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
seed ^= hasher(tokenId) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
seed ^= hasher(layerId) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
seed ^= hasher(headId) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
seed ^= hasher(hiddenId) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
seed += key;
generator.seed(seed);
std::uniform_real_distribution<double> dis(-100.0f, 100.0f);
double value = dis(generator);
auto dataTypeSize = tensorrt_llm::common::getDTypeSize(dataType);
switch (dataTypeSize)
{
case 8: return value; break;
case 4: return static_cast<float>(value); break;
case 2: return static_cast<int16_t>(value); break;
case 1: return static_cast<int8_t>(value); break;
default: TLLM_CHECK_WITH_INFO(false, "generateExpectedValue only support dataTypeSize in [8,4,2,1]"); break;
};
return 0.F;
}
bool mIsContext{false};
bool mIsGeneration{false};
tensorrt_llm::mpi::MpiComm const* mComm;
tensorrt_llm::mpi::MpiComm mParticipatingComm{nullptr, false};
SizeType32 mWorldSize{0}, mRank{0}, mRankInInstance{0};
SizeType32 mSizeInInstance{0}, mTpRank{0}, mPpRank{0}, mTpSize{0}, mPpSize{0}, mContextRankSize{0}, mGenRankSize{0},
mContextTpSize{0}, mContextPpSize{0};
LlmRequest::RequestIdType mRequestId{0};
bool mContextDP{false};
bool mGenerationDP{false};
bool mIsMLA{false};
bool mIsWindowAttention{false};
int mDupHeadFactor{1};
SizeType32 mMaxNumSequences{};
std::unique_ptr<KVCacheManager> mManager;
std::unique_ptr<CacheTransBufferManager> mCacheTransBufferManager;
std::unique_ptr<DataResponder> mResponder;
std::unique_ptr<DataRequester> mRequester;
std::unique_ptr<texec::kv_cache::CacheState> mCacheState;
std::unique_ptr<texec::kv_cache::CacheState> mContextCacheState;
std::unique_ptr<texec::kv_cache::CommState> mContextCommState;
std::unique_ptr<texec::kv_cache::ConnectionManager> mConnectionManager;
std::mt19937 generator;
};
TEST_P(AsymmetricalCacheTest, TestCase)
{
if (!(tensorrt_llm::common::getEnvUseUCXKvCache()))
{
setenv("UCX_TLS", "^cuda_ipc", 1); // disable cuda_ipc for testing for mpi
}
else
{
setenv("UCX_TCP_CM_REUSEADDR", "y",
1); // tests creates and destroies ucxCacheCommunicatoers frequently, so listener ports must be reused
}
AsymmetricTestParam param = GetParam();
int contextTp = std::get<0>(param);
int contextPp = std::get<1>(param);
int genTp = std::get<2>(param);
int genPp = std::get<3>(param);
int numLayers = std::get<4>(param);
int numHeads = std::get<5>(param);
int sizePerHead = std::get<6>(param);
int tokensPerBlock = std::get<7>(param);
nvinfer1::DataType dataType = std::get<8>(param);
int kvFactor = std::get<9>(param);
bool isMLA = std::get<10>(param);
bool contextDP = std::get<11>(param);
bool generationDP = std::get<12>(param);
bool isWindow = std::get<13>(param);
setUpCommunicator(contextTp, contextPp, genTp, genPp, isMLA, contextDP, generationDP);
if (mIsContext || mIsGeneration)
{
setUpCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, dataType, kvFactor, isMLA, false, isWindow);
setUpCacheTransceiver();
std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>> requests;
// the second loop is for cache reuse
for (int i = 0; i < 2; i++)
{
for (auto len : {30, 10, 60, 80})
{
requests.emplace_back(makeLlmRequest(len));
}
if (mIsContext)
{
std::vector<std::future<void>> contextFutures;
for (auto&& request : requests)
{
contextFutures.push_back(addRequestAndTransportCacheForContext(request));
}
mComm->barrier();
for (auto&& cfuture : contextFutures)
{
cfuture.get();
}
}
else
{
std::vector<std::future<void>> generationFutures;
mComm->barrier();
for (auto&& request : requests)
{
generationFutures.push_back(addRequestAndTransportCacheForGeneration(request));
}
for (auto&& gfuture : generationFutures)
{
gfuture.get();
}
for (auto&& request : requests)
{
generationVerifyKVCache(request);
}
}
for (auto&& request : requests)
{
mManager->removeSequence(request->mRequestId, request);
}
requests.clear();
mComm->barrier();
}
}
tensorrt_llm::mpi::MpiComm::world().barrier();
}
class AsymmetricalCacheTestWithDP : public AsymmetricalCacheTest
{
};
TEST_P(AsymmetricalCacheTestWithDP, TestCase)
{
if (!(tensorrt_llm::common::getEnvUseUCXKvCache()))
{
setenv("UCX_TLS", "^cuda_ipc", 1); // disable cuda_ipc for testing for mpi
}
else
{
setenv("UCX_TCP_CM_REUSEADDR", "y",
1); // tests creates and destroies ucxCacheCommunicatoers frequently, so listener ports must be reused
}
AsymmetricTestParam param = GetParam();
int contextTp = std::get<0>(param);
int contextPp = std::get<1>(param);
int genTp = std::get<2>(param);
int genPp = std::get<3>(param);
int numLayers = std::get<4>(param);
int numHeads = std::get<5>(param);
int sizePerHead = std::get<6>(param);
int tokensPerBlock = std::get<7>(param);
nvinfer1::DataType dataType = std::get<8>(param);
int kvFactor = std::get<9>(param);
bool isMLA = std::get<10>(param);
bool contextDP = std::get<11>(param);
bool generationDP = std::get<12>(param);
bool isWindow = std::get<13>(param);
setUpCommunicator(contextTp, contextPp, genTp, genPp, isMLA, contextDP, generationDP);
if (mIsContext || mIsGeneration)
{
bool enableDP = mIsContext ? contextDP : generationDP;
setUpCacheManager(
numLayers, numHeads, sizePerHead, tokensPerBlock, dataType, kvFactor, isMLA, enableDP, isWindow);
setUpCacheTransceiver();
std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>> requests;
int requestId = 0;
for (auto len : {30, 10, 60, 30, 60, 10})
{
requests.emplace_back(makeLlmRequestWithDP(len, requestId, requestId % contextTp));
requestId++;
}
std::vector<std::future<void>> contextFutures;
std::vector<std::future<void>> generationFutures;
std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>> generationRequests;
if (mIsContext)
{
std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>> contextRequests;
if (contextDP)
{
for (int i = 0; i < requests.size(); i++)
{
if ((i) % mTpSize == mTpRank)
{
// round robin
contextRequests.push_back(requests[i]);
}
}
}
else
{
contextRequests = requests;
}
for (auto&& request : contextRequests)
{
contextFutures.push_back(std::move(addRequestAndTransportCacheForContext(request)));
}
mComm->barrier();
}
else
{
if (generationDP)
{
for (int i = 0; i < requests.size(); i++)
{
if ((i) % mTpSize == mTpRank)
{
generationRequests.push_back(requests[i]);
}
}
}
else
{
generationRequests = requests;
}
mComm->barrier();
for (auto&& request : generationRequests)
{
generationFutures.push_back(std::move(addRequestAndTransportCacheForGeneration(request)));
}
}
if (mIsContext)
{
for (auto&& cfuture : contextFutures)
{
cfuture.get();
}
}
else
{
for (auto&& gfuture : generationFutures)
{
gfuture.get();
}
for (auto&& request : generationRequests)
{
generationVerifyKVCache(request);
}
}
mComm->barrier();
}
tensorrt_llm::mpi::MpiComm::world().barrier();
}
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest0, AsymmetricalCacheTest,
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2),
testing::Values(4), testing::Values(4), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(true, false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithWindow, AsymmetricalCacheTest,
testing::Combine(testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(5),
testing::Values(4), testing::Values(4), testing::Values(8),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(true)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest1, AsymmetricalCacheTest,
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(8),
testing::Values(4), testing::Values(4), testing::Values(8),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(false, true)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest2, AsymmetricalCacheTest,
testing::Combine(testing::Values(1), testing::Values(2), testing::Values(1), testing::Values(1, 4),
testing::Values(16), testing::Values(16), testing::Values(4), testing::Values(8),
testing::Values(nvinfer1::DataType::kFLOAT), testing::Values(2), testing::Values(false), testing::Values(false),
testing::Values(false), testing::Values(false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest0ForMLA, AsymmetricalCacheTest,
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2),
testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
testing::Values(true), testing::Values(false), testing::Values(false), testing::Values(false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest1ForMLA, AsymmetricalCacheTest,
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(4),
testing::Values(1), testing::Values(4), testing::Values(8),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
testing::Values(true), testing::Values(false), testing::Values(false), testing::Values(false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA1, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2),
testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
testing::Values(true), testing::Values(true), testing::Values(true), testing::Values(false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA2, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2),
testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
testing::Values(true), testing::Values(true), testing::Values(false), testing::Values(false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA3, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2),
testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
testing::Values(true), testing::Values(false), testing::Values(true), testing::Values(false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLA, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2),
testing::Values(4), testing::Values(4), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Values(false), testing::Values(true), testing::Values(true), testing::Values(false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLA1, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2),
testing::Values(4), testing::Values(4), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Values(false), testing::Values(true), testing::Values(false), testing::Values(false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLA2, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2),
testing::Values(4), testing::Values(4), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Values(false), testing::Values(false), testing::Values(true), testing::Values(false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate0, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(4), testing::Values(1),
testing::Values(4), testing::Values(2), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Values(false), testing::Values(true, false), testing::Values(false), testing::Values(false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate1, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(2), testing::Values(2),
testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Values(false), testing::Values(true, false), testing::Values(false), testing::Values(false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate2, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(4, 2), testing::Values(1),
testing::Values(4), testing::Values(2), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate4, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1, 2), testing::Values(2),
testing::Values(4), testing::Values(1, 2), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(false)));
#endif
TEST(targetTest, CacheStateNODP)
{
int const numLayers = 16;
int const numHeads = 2;
int const sizePerHead = 64;
int const tokensPerBlock = 64;
auto const dataType = nvinfer1::DataType::kFLOAT;
bool const isMLA = true;
int const kvFactor = 2;
int contextPP = 2;
int contextTP = 4;
int genPP = 2;
int genTP = 2;
bool const contextEnableDP = false;
bool const genEnableDP = false;
auto const verifyContext = [&](int contextRank, std::vector<int> const& expectRanks, int expectPPDomain,
int expectTPDomain, bool expectNeedSend)
{
auto attentionType = isMLA ? texec::kv_cache::CacheState::AttentionType::kMLA
: texec::kv_cache::CacheState::AttentionType::kDEFAULT;
auto const contextCache = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead,
tokensPerBlock, contextTP, contextPP, dataType, attentionType, kvFactor, contextEnableDP, 0, 0};
auto const genCache = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead,
tokensPerBlock, genTP, genPP, dataType, attentionType, kvFactor, genEnableDP, 0, 0};
auto const contextTragetInfo
= tensorrt_llm::executor::kv_cache::TargetRanksInfoForDP(genCache, contextCache, contextRank);
EXPECT_EQ(expectRanks, contextTragetInfo.mIRanks);
EXPECT_EQ(expectPPDomain, contextTragetInfo.mDomainPPSize);
EXPECT_EQ(expectTPDomain, contextTragetInfo.mDomainTPSize);
EXPECT_EQ(expectNeedSend, MLACacheFormatter::needSendCache(contextCache, genCache, contextRank));
};
verifyContext(
/*contextRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 1, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false);
verifyContext(
/*contextRank*/ 2, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 3, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false);
verifyContext(
/*contextRank*/ 4, /*expectRanks*/ {2}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 5, /*expectRanks*/ {2}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false);
verifyContext(
/*contextRank*/ 6, /*expectRanks*/ {3}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 7, /*expectRanks*/ {3}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false);
contextTP = 2;
genTP = 4;
verifyContext(
/*contextRank*/ 0, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2, /*expectNeedSend*/ true);
verifyContext(/*contextRank*/ 1, /*expectRanks*/ {2, 3}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2,
/*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 2, /*expectRanks*/ {4, 5}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2, /*expectNeedSend*/ true);
verifyContext(/*contextRank*/ 3, /*expectRanks*/ {6, 7}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2,
/*expectNeedSend*/ true);
contextPP = 1;
verifyContext(
/*contextRank*/ 0, /*expectRanks*/ {0, 4, 1, 5}, /*expectPPDomain*/ 2, /*expectTPDomain*/ 2,
/*expectNeedSend*/ true);
verifyContext(/*contextRank*/ 1, /*expectRanks*/ {2, 6, 3, 7}, /*expectPPDomain*/ 2, /*expectTPDomain*/ 2,
/*expectNeedSend*/ true);
}
TEST(targetTest, CacheStateContextDP)
{
int const numLayers = 16;
int const numHeads = 2;
int const sizePerHead = 64;
int const tokensPerBlock = 64;
auto const dataType = nvinfer1::DataType::kFLOAT;
bool const isMLA = true;
int const kvFactor = 2;
int contextPP = 1;
int contextTP = 4;
int genPP = 1;
int genTP = 2;
bool contextEnableDP = true;
bool genEnableDP = true;
auto const verifyContext = [&](int contextRank, int generationRank, std::vector<int> const& expectRanks,
int expectPPDomain, int expectTPDomain, bool expectNeedSend)
{
int contextDPRank = contextRank % contextTP;
int generationDPRank = generationRank % genTP;
auto attentionType = isMLA ? texec::kv_cache::CacheState::AttentionType::kMLA
: texec::kv_cache::CacheState::AttentionType::kDEFAULT;
auto const contextCache
= tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, tokensPerBlock, contextTP,
contextPP, dataType, attentionType, kvFactor, contextEnableDP, contextDPRank, contextTP};
auto const genCache = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead,
tokensPerBlock, genTP, genPP, dataType, attentionType, kvFactor, genEnableDP, generationDPRank, genTP};
auto const contextTragetInfo
= tensorrt_llm::executor::kv_cache::TargetRanksInfoForDP(genCache, contextCache, contextRank);
EXPECT_EQ(expectRanks, contextTragetInfo.mIRanks);
EXPECT_EQ(expectPPDomain, contextTragetInfo.mDomainPPSize);
EXPECT_EQ(expectTPDomain, contextTragetInfo.mDomainTPSize);
EXPECT_EQ(expectNeedSend, MLACacheFormatter::needSendCache(contextCache, genCache, contextRank));
};
verifyContext(
/*contextRank*/ 0, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
/*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 0, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
/*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 1, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
/*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 1, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
/*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 2, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
/*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 2, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
/*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 3, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
/*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 3, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
/*expectNeedSend*/ true);
contextEnableDP = false;
verifyContext(
/*contextRank*/ 0, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
/*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 0, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
/*expectNeedSend*/ false);
verifyContext(
/*contextRank*/ 1, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
/*expectNeedSend*/ false);
verifyContext(
/*contextRank*/ 1, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
/*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 2, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
/*expectNeedSend*/ false);
verifyContext(
/*contextRank*/ 2, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
/*expectNeedSend*/ false);
verifyContext(
/*contextRank*/ 3, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
/*expectNeedSend*/ false);
verifyContext(
/*contextRank*/ 3, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
/*expectNeedSend*/ false);
contextEnableDP = true;
genEnableDP = false;
verifyContext(
/*contextRank*/ 0, /*generationRank*/ 0, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2,
/*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 0, /*generationRank*/ 1, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2,
/*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 1, /*generationRank*/ 0, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2,
/*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 1, /*generationRank*/ 1, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2,
/*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 2, /*generationRank*/ 0, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2,
/*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 2, /*generationRank*/ 1, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2,
/*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 3, /*generationRank*/ 0, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2,
/*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 3, /*generationRank*/ 1, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2,
/*expectNeedSend*/ true);
contextTP = 1;
genTP = 2;
auto const verfiyGeneration = [&](int contextRank, int generationRank, std::vector<int> const& expectRanks,
int expectPPDomain, int expectTPDomain)
{
int contextDPRank = contextRank % contextTP;
int generationDPRank = generationRank % genTP;
auto attentionType = isMLA ? texec::kv_cache::CacheState::AttentionType::kMLA
: texec::kv_cache::CacheState::AttentionType::kDEFAULT;
auto const contextCache
= tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, tokensPerBlock, contextTP,
contextPP, dataType, attentionType, kvFactor, contextEnableDP, contextDPRank, contextTP};
auto const genCache = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead,
tokensPerBlock, genTP, genPP, dataType, attentionType, kvFactor, genEnableDP, generationDPRank, genTP};
auto const contextTragetInfo
= tensorrt_llm::executor::kv_cache::TargetRanksInfoForDP(contextCache, genCache, generationRank);
EXPECT_EQ(expectRanks, contextTragetInfo.mIRanks);
EXPECT_EQ(expectPPDomain, contextTragetInfo.mDomainPPSize);
EXPECT_EQ(expectTPDomain, contextTragetInfo.mDomainTPSize);
};
verfiyGeneration(
/*contextRank*/ 0, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1);
verfiyGeneration(
/*contextRank*/ 0, /*generationRank*/ 1, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1);
contextTP = 1;
contextPP = 1;
genTP = 1;
genPP = 2;
verfiyGeneration(
/*contextRank*/ 0, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1);
verfiyGeneration(
/*contextRank*/ 0, /*generationRank*/ 1, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1);
genEnableDP = false;
contextEnableDP = true;
contextTP = 2;
contextPP = 1;
genTP = 1;
genPP = 1;
verfiyGeneration(
/*contextRank*/ 0, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1);
verfiyGeneration(
/*contextRank*/ 1, /*generationRank*/ 0, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1);
}