TensorRT-LLMs/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Iman Tabrizian 356a52edf5
[None][feat] Add support for KVCache reuse for DSv32 (#9383)
Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
2025-12-02 11:14:30 +08:00

2994 lines
124 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/batch_manager/common.h"
#include "tensorrt_llm/batch_manager/evictionPolicy.h"
#include "tensorrt_llm/batch_manager/kvCacheTransferManager.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/kernels/kvCacheIndex.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include "tensorrt_llm/runtime/worldConfig.h"
#include <algorithm>
#include <map>
#include <optional>
#include <utility>
namespace tc = tensorrt_llm::common;
namespace tk = tensorrt_llm::kernels;
namespace tle = tensorrt_llm::executor;
using namespace tle::kv_cache;
using namespace tensorrt_llm::runtime;
using namespace tensorrt_llm::batch_manager::kv_cache_manager;
using namespace tensorrt_llm::batch_manager::eviction_policy;
using BlocksPerWindow = std::map<SizeType32, std::tuple<SizeType32, SizeType32>>;
namespace
{
inline uint8_t getNthByte(SizeType32 hashPart, uint8_t byteIdx) noexcept
{
return static_cast<uint8_t>((hashPart >> (24 - byteIdx * 8)) & 0xFF);
}
//! \brief Get all blocks in a sequence by traversing backwards from the last block.
//! \param lastBlock is a BlockPtr to the last block in the sequence to start traversal from
//! \return Vector of BlockPtr-s in sequence order
std::vector<BlockPtr> getAllSequenceBlocks(BlockPtr lastBlock)
{
// First count the number of blocks to pre-allocate the vector
auto currentBlock = lastBlock;
size_t blockCount = 0;
while (currentBlock != nullptr && currentBlock->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
{
blockCount++;
currentBlock = currentBlock->getPrevBlockInSeq();
}
if (blockCount == 0)
{
return {};
}
// Create and pre-allocate the vector with the correct size
std::vector<BlockPtr> sequenceBlocks(blockCount);
// Now traverse backwards and fill from the end
currentBlock = lastBlock;
size_t currentIndex = blockCount - 1;
while (currentBlock != nullptr && currentBlock->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
{
sequenceBlocks[currentIndex--] = currentBlock;
currentBlock = currentBlock->getPrevBlockInSeq();
}
return sequenceBlocks;
}
} // namespace
namespace tensorrt_llm::batch_manager::kv_cache_manager
{
std::vector<MmKey> generateBlockHashExtraKeys(
tensorrt_llm::batch_manager::LlmRequest const& llmRequest, SizeType32 startTokenIdx, SizeType32 endTokenIdx)
{
auto const multimodalHashes = llmRequest.getMultimodalHashes();
auto const multimodalPositions = llmRequest.getMultimodalPositions();
auto const multimodalLengths = llmRequest.getMultimodalLengths();
if (!multimodalHashes || !multimodalPositions || !multimodalLengths || !(*multimodalHashes)
|| (*multimodalHashes)->empty() || !(*multimodalPositions) || (*multimodalPositions)->empty()
|| !(*multimodalLengths) || (*multimodalLengths)->empty())
{
return {};
}
if ((*multimodalHashes)->size() != (*multimodalPositions)->size()
|| (*multimodalPositions)->size() != (*multimodalLengths)->size())
{
TLLM_LOG_WARNING("Multimodal data arrays have mismatched sizes");
return {};
}
std::vector<MmKey> extraKeys; // MmKey = std::pair<std::array<uint8_t, 32>, SizeType32>
extraKeys.reserve((*multimodalPositions)->size());
std::array<uint8_t, 32> mmHashArray;
for (size_t i = 0; i < (*multimodalPositions)->size(); ++i)
{
auto const& startPos = (*(*multimodalPositions))[i];
auto const& length = (*(*multimodalLengths))[i];
auto const& mmHashVector = (*(*multimodalHashes))[i];
TLLM_CHECK_WITH_INFO(mmHashVector.size() == 8, "Multimodal hash vector has unexpected size: %zu (expected 8)",
mmHashVector.size());
// mmHashVector[j] comes from Python's int(hex_chunk, 16)
// where hex_chunk like "00010203" means 0x00 is MSB and 0x03 is LSB (big endian)
// Convert 8x 32-bit integers into a 32-byte array preserving Blake3 hash byte order
// Example: hashPart = 0x00010203 → mmHashArray[0:3] = [0x00, 0x01, 0x02, 0x03]
for (size_t j = 0; j < 8; ++j)
{
auto const& hashPart = mmHashVector[j];
for (uint8_t byteIdx = 0; byteIdx < 4; ++byteIdx)
{
mmHashArray[j * 4 + byteIdx] = getNthByte(hashPart, byteIdx);
}
}
// Check if this multimodal content overlaps with the current block
if (endTokenIdx > startPos && startTokenIdx < startPos + length)
{
uint64_t mmStartInBlock = (startPos >= startTokenIdx) ? 0 : static_cast<uint64_t>(startTokenIdx - startPos);
extraKeys.emplace_back(mmHashArray, mmStartInBlock);
}
}
return extraKeys;
}
std::vector<BlockKey> buildBlockKeys(
std::list<VecUniqueTokens>& blockedUniqueTokens, tensorrt_llm::batch_manager::LlmRequest const& llmRequest)
{
std::vector<BlockKey> blockKeys;
SizeType32 currentTokenIdx = 0;
for (auto& uniqueTokens : blockedUniqueTokens)
{
auto extraKeys = generateBlockHashExtraKeys(llmRequest, currentTokenIdx, currentTokenIdx + uniqueTokens.size());
currentTokenIdx += uniqueTokens.size();
blockKeys.emplace_back(llmRequest.getInputTokensExtraIds().has_value(), llmRequest.getLoraTaskId(),
std::move(uniqueTokens), std::move(extraKeys), llmRequest.getCacheSaltID());
}
return blockKeys;
}
bool BlockKey::operator==(BlockKey const& other) const noexcept
{
return (usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId && uniqueTokens == other.uniqueTokens
&& extraKeys == other.extraKeys && cacheSaltID == other.cacheSaltID);
}
size_t BlockKeyHasher::hash(BlockKey const& blockKey, std::size_t parentHash) noexcept
{
// Hashing algorithm adapted from StackOverflow:
// https://stackoverflow.com/questions/664014/what-integer-hash-function-are-good-that-accepts-an-integer-hash-key
// Constants provide very good distribution - each input bit affects each output bit with ~50% probability.
size_t seed = blockKey.uniqueTokens.size() ^ parentHash * UINT64_C(0xbf58476d1ce4e5b9);
if (parentHash == 0 && blockKey.cacheSaltID)
{
// Only hashing the cache salt ID for the first block in the sequence
uint64_t c = blockKey.cacheSaltID.value();
c = (c ^ (c >> 30)) * UINT64_C(0xbf58476d1ce4e5b9);
c = (c ^ (c >> 27)) * UINT64_C(0x94d049bb133111eb);
c = c ^ (c >> 31);
seed ^= c + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
for (auto const& uniqueToken : blockKey.uniqueTokens)
{
uint32_t a = static_cast<uint32_t>(uniqueToken.tokenId);
a = ((a >> 16) ^ a) * 0x45d9f3b;
a = ((a >> 16) ^ a) * 0x45d9f3b;
a = (a >> 16) ^ a;
seed ^= a + 0x9e3779b9 + (seed << 6) + (seed >> 2);
if (blockKey.usesExtraIds)
{
uint64_t b = uniqueToken.tokenExtraId;
b = (b ^ (b >> 30)) * UINT64_C(0xbf58476d1ce4e5b9);
b = (b ^ (b >> 27)) * UINT64_C(0x94d049bb133111eb);
b = b ^ (b >> 31);
seed ^= b + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
}
if (blockKey.loraTaskId)
{
uint64_t c = blockKey.loraTaskId.value();
c = (c ^ (c >> 30)) * UINT64_C(0xbf58476d1ce4e5b9);
c = (c ^ (c >> 27)) * UINT64_C(0x94d049bb133111eb);
c = c ^ (c >> 31);
seed ^= c + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
// Add extra keys for multimodal data mixing in external multimodal item hash and token offset within this sequence
// block
if (!blockKey.extraKeys.empty())
{
for (auto const& [mmHash, startOffset] : blockKey.extraKeys)
{
// Hash the multimodal hash array in 32-bit chunks (more efficient)
for (size_t i = 0; i < 32; i += 4)
{
// Combine 4 bytes into a 32-bit word (construct as little endian order)
uint32_t word = static_cast<uint32_t>(mmHash[i]) | (static_cast<uint32_t>(mmHash[i + 1]) << 8)
| (static_cast<uint32_t>(mmHash[i + 2]) << 16) | (static_cast<uint32_t>(mmHash[i + 3]) << 24);
// Mix the word into the seed
word = ((word >> 16) ^ word) * 0x45d9f3b;
word = ((word >> 16) ^ word) * 0x45d9f3b;
word = (word >> 16) ^ word;
seed ^= word + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
// Hash the start offset
uint64_t e = static_cast<uint64_t>(startOffset);
e = (e ^ (e >> 30)) * UINT64_C(0xbf58476d1ce4e5b9);
e = (e ^ (e >> 27)) * UINT64_C(0x94d049bb133111eb);
e = e ^ (e >> 31);
seed ^= e + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
}
return seed;
}
KVCacheBlock::KVCacheBlock(IdType blockId, tk::KVCacheIndex blockIdx)
: mBlockId(blockId)
, mMemoryPoolBlockIndex{blockIdx}
, mRefCount(0)
, mSchedulingRefCount(0)
, mPrevBlock(nullptr)
, mFreeBlockIterator(std::nullopt)
, mIsFull{false}
, mPriority{executor::KvCacheRetentionConfig::kDefaultRetentionPriority}
, mDurationMs{std::nullopt}
, mExpirationTime{std::nullopt}
, mHash{0}
{
}
void KVCacheBlock::startScheduling()
{
mSchedulingRefCount = mRefCount;
}
KVCacheBlock::IdType KVCacheBlock::getBlockId() const
{
return mBlockId;
}
NextBlockMap KVCacheBlock::getNextBlocks() const
{
return mNextBlocks;
}
tk::KVCacheIndex::UnderlyingType KVCacheBlock::getMemoryPoolBlockIndex() const
{
return mMemoryPoolBlockIndex.get();
}
bool KVCacheBlock::isPrimary() const
{
return mMemoryPoolBlockIndex.isPrimary();
}
void KVCacheBlock::swapMemoryPoolBlockOffset(std::shared_ptr<KVCacheBlock> otherBlock)
{
std::swap(mMemoryPoolBlockIndex, otherBlock->mMemoryPoolBlockIndex);
}
void KVCacheBlock::incRefCount()
{
mRefCount++;
}
void KVCacheBlock::decRefCount()
{
TLLM_CHECK_WITH_INFO(
hasRefs(), "Can't remove link from block (id=%d) that is not allocated", static_cast<int>(mBlockId));
mRefCount--;
}
void KVCacheBlock::decSchedulingRefCount()
{
TLLM_CHECK_WITH_INFO(hasSchedulingRefs(), "Can't remove link from block that is not allocated");
mSchedulingRefCount--;
}
bool KVCacheBlock::hasRefs() const
{
return mRefCount > 0;
}
bool KVCacheBlock::isShared() const
{
// block is considered shared if ready for reuse
return mRefCount > 1 || mPrevBlock != nullptr;
}
bool KVCacheBlock::hasSchedulingRefs() const
{
return mSchedulingRefCount > 0;
}
void KVCacheBlock::setBlockKey(BlockKey const& blockKey, bool isFull)
{
mBlockKey = blockKey;
mIsFull = isFull;
}
BlockKey KVCacheBlock::getBlockKey()
{
return mBlockKey;
}
void KVCacheBlock::setPriority(executor::RetentionPriority priority)
{
mPriority = priority;
}
executor::RetentionPriority KVCacheBlock::getPriority() const
{
return mPriority;
}
std::optional<std::chrono::milliseconds> KVCacheBlock::getDurationMs() const
{
return mDurationMs;
}
void KVCacheBlock::setDurationMs(std::optional<std::chrono::milliseconds> durationMs)
{
mDurationMs = durationMs;
}
void KVCacheBlock::setExpirationTime(std::optional<std::chrono::steady_clock::time_point::duration> expirationTime)
{
mExpirationTime = expirationTime;
}
std::optional<std::chrono::steady_clock::time_point::duration> KVCacheBlock::getExpirationTime() const
{
return mExpirationTime;
}
void KVCacheBlock::setHash(size_t hash)
{
mHash = hash;
}
void KVCacheBlock::setHash()
{
mHash = BlockKeyHasher()(mBlockKey, mPrevBlockInSeq ? mPrevBlockInSeq->getHash() : 0);
}
size_t KVCacheBlock::getHash() const
{
return mHash;
}
VecUniqueTokens const& KVCacheBlock::getUniqueTokens() const
{
return mBlockKey.uniqueTokens;
}
BlockPtr const& KVCacheBlock::getPrevBlock() const
{
return mPrevBlock;
}
void KVCacheBlock::setPrevBlock(BlockPtr prevBlock)
{
mPrevBlock = std::move(prevBlock);
}
BlockPtr const& KVCacheBlock::getPrevBlockInSeq() const
{
return mPrevBlockInSeq;
}
void KVCacheBlock::setPrevBlockInSeq(BlockPtr prevBlock)
{
mPrevBlockInSeq = std::move(prevBlock);
}
void KVCacheBlock::addNextBlock(BlockKey const& blockKey, BlockPtr block)
{
if (mNextBlocks.find(blockKey) == mNextBlocks.end())
{
mNextBlocks[blockKey] = std::move(block);
}
}
std::tuple<bool, SizeType32, BlockPtr> KVCacheBlock::findMatchingBlock(
BlockKey const& blockKey, bool enablePartialReuse, bool copyOnPartialReuse) const
{
if (blockKey.uniqueTokens.size() == 0 || mNextBlocks.size() == 0)
{
return {false, 0, nullptr};
}
auto itr = mNextBlocks.find(blockKey);
if (itr == mNextBlocks.end())
{
if (enablePartialReuse)
{
SizeType32 bestNumMatched{0};
BlockPtr bestBlock{nullptr};
for (auto const& [key, block] : mNextBlocks)
{
if (copyOnPartialReuse || (!block->hasRefs() && block->isLeaf()))
{
SizeType32 numMatched = key.partialMatch(blockKey);
if (numMatched > bestNumMatched)
{
bestNumMatched = numMatched;
bestBlock = block;
}
}
}
if (bestNumMatched > 0)
{
return {true, bestNumMatched, bestBlock};
}
}
return {false, 0, nullptr};
}
auto block = itr->second;
return {!block->isFull(), static_cast<SizeType32>(blockKey.uniqueTokens.size()), block};
}
void KVCacheBlock::freeLeafBlock()
{
// assure that this is a leaf block
TLLM_CHECK(isLeaf());
// free from previous block
if (mPrevBlock != nullptr)
{
mPrevBlock->removeNextBlock(mBlockKey);
mPrevBlock = nullptr;
}
}
void KVCacheBlock::removeNextBlock(BlockKey const& blockKey)
{
mNextBlocks.erase(blockKey);
}
bool KVCacheBlock::isFull() const
{
return mIsFull;
}
bool KVCacheBlock::isLeaf() const
{
return mNextBlocks.empty();
}
// This function calculates the number of block a layer should have, given
// the total free memory and the window size of each layer.
// For example, if we have 1 layer of window size 1024, and 2 layer of window
// size 2048, and 3 layers of 4096.
// Each layer of window size 1024 should have
// 1024 / (1024 + 2048 * 2 + 4096 * 3) proportion of the total blocks.
// Each layer of window size 2048 should have
// 2048 / (1024 + 2048 * 2 + 4096 * 3) proportion of the total blocks.
// Each layer of window size 4096 should have
// 4096 / (1024 + 2048 * 2 + 4096 * 3) proportion of the total blocks.
// NOTE: Currently the use of this function is not used for
// BaseKVCacheManager::calculateMaxNumBlocks because the we want to first
// achieve identical performance as assuming all layers as full attention.
std::map<SizeType32, float> BlockManager::calculateWindowSizeToShare(
std::map<SizeType32, std::vector<SizeType32>> const& windowSizeToLayers,
std::map<SizeType32, SizeType32> const& windowSizeToCacheSizePerToken)
{
if (windowSizeToLayers.size() == 1)
{
return {{windowSizeToLayers.begin()->first, 1.0f}};
}
std::map<SizeType32, float> windowSizeToContribution;
SizeType32 cacheSizePerTokenTotal
= std::accumulate(windowSizeToCacheSizePerToken.begin(), windowSizeToCacheSizePerToken.end(), SizeType32{0},
[](auto sum, auto const& windowSize) { return sum + windowSize.second; });
for (auto const& [windowSize, cacheSizePerToken] : windowSizeToCacheSizePerToken)
{
auto const cacheSizeWeight = static_cast<float>(cacheSizePerToken) / cacheSizePerTokenTotal;
windowSizeToContribution[windowSize] = cacheSizeWeight;
}
for (auto const& [windowSize, _] : windowSizeToLayers)
{
windowSizeToContribution.at(windowSize) *= windowSize;
}
auto const windowSizesTotalSum = std::accumulate(windowSizeToContribution.begin(), windowSizeToContribution.end(),
0.0, [](auto sum, auto const& windowSize) { return sum + windowSize.second; });
std::map<SizeType32, float> windowSizeToShare;
for (auto const& [windowSize, windowSizeSum] : windowSizeToContribution)
{
float const fraction = windowSizeSum / windowSizesTotalSum;
TLLM_CHECK(0.0f < fraction && fraction <= 1.0f);
windowSizeToShare[windowSize] = fraction;
}
auto total = std::accumulate(windowSizeToShare.begin(), windowSizeToShare.end(), 0.0f,
[](auto sum, auto const& windowSize) { return sum + windowSize.second; });
TLLM_CHECK(total == 1.0f);
return windowSizeToShare;
}
BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead,
SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences,
std::shared_ptr<runtime::CudaStream> stream, SizeType32 maxSequenceLength, SizeType32 maxBeamWidth,
std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager,
std::optional<BaseAgentConfig> agentConfig, bool enableIndexerKCache, SizeType32 indexerKCacheQuantBlockSize,
SizeType32 indexerKCacheIndexHeadDim)
: mNumLayers{static_cast<SizeType32>(numKvHeadsPerLayer.size())}
, mTokensPerBlock{tokensPerBlock}
, mEventManager{std::move(eventManager)}
, mStream{stream}
, mCacheType{cacheType}
, mIsEnableIndexerKCache{enableIndexerKCache}
, mIndexerKCacheQuantBlockSize{indexerKCacheQuantBlockSize}
, mIndexerKCacheIndexHeadDim{indexerKCacheIndexHeadDim}
{
if (agentConfig.has_value())
mLoopbackAgent = makeLoopbackAgent("nixl", &agentConfig.value());
else
mLoopbackAgent = nullptr;
auto const uniqueWindowSizeToLayers
= BaseKVCacheManager::groupLayersByWindowSize(maxAttentionWindowVec, mNumLayers);
TLLM_CHECK_WITH_INFO(kvCacheConnectorManager == nullptr || uniqueWindowSizeToLayers.size() == 1,
"KV Cache Connector is not supported with multiple window sizes");
auto const numUniqueWindowSizes = static_cast<SizeType32>(uniqueWindowSizeToLayers.size());
mIsVariableWindow = numUniqueWindowSizes > 1;
mIsVariableGQA = std::unordered_set(numKvHeadsPerLayer.begin(), numKvHeadsPerLayer.end()).size() > 1;
mLayerToWindowSize.resize(mNumLayers);
for (auto const& [windowSize, layersWithWindowSize] : uniqueWindowSizeToLayers)
{
if (windowSize > maxSequenceLength)
{
TLLM_LOG_WARNING("[kv cache manager] window size %d is greater than max sequence length %d", windowSize,
maxSequenceLength);
}
for (auto& layerIdx : layersWithWindowSize)
{
mLayerToWindowSize.at(layerIdx) = windowSize;
}
auto const [allottedPrimaryBlocks, allottedSecondaryBlocks] = blocksPerWindow.at(windowSize);
TLLM_CHECK(allottedPrimaryBlocks > 0); // You can't have a model with negative primary blocks...
mWindowBlockManagers.try_emplace(windowSize, dtype, windowSize, layersWithWindowSize, numKvHeadsPerLayer,
sizePerHead, tokensPerBlock, /*isSWA=*/windowSize < maxSequenceLength, allottedPrimaryBlocks,
allottedSecondaryBlocks, maxNumSequences, stream, onboardBlocks, cacheType, secondaryOffloadMinPriority,
mEventManager, enablePartialReuse, copyOnPartialReuse, kvCacheConnectorManager, mLoopbackAgent,
enableIndexerKCache, indexerKCacheQuantBlockSize, indexerKCacheIndexHeadDim);
}
auto const numAllPools = getNumPools();
mAbsolutePoolToWindowSize.reserve(numAllPools);
mAbsolutePoolToRelativePoolIndex.reserve(numAllPools);
auto absolutePoolsOffset = SizeType32{0};
for (auto const& [windowSize, manager] : mWindowBlockManagers)
{
auto const numPools = manager.getNumPools();
for (auto i = 0; i < numPools; ++i)
{
mAbsolutePoolToWindowSize.push_back(windowSize);
mAbsolutePoolToRelativePoolIndex.push_back(i);
}
// SWA allocates blocks linearly, and we need as many blocks as full attention,
// where full attention has windowSize = maxSequenceLength.
auto const maxTokenNum = std::max(windowSize, maxSequenceLength) + sinkBubbleLength;
auto const temporaryAttentionWindow = manager.calculateTemporaryAttentionWindow(tempAttentionWindowInputs);
// Consider the temporaryAttentionWindow when allocating blocks.
// Current tempAttentionWindow calculation does not consider the
// concept of SWA right now at most occupying maxSequenceLength of
// blocks. So the calculation of maxToken + tempAttention will exceed
// maxSequenceLength. A temporary resolution here is to cap the
// calculation to maxSequenceLength. I will proceed with a follow-up
// MR to remove the tempAttentionWindow concept.
auto const maxBlocksPerSeq
= tc::ceilDiv(std::min(maxSequenceLength, maxTokenNum + temporaryAttentionWindow), tokensPerBlock);
auto const [allottedPrimaryBlocks, allottedSecondaryBlocks] = blocksPerWindow.at(windowSize);
mWindowSizeToMetadata[windowSize] = WindowSizeMetadata{allottedPrimaryBlocks, allottedSecondaryBlocks,
absolutePoolsOffset, numPools, maxTokenNum, maxBlocksPerSeq, manager.getMaxNumBlocks(),
temporaryAttentionWindow, windowSize, manager.isSWA()};
TLLM_LOG_INFO(
"Max KV cache blocks per sequence: %d [window size=%d], tokens per block=%d, primary blocks=%d, secondary "
"blocks=%d, max sequence length=%d",
maxBlocksPerSeq, windowSize, tokensPerBlock, allottedPrimaryBlocks, allottedSecondaryBlocks,
maxSequenceLength);
TLLM_LOG_DEBUG(
"%s Metadata: %s", manager.getLogPrefix().c_str(), mWindowSizeToMetadata[windowSize].toString().c_str());
absolutePoolsOffset += numPools;
}
TLLM_CHECK_WITH_INFO(mWindowBlockManagers.size() == mWindowSizeToMetadata.size()
&& std::equal(mWindowBlockManagers.cbegin(), mWindowBlockManagers.cend(), mWindowSizeToMetadata.cbegin(),
mWindowSizeToMetadata.cend(),
[](auto const& window1, auto const& window2) { return window1.first == window2.first; }),
"Iteration order of window sizes between mWindowBlockManagers and mWindowSizeToMetadata *must* be ensured. "
"Maybe you tried changing either of them to an std::unordered_map?");
}
WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 windowSize,
std::vector<SizeType32> const& managedLayers, std::vector<SizeType32> const& numKvHeadsPerLayer,
SizeType32 sizePerHead, SizeType32 tokensPerBlock, bool isSWA, SizeType32 blocksInPrimaryPool,
SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream,
bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager,
std::shared_ptr<kvc::BaseLoopbackAgent> loopbackAgent, bool enableIndexerKCache,
SizeType32 indexerKCacheQuantBlockSize, SizeType32 indexerKCacheIndexHeadDim)
: mDataType{dtype}
, mWindowSize{windowSize}
, mNumPrimaryBlocks{blocksInPrimaryPool}
, mNumSecondaryBlocks{blocksInSecondaryPool}
, mOnboardBlocks(onboardBlocks)
, mBufferManager{std::move(stream)}
, mSchedulingNumFreeBlocks{0}
, mTokensPerBlock{tokensPerBlock}
, mIsSWA{isSWA}
, mCachedBlocksRoot{std::make_shared<KVCacheBlock>(KVCacheBlock::kCachedBlocksRootId, tk::KVCacheIndex{0})}
, mCacheType{cacheType}
, mEventManager(std::move(eventManager))
, mLoopbackAgent{loopbackAgent}
, mTransferManager{std::make_shared<KVCacheTransferManager>(mBufferManager, mLoopbackAgent)}
, mAllocTotalBlocks{0}
, mAllocNewBlocks{0}
, mReusedBlocks{0}
, mReusedUniqueBlocks{0}
, mMissedBlocks{0}
, mKVFactor{mCacheType == CacheType::kSELFKONLY ? 1 : 2}
, mLogPrefix{tensorrt_llm::common::fmtstr("BlockManager[windowSize=%u]", mWindowSize)}
, mReusedTokens{0.0}
, mTotalInputTokens{0.0}
, mEnablePartialReuse{enablePartialReuse}
, mCopyOnPartialReuse{copyOnPartialReuse}
, mKvCacheConnectorManager{std::move(kvCacheConnectorManager)}
, mEnableIndexerKCache{enableIndexerKCache}
, mIndexerKCacheQuantBlockSize{indexerKCacheQuantBlockSize}
, mIndexerKCacheIndexHeadDim{indexerKCacheIndexHeadDim}
{
std::map<SizeType32, SizeType32> numLayersPerPool;
for (auto const layerIdx : managedLayers)
{
auto const& layerIndexWithinPool = numLayersPerPool[numKvHeadsPerLayer.at(layerIdx)]++;
mLayerToIndexWithinPool[layerIdx] = layerIndexWithinPool;
}
auto numEltsPerContainer = getNumEltsPerContainer();
#ifdef ENABLE_FP4
if (numEltsPerContainer == 2)
{
TLLM_CHECK_WITH_INFO(sizePerHead % 2 == 0, "sizePerHead must be divisible by 2 for 4-bit KV cache.");
}
#endif
size_t poolIndex = 0;
for (auto const [numKvHeads, numLayers] : numLayersPerPool)
{
for (auto const layerIdx : managedLayers)
{
if (numKvHeadsPerLayer.at(layerIdx) == numKvHeads)
{
mLayerToPoolIndex[layerIdx] = poolIndex;
}
}
mPools.emplace_back(numLayers, mKVFactor, numKvHeads, sizePerHead / numEltsPerContainer, tokensPerBlock);
++poolIndex;
}
#ifdef ENABLE_FP4
// TODO(miovine): make the block size configurable. Should we have an additional argument
// to specify FP4 related parameters (scale dtypes, etc)? This can also be passed
// in the constructor.
constexpr SizeType32 kQuantBlockSizeNVFP4 = 16;
if (dtype == nvinfer1::DataType::kFP4)
{
createBlockScalePools(kQuantBlockSizeNVFP4);
}
#endif
if (mEnableIndexerKCache)
{
createIndexerKCachePools();
}
// Create free blocks
mAllBlocksById.reserve(blocksInPrimaryPool + blocksInSecondaryPool);
for (KVCacheBlock::IdType blockId = 0; blockId < blocksInPrimaryPool; ++blockId)
{
mAllBlocksById.emplace_back(std::make_shared<KVCacheBlock>(blockId, tk::KVCacheIndex{blockId, false}));
}
for (KVCacheBlock::IdType blockId = 0; blockId < blocksInSecondaryPool; ++blockId)
{
mAllBlocksById.emplace_back(
std::make_shared<KVCacheBlock>(blocksInPrimaryPool + blockId, tk::KVCacheIndex{blockId, true}));
}
mAllocatedBlocksPerSeq.reserve(maxNumSequences);
mEvictionPolicy = std::make_shared<LRUEvictionPolicy>();
mEvictionPolicy->initialize(
mAllBlocksById, {blocksInPrimaryPool, blocksInSecondaryPool}, secondaryOffloadMinPriority);
if (mEventManager)
{
mEventManager->enqueueCreatedEvent({blocksInPrimaryPool, blocksInSecondaryPool}, mWindowSize);
}
}
WindowBlockManager::~WindowBlockManager()
{
float reusedUniqueBlocksPercentage = mReusedUniqueBlocks == 0 || mAllocTotalBlocks == 0
? 0
: static_cast<float>(mReusedUniqueBlocks) / static_cast<float>(mAllocNewBlocks) * 100;
float cacheHitRate = mReusedBlocks == 0
? 0
: static_cast<float>(mReusedBlocks) / (static_cast<float>(mReusedBlocks + mMissedBlocks));
TLLM_LOG_DEBUG("%s - total allocated blocks: %lu ", mLogPrefix.c_str(), mAllocTotalBlocks);
TLLM_LOG_DEBUG("%s - allocated new blocks: %lu ", mLogPrefix.c_str(), mAllocNewBlocks);
TLLM_LOG_DEBUG("%s - missed blocks: %lu ", mLogPrefix.c_str(), mMissedBlocks);
TLLM_LOG_DEBUG("%s - reused blocks: %lu ", mLogPrefix.c_str(), mReusedBlocks);
TLLM_LOG_DEBUG("%s - reused unique blocks: %lu ", mLogPrefix.c_str(), mReusedUniqueBlocks);
TLLM_LOG_DEBUG(
"%s - reused unique blocks percentage (%%): %.2f ", mLogPrefix.c_str(), reusedUniqueBlocksPercentage);
TLLM_LOG_DEBUG("%s - cache hit rate: %.2f ", mLogPrefix.c_str(), cacheHitRate);
TLLM_LOG_DEBUG("%s - reused tokens: %.0f ", mLogPrefix.c_str(), mReusedTokens);
TLLM_LOG_DEBUG("%s - reused tokens percentage (%%): %.2f ", mLogPrefix.c_str(),
100.0 * mReusedTokens / mTotalInputTokens);
}
bool BlockManager::verifyQueueIntegrity(SizeType32 windowSize)
{
return mWindowBlockManagers.at(windowSize).verifyQueueIntegrity();
}
bool WindowBlockManager::verifyQueueIntegrity()
{
return mEvictionPolicy->verifyQueueIntegrity();
}
void BlockManager::storeContextBlocks(GenerationRequest& sequence, LlmRequest const& llmRequest)
{
constexpr int beamIdx = 0; // no need to consider more than one beam for input tokens
for (auto const& [windowSize, _] : mWindowBlockManagers)
{
if (mWindowBlockManagers.at(windowSize).isSWA())
{
// SWA cannot store new blocks on the fly because the block stored
// may go OOW and be reused by another sequence.
continue;
}
auto cacheBlockIds = sequence.getCacheBlockIds(windowSize);
auto const& uniqueTokens = llmRequest.getUniqueTokens(beamIdx);
auto blockedUniqueTokens
= chopVectorIntoBlocks<UniqueToken>(uniqueTokens, uniqueTokens.size() - 1, getTokensPerBlock(), false);
auto blockKeys = buildBlockKeys(blockedUniqueTokens, llmRequest);
(void) mWindowBlockManagers.at(windowSize).storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]);
}
}
void WindowBlockManager::createBlockScalePools(SizeType32 quantBlockSize)
{
SizeType32 const numEltsPerContainer = getNumEltsPerContainer();
SizeType32 numPools = mPools.size();
for (SizeType32 i = 0; i < numPools; ++i)
{
auto& kvPool = mPools[i];
if (kvPool.containsIndexerKCache || kvPool.containsBlockScales)
{
continue;
}
TLLM_CHECK_WITH_INFO((kvPool.sizePerHead * numEltsPerContainer) % quantBlockSize == 0,
"Cannot use FP4 quantization since kvPool.sizePerHead is not divisible by FP4 quantBlockSize.");
auto blockScaleSizePerHead = kvPool.sizePerHead * numEltsPerContainer / quantBlockSize;
mPools.emplace_back(kvPool.numLayers, kvPool.kvFactor, kvPool.numKvHeads, blockScaleSizePerHead,
kvPool.tokensPerBlock,
/*primaryPool=*/nullptr,
/*secondaryPool=*/nullptr,
/*containsBlockScales=*/true,
/*containsIndexerKCache=*/false);
}
}
void WindowBlockManager::createIndexerKCachePools()
{
SizeType32 numPools = mPools.size();
for (SizeType32 i = 0; i < numPools; ++i)
{
auto& kvPool = mPools[i];
if (kvPool.containsIndexerKCache || kvPool.containsBlockScales)
{
continue;
}
SizeType32 scaleSize = mIndexerKCacheIndexHeadDim / mIndexerKCacheQuantBlockSize * 4;
mPools.emplace_back(kvPool.numLayers, kvPool.kvFactor, 1, scaleSize + mIndexerKCacheIndexHeadDim,
kvPool.tokensPerBlock,
/*primaryPool=*/nullptr,
/*secondaryPool=*/nullptr,
/*containsBlockScales=*/false,
/*containsIndexerKCache=*/true);
}
}
void BlockManager::allocatePools(bool useUvm)
{
for (auto& [_, manager] : mWindowBlockManagers)
{
manager.allocatePools(useUvm);
}
}
void WindowBlockManager::allocatePools(bool useUvm)
{
constexpr nvinfer1::DataType kScaleDtypeNVFP4 = nvinfer1::DataType::kFP8;
// Allocate a memory pool backing the blocks for each numKvHeads
// TODO(oargov): allocate pools in a single buffer and split it, to avoid fragmentation
for (auto& pool : mPools)
{
auto blockSize = pool.blockSize;
auto poolDtype = pool.containsBlockScales ? kScaleDtypeNVFP4 : mDataType;
#ifdef ENABLE_FP4
auto const poolIsFP4 = poolDtype == nvinfer1::DataType::kFP4;
#else
auto const poolIsFP4 = false;
#endif
if (poolIsFP4)
{
poolDtype = nvinfer1::DataType::kINT8;
}
if (pool.containsIndexerKCache)
{
poolDtype = nvinfer1::DataType::kUINT8;
}
nvinfer1::Dims cacheShape;
cacheShape = ITensor::makeShape({mNumPrimaryBlocks, pool.numLayers, mKVFactor, blockSize});
TLLM_LOG_DEBUG("[%s] Allocating primary pool with %d blocks for %d layers with %d kv heads", mLogPrefix.c_str(),
mNumPrimaryBlocks, pool.numLayers, pool.numKvHeads);
if (useUvm)
pool.primaryPtr = BufferManager::managed(cacheShape, poolDtype);
else
pool.primaryPtr = mBufferManager.gpuSync(cacheShape, poolDtype);
if (mNumSecondaryBlocks > 0)
{
nvinfer1::Dims const cacheShapeOffload
= ITensor::makeShape({mNumSecondaryBlocks, pool.numLayers, mKVFactor, blockSize});
TLLM_LOG_DEBUG("[%s] Allocating secondary pool with %d blocks for %d layers with %d kv heads",
mLogPrefix.c_str(), mNumSecondaryBlocks, pool.numLayers, pool.numKvHeads);
pool.secondaryPtr = BufferManager::pinned(cacheShapeOffload, poolDtype);
}
}
}
void BlockManager::releasePools()
{
for (auto& [_, manager] : mWindowBlockManagers)
{
manager.releasePools();
}
}
void WindowBlockManager::releasePools()
{
for (auto& pool : mPools)
{
if (pool.primaryPtr)
{
pool.primaryPtr->release();
}
if (pool.secondaryPtr)
{
pool.secondaryPtr->release();
}
}
mBufferManager.getStream().synchronize();
mBufferManager.memoryPoolTrimTo(0);
}
void BlockManager::startScheduling()
{
for (auto& [_, manager] : mWindowBlockManagers)
{
manager.startScheduling();
}
}
void WindowBlockManager::startScheduling()
{
mSchedulingNumFreeBlocks = mEvictionPolicy->getNumFreeBlocks(kPrimaryLevel);
for (auto& [requestId, slotAllocatedBlocks] : mAllocatedBlocksPerSeq)
{
for (auto& allocatedBlock : slotAllocatedBlocks)
{
allocatedBlock->startScheduling();
}
}
}
void WindowBlockManager::freeLeafBlock(BlockPtr const& block)
{
// The eviction policy needs blocks to still be linked to their old parents when they're reclaimed.
// This is so it can check if the parent should be queued for eviction.
block->freeLeafBlock();
}
void WindowBlockManager::freeChildren(BlockPtr const& block)
{
// Free all descendants of block
for (auto const& p : block->getNextBlocks())
{
auto childBlock = p.second;
freeChildren(childBlock);
}
// Free block
if (mEventManager && blockInRadixTree(block))
{
mEventManager->enqueueRemovedEvent(block, mWindowSize);
}
freeLeafBlock(block);
}
BlockPtr WindowBlockManager::getFreeBlock(GenerationRequest& sequence, executor::RetentionPriority priority,
std::optional<std::chrono::milliseconds> durationMs, executor::KvCacheTransferMode mode,
std::string const& directory)
{
// eviction policy get free primary block
auto [block, canOffload] = mEvictionPolicy->getFreeBlock(kPrimaryLevel);
if (block->getUniqueTokens().empty())
{
++mAllocNewBlocks;
}
++mAllocTotalBlocks;
// Offloading is an option only when these conditions are met:
// 1. Block contains state (evidenced by presence of tokens)
// 2. Eviction policy indicated block can be offloaded
// 3. At least one free block in secondary memory
// 4. Onboarding is enabled (allowing block to be brought back into primary)
if (!block->getUniqueTokens().empty() && canOffload && mEvictionPolicy->getNumFreeBlocks(kSecondaryLevel) > 0
&& mOnboardBlocks)
{
// Offload block in primary memory before repurposing
auto offloadBlock = std::get<0>(mEvictionPolicy->getFreeBlock(kSecondaryLevel));
mTransferManager->offload(block, offloadBlock, mPools, 0, mode, directory);
// swap linear block offsets (i.e. make block the offload block)
block->swapMemoryPoolBlockOffset(offloadBlock);
if (mEventManager && blockInRadixTree(block))
{
mEventManager->enqueueUpdatedEvent(
tle::KVCacheUpdatedData(block->getHash()).cacheLevelUpdated(kPrimaryLevel, kSecondaryLevel),
mWindowSize);
}
// Update the block as a secondary block (maintaining its priority)
mEvictionPolicy->claimBlock(block);
// Release the block into secondary block queue
mEvictionPolicy->releaseBlock(block);
// We have the offloaded block as the block to use now.
block = offloadBlock;
}
// Removes children of the block from the search tree
freeChildren(block);
// Claim the block in primary block queue
mEvictionPolicy->claimBlock(block, priority, durationMs);
// Deal with invalidating block save for reuse for the sequence
if (mBlockToSequence.count(block->getBlockId()) > 0)
{
auto const& originalOwnerSequenceId = mBlockToSequence[block->getBlockId()];
if (mIsValidStoreForReuseSequence.count(originalOwnerSequenceId) > 0
&& sequence.getRequestId() != originalOwnerSequenceId)
{
TLLM_LOG_DEBUG("%s::getFreeBlock - Block %d was originally held but released from sequence %d",
mLogPrefix.c_str(), block->getBlockId(), originalOwnerSequenceId);
if (mIsValidStoreForReuseSequence[originalOwnerSequenceId])
{
TLLM_LOG_DEBUG("%s::getFreeBlock - Invalidate store block for reuse for sequence %d",
mLogPrefix.c_str(), originalOwnerSequenceId);
}
else
{
TLLM_LOG_DEBUG("%s::getFreeBlock - Store block for reuse for sequence %d is already invalid",
mLogPrefix.c_str(), originalOwnerSequenceId);
}
mIsValidStoreForReuseSequence[originalOwnerSequenceId] = false;
}
}
// Record which sequence is using the block
mBlockToSequence[block->getBlockId()] = sequence.getRequestId();
TLLM_LOG_DEBUG("%s::getFreeBlock - Block %d is now acquired by sequence %d", mLogPrefix.c_str(),
block->getBlockId(), sequence.getRequestId());
return block;
}
void WindowBlockManager::setOffsets(tk::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape,
SizeType32 beamIdx, SizeType32 blockIdx, KVCacheBlock::IdType blockId) const
{
auto constexpr kIdx = 0;
auto constexpr vIdx = 1;
auto const& block = mAllBlocksById[blockId];
for (SizeType32 poolIdx = 0; poolIdx < static_cast<SizeType32>(mPools.size()); poolIdx++)
{
auto const& pool = mPools.at(poolIdx);
for (auto const xIdx : {kIdx, vIdx})
{
auto constexpr layerIdx = 0;
auto const offsetIndex = tensorrt_llm::common::flat_index(offsetsShape.d, poolIdx, beamIdx, xIdx, blockIdx);
auto const fieldIdx = mCacheType == CacheType::kSELFKONLY ? 0 : xIdx;
auto const blockIndex = tk::KVCacheIndex{
common::flat_index3(block->getMemoryPoolBlockIndex(), layerIdx, fieldIdx, pool.numLayers, mKVFactor)};
offsetsPtr[offsetIndex] = blockIndex;
}
}
}
void BlockManager::setOffsets(tk::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType32 beamIdx,
SizeType32 blockIdx, KVCacheBlock::IdType blockId, SizeType32 windowSize) const
{
mWindowBlockManagers.at(windowSize).setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId);
}
void BlockManager::onboardBlock(GenerationRequest& sequence, BlockPtr const& offloadBlock, SizeType32 windowSize,
executor::KvCacheTransferMode mode, std::string const& directory)
{
mWindowBlockManagers.at(windowSize).onboardBlock(sequence, offloadBlock, mode, directory);
}
void WindowBlockManager::onboardBlock(GenerationRequest& sequence, BlockPtr const& offloadBlock,
executor::KvCacheTransferMode mode, std::string const& directory)
{
if (mOnboardBlocks && !offloadBlock->isPrimary())
{
auto block = getFreeBlock(
sequence, executor::KvCacheRetentionConfig::kDefaultRetentionPriority, std::nullopt, mode, directory);
mTransferManager->onboard(offloadBlock, block, mPools, 0, mode, directory);
// swap linear block offsets (i.e. make block the offload block and vice versa)
offloadBlock->swapMemoryPoolBlockOffset(block);
if (mEventManager)
{
mEventManager->enqueueUpdatedEvent(
tle::KVCacheUpdatedData(offloadBlock->getHash()).cacheLevelUpdated(kSecondaryLevel, kPrimaryLevel),
mWindowSize);
}
mEvictionPolicy->releaseBlock(block); // append block to offload queue
// offloadBlock is now in primary memory pool
}
}
void BlockManager::offloadBlock(
BlockPtr const& block, SizeType32 windowSize, executor::KvCacheTransferMode mode, std::string const& directory)
{
mWindowBlockManagers.at(windowSize).offloadBlock(block, mode, directory);
}
void WindowBlockManager::offloadBlock(
BlockPtr const& block, executor::KvCacheTransferMode mode, std::string const& directory)
{
// The current default behavior is to offload the out-of-window block
// to secondary block pool to allow more free primary blocks for reuse.
// However, such behavior does not take account whether the offloaded
// block is useful or not and may just lead to more traffic instead.
// The ideal way of this is to dedicate the offloading of the block
// to the eviction policy.
if (mOnboardBlocks && block->isPrimary())
{
// Offload block in primary memory before repurposing
auto offloadBlock = std::get<0>(mEvictionPolicy->getFreeBlock(kSecondaryLevel));
// If we're swapping a block to secondary memory, maintain the prior priority values.
mEvictionPolicy->claimBlock(offloadBlock);
mTransferManager->offload(block, offloadBlock, mPools, 0, mode, directory);
// swap linear block offsets (i.e. make block the offload block)
block->swapMemoryPoolBlockOffset(offloadBlock);
if (mEventManager && blockInRadixTree(block))
{
mEventManager->enqueueUpdatedEvent(
tle::KVCacheUpdatedData(block->getHash()).cacheLevelUpdated(kPrimaryLevel, kSecondaryLevel),
mWindowSize);
}
mEvictionPolicy->releaseBlock(offloadBlock); // append offloadBlock to mFreePrimaryBlocks queue
// block is now in secondary memory
}
}
[[nodiscard]] std::optional<BlockKey> BlockManager::findNewContextBlock(
VecUniqueTokens const& uniqueTokens, LlmRequest const& llmRequest) const
{
TLLM_CHECK_WITH_INFO(
!isVariableWindow(), "The optimization of delaying requests won't work for variable window attention");
auto const& onlyManager = mWindowBlockManagers.cbegin()->second;
return onlyManager.findNewContextBlock(uniqueTokens, llmRequest);
}
std::optional<BlockKey> WindowBlockManager::findNewContextBlock(
VecUniqueTokens const& uniqueTokens, LlmRequest const& llmRequest) const
{
auto blockedUniqueTokens
= chopVectorIntoBlocks<UniqueToken>(uniqueTokens, uniqueTokens.size(), mTokensPerBlock, false);
auto blockKeys = buildBlockKeys(blockedUniqueTokens, llmRequest);
BlockKey ret;
ret.loraTaskId = llmRequest.getLoraTaskId();
auto searchRoot = mCachedBlocksRoot;
for (auto const& blockKey : blockKeys)
{
ret.uniqueTokens.insert(ret.uniqueTokens.end(), blockKey.uniqueTokens.begin(), blockKey.uniqueTokens.end());
auto [partialMatch, numMatched, matchingBlock] = searchRoot != nullptr
? searchRoot->findMatchingBlock(blockKey, false, false)
: std::make_tuple(false, 0, nullptr);
if (matchingBlock == nullptr)
{
return ret;
}
searchRoot = std::move(matchingBlock);
}
return std::nullopt;
}
bool WindowBlockManager::blockInRadixTree(BlockPtr const& block)
{
return !block->getUniqueTokens().empty() && block->getPrevBlock() != nullptr;
}
std::shared_ptr<KVCacheBlock> WindowBlockManager::findBlocksInReuseTreeByBlockKey(BlockKey const& blockKey)
{
std::lock_guard<std::mutex> lock(mCachedBlocksRootMutex);
auto blockedUniqueTokens
= chopVectorIntoBlocks<UniqueToken>(blockKey.uniqueTokens, blockKey.uniqueTokens.size(), mTokensPerBlock, true);
std::vector<BlockKey> blockKeys;
for (auto const& blockedUniqueTokensList : blockedUniqueTokens)
{
blockKeys.push_back(blockKey);
blockKeys.back().uniqueTokens = blockedUniqueTokensList;
}
auto searchRoot = mCachedBlocksRoot;
for (auto const& blockKey : blockKeys)
{
auto [partialMatch, numMatched, matchingBlock] = searchRoot != nullptr
? searchRoot->findMatchingBlock(blockKey, true, true)
: std::make_tuple(false, 0, nullptr);
if (matchingBlock == nullptr)
{
return nullptr;
}
searchRoot = std::move(matchingBlock);
}
return searchRoot;
}
SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const& blockKeys, SizeType32 numContextBlocks,
GenerationRequest& sequence, std::vector<executor::RetentionPriorityAndDuration> const& perBlockRetentions,
executor::KvCacheTransferMode mode, std::string const& directory)
{
std::lock_guard<std::mutex> lock(mCachedBlocksRootMutex);
SizeType32 numMatchedTokens{0};
auto searchRoot = mCachedBlocksRoot;
// The last block cannot be shared between beams because it will be written to.
// Make sure a unique block is allocated per beam.
auto const beamWidth = sequence.getBeamWidth();
SizeType32 numSharedContextBlocks = beamWidth > 1 ? numContextBlocks - 1 : numContextBlocks;
auto blockItr = blockKeys.begin();
for (int bi = 0; bi < numSharedContextBlocks; ++bi)
{
auto [partialMatch, numMatched, matchingBlock] = searchRoot != nullptr && blockItr != blockKeys.end()
? searchRoot->findMatchingBlock(*blockItr, mEnablePartialReuse, mCopyOnPartialReuse)
: std::make_tuple(false, 0, nullptr);
if (matchingBlock != nullptr)
{
KVCacheBlock::IdType matchingBlockId = matchingBlock->getBlockId();
numMatchedTokens += numMatched > 0 ? numMatched : blockItr->uniqueTokens.size();
if (perBlockRetentions[bi].retentionPriority.has_value()
&& matchingBlock->getPriority() != perBlockRetentions[bi].retentionPriority && mEventManager)
{
mEventManager->enqueueUpdatedEvent(
tle::KVCacheUpdatedData(matchingBlock->getHash())
.priorityUpdated(matchingBlock->getPriority(), *perBlockRetentions[bi].retentionPriority),
mWindowSize);
}
if (partialMatch)
{
if (matchingBlock->hasRefs() || !matchingBlock->isLeaf())
{
// Somebody else is using block or it is not a leaf, copy reusable tokens
auto newBlock = getFreeBlock(
sequence, matchingBlock->getPriority(), matchingBlock->getDurationMs(), mode, directory);
mTransferManager->onboard(matchingBlock, newBlock, mPools, numMatched, mode, directory);
// TODO: (optional) Send out event
matchingBlock = newBlock;
if (blockItr != blockKeys.end())
{
matchingBlock->setBlockKey(
*blockItr, blockItr->uniqueTokens.size() == static_cast<size_t>(mTokensPerBlock));
}
matchingBlock->setHash();
TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Copied partially filled block %d", mLogPrefix.c_str(),
matchingBlockId);
}
else
{
// Leaf block that nobody is using. Make block private and reuse
freeLeafBlock(matchingBlock);
mEvictionPolicy->claimBlock(
matchingBlock, perBlockRetentions[bi].retentionPriority, perBlockRetentions[bi].durationMs);
TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Reused partially filled block %d", mLogPrefix.c_str(),
matchingBlockId);
}
searchRoot = nullptr; // no matching needed for following blocks
}
else
{
// Recover block and reuse
mEvictionPolicy->claimBlock(
matchingBlock, perBlockRetentions[bi].retentionPriority, perBlockRetentions[bi].durationMs);
TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Matched full block %d", mLogPrefix.c_str(), matchingBlockId);
searchRoot = matchingBlock;
}
onboardBlock(sequence, matchingBlock, mode, directory);
addBlockToAllBeams(matchingBlock, sequence);
// TODO: only add once for reused blocks
++mReusedBlocks;
if (!reusedBlockIds.count(matchingBlockId))
{
reusedBlockIds.insert(matchingBlockId);
++mReusedUniqueBlocks;
}
++blockItr;
}
else
{
// If we haven't set a priority, set it to the default priority level (low)
auto freeBlock = getFreeBlock(sequence,
perBlockRetentions[bi].retentionPriority.value_or(
executor::KvCacheRetentionConfig::kDefaultRetentionPriority),
perBlockRetentions[bi].durationMs, mode, directory);
addBlockToAllBeams(freeBlock, sequence);
TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - No match, allocated new block %d for sequence %lu",
mLogPrefix.c_str(), freeBlock->getBlockId(), sequence.getRequestId());
searchRoot = nullptr; // no matching needed for following blocks
if (blockItr != blockKeys.end())
{
freeBlock->setBlockKey(
*blockItr, blockItr->uniqueTokens.size() == static_cast<size_t>(mTokensPerBlock));
++blockItr;
}
freeBlock->setHash();
++mMissedBlocks;
}
}
// Allocate new blocks that cannot be shared by multiple beams.
for (int bi = numSharedContextBlocks; bi < numContextBlocks; ++bi)
{
// TODO: Still look for match. Clone matching block or allocate fresh ones.
// This work is described in JIRA task https://jirasw.nvidia.com/browse/TRTLLM-2069.
for (SizeType32 beamIdx = 0; beamIdx < beamWidth; ++beamIdx)
{
// If we haven't set a priority, set it to the default priority level (low)
auto freeBlock = getFreeBlock(sequence,
perBlockRetentions[bi].retentionPriority.value_or(
executor::KvCacheRetentionConfig::kDefaultRetentionPriority),
perBlockRetentions[bi].durationMs, mode, directory);
addBlockToBeam(freeBlock, sequence, beamIdx);
if (blockItr != blockKeys.end())
{
freeBlock->setBlockKey(
*blockItr, blockItr->uniqueTokens.size() == static_cast<size_t>(mTokensPerBlock));
++blockItr;
}
freeBlock->setHash();
TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Beam %d. Allocated non-shared block %d for bi %d",
mLogPrefix.c_str(), beamIdx, freeBlock->getBlockId(), bi);
}
++mMissedBlocks;
if (blockItr != blockKeys.end())
{
++blockItr;
}
}
return numMatchedTokens;
}
void BlockManager::refreshBlocks()
{
for (auto& [_, manager] : mWindowBlockManagers)
{
manager.refreshBlocks();
}
}
void WindowBlockManager::refreshBlocks()
{
mEvictionPolicy->refresh();
mTransferManager->syncTransfers();
}
// There are two versions of BlockManager::addSequence function.
// This is called when block reuse is enabled.
void BlockManager::addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks,
LlmRequest& llmRequest, SizeType32 windowSize)
{
mWindowBlockManagers.at(windowSize).addSequence(sequence, inputLength, numContextBlocks, llmRequest);
}
// There are two versions of WindowBlockManager::addSequence function.
// This is called when block reuse is enabled.
void WindowBlockManager::addSequence(
GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest)
{
auto const requestId = sequence.getRequestId();
auto const [seqIt, emplaceDone] = mAllocatedBlocksPerSeq.emplace(requestId, std::vector<BlockPtr>{});
TLLM_CHECK(emplaceDone);
auto constexpr beamIdx = 0;
auto const& uniqueTokens = (mCacheType == CacheType::kSELF || mCacheType == CacheType::kSELFKONLY)
? llmRequest.getUniqueTokens(beamIdx)
: *(llmRequest.getEncoderUniqueTokens().value());
// Ignore last token because it can't be recovered
auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, inputLength - 1, mTokensPerBlock, true);
// Add empty block if last token is separated
if (inputLength % mTokensPerBlock == 1)
{
blockedUniqueTokens.emplace_back();
}
auto blockKeys = buildBlockKeys(blockedUniqueTokens, llmRequest);
auto config = llmRequest.getKvCacheRetentionConfig();
auto perBlockRetentions = config.value_or(executor::KvCacheRetentionConfig())
.getPerBlockRetentionPriorityDuration(getTokensPerBlock(), inputLength);
auto mode = config.value_or(executor::KvCacheRetentionConfig()).getTransferMode();
auto directory = config.value_or(executor::KvCacheRetentionConfig()).getDirectory();
if (mode != executor::KvCacheTransferMode::DRAM && directory.empty())
{
TLLM_LOG_WARNING(
"Transfer mode %d specified without directory, falling back to DRAM mode", static_cast<int>(mode));
mode = executor::KvCacheTransferMode::DRAM;
}
TLLM_CHECK(perBlockRetentions.size() == (size_t) numContextBlocks);
auto const prepopulatedPromptLen
= loadOrAllocateBlocks(blockKeys, numContextBlocks, sequence, perBlockRetentions, mode, directory);
mReusedTokens += static_cast<double>(prepopulatedPromptLen);
mTotalInputTokens += static_cast<double>(uniqueTokens.size());
SizeType32 numConnectorMatchedTokens = 0;
// If we're using a KV cache connector, check if any additional blocks can be loaded.
if (mKvCacheConnectorManager && !llmRequest.isDummyRequest())
{
numConnectorMatchedTokens = mKvCacheConnectorManager->getNumNewMatchedTokens(llmRequest, prepopulatedPromptLen);
}
llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen + numConnectorMatchedTokens, getTokensPerBlock());
TLLM_LOG_DEBUG("addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d, numConnectorMatchedTokens %d",
llmRequest.mRequestId, inputLength, prepopulatedPromptLen, numConnectorMatchedTokens);
}
void BlockManager::adjustBlocksIfNeeded(GenerationRequest& sequence)
{
for (auto& [windowSize, manager] : mWindowBlockManagers)
{
mWindowBlockManagers.at(windowSize).adjustBlocksIfNeeded(sequence);
}
}
void WindowBlockManager::adjustBlocksIfNeeded(GenerationRequest& sequence)
{
auto const minTokensForBlockDetach = mWindowSize + mTokensPerBlock;
while (
sequence.getNumTokens() - sequence.getNumFrontBlocksRemoved() * getTokensPerBlock() >= minTokensForBlockDetach)
{
// Detaching block for SWA is non-trivial due to the radix tree structure.
// For now, when reuse is enabled, we do not detach blocks for SWA.
TLLM_CHECK_WITH_INFO(mIsSWA, "A block only go out-of-window in SWA");
detachFrontBlock(sequence);
}
if ((sequence.getNumTokens() - 1) % getTokensPerBlock() == 0)
{
// Allocating a new block when the last token is a block boundary
allocateBlock(sequence, /*shareAmongBeams=*/sequence.getBeamWidth() == 1);
updateLastCacheBlockOffsets(sequence);
}
}
// There are two versions of BlockManager::addSequence function.
// This is called when block reuse is disabled.
void BlockManager::addSequence(
GenerationRequest& sequence, SizeType32 numContextBlocks, SizeType32 windowSize, bool isShareLastContextBlock)
{
mWindowBlockManagers.at(windowSize).addSequence(sequence, numContextBlocks, isShareLastContextBlock);
}
// There are two versions of WindowBlockManager::addSequence function.
// This is called when block reuse is disabled.
void WindowBlockManager::addSequence(
GenerationRequest& sequence, SizeType32 numContextBlocks, bool isShareLastContextBlock)
{
if (mKvCacheConnectorManager)
{
TLLM_LOG_WARNING(
"KV Cache Connector specified when block reuse is disabled. The KV Cache Connector will be "
"ignored.");
}
auto const requestId = sequence.getRequestId();
auto const [seqIt, emplaceDone] = mAllocatedBlocksPerSeq.emplace(requestId, std::vector<BlockPtr>{});
TLLM_CHECK(emplaceDone);
TLLM_CHECK_WITH_INFO(numContextBlocks > 0, "numContextBlocks must be greater than 0");
for (SizeType32 bi = 0; bi < numContextBlocks - 1; ++bi)
{
allocateBlock(sequence, /*shareAmongBeams=*/true);
}
allocateBlock(sequence, /*shareAmongBeams=*/isShareLastContextBlock);
}
void WindowBlockManager::addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx)
{
auto const requestId = sequence.getRequestId();
block->incRefCount();
if (sequence.getCacheBlockIds(mWindowSize).at(beamIdx).size() == 0)
{
block->setPrevBlockInSeq(nullptr);
}
else
{
block->setPrevBlockInSeq(mAllBlocksById.at(sequence.getCacheBlockIds(mWindowSize)[beamIdx].back()));
}
sequence.addCacheBlock(mWindowSize, beamIdx, block->getBlockId());
mAllocatedBlocksPerSeq.at(requestId).push_back(block);
}
void WindowBlockManager::addBlockToAllBeams(BlockPtr& block, GenerationRequest& sequence)
{
auto const beamWidth = sequence.getBeamWidth();
for (SizeType32 beamIdx = 0; beamIdx < beamWidth; ++beamIdx)
{
addBlockToBeam(block, sequence, beamIdx);
}
}
void BlockManager::allocateBlock(GenerationRequest& sequence, SizeType32 windowSize)
{
mWindowBlockManagers.at(windowSize).allocateBlock(sequence, false);
}
void WindowBlockManager::allocateBlock(GenerationRequest& sequence, bool shareAmongBeams)
{
auto const beamWidth = sequence.getBeamWidth();
auto const requiredBlocks = shareAmongBeams ? 1 : beamWidth;
TLLM_CHECK_WITH_INFO(hasFreeBlocks(requiredBlocks), "Can't allocate new blocks. No free blocks left.");
if (shareAmongBeams)
{
// add same block to all beams
auto block = getFreeBlock(sequence, sequence.getDecodeRetentionPriority(), sequence.getDecodeDurationMs(),
sequence.getTransferMode(), sequence.getDirectory());
for (auto beamIdx = 0; beamIdx < beamWidth; ++beamIdx)
{
addBlockToBeam(block, sequence, beamIdx);
}
}
else
{
// add different block to each beam
for (auto beamIdx = 0; beamIdx < beamWidth; ++beamIdx)
{
auto block = getFreeBlock(sequence, sequence.getDecodeRetentionPriority(), sequence.getDecodeDurationMs(),
sequence.getTransferMode(), sequence.getDirectory());
addBlockToBeam(block, sequence, beamIdx);
}
}
}
std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::storeBlocks(
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds, bool pinBlocks)
{
SizeType32 numBlocksStoredForReuse = 0;
std::lock_guard<std::mutex> lock(mCachedBlocksRootMutex);
TLLM_LOG_DEBUG(
"%s::storeBlocks - %zu blockKeys, %zu blockIds", mLogPrefix.c_str(), blockKeys.size(), blockIds.size());
auto searchRoot = mCachedBlocksRoot;
bool needMatch = true;
auto numBlocks = blockKeys.size();
std::vector<BlockPtr> storedBlocks;
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt;
for (std::size_t blockCnt = 0; blockCnt < numBlocks; ++blockCnt)
{
auto const bid = blockIds[blockCnt];
TLLM_LOG_DEBUG("%s::storeBlocks - Searching match for block %d", mLogPrefix.c_str(), bid);
auto& block = mAllBlocksById[bid];
auto const& blockKey = blockKeys[blockCnt];
auto [partialMatch, numMatched, matchedBlock]
= needMatch ? searchRoot->findMatchingBlock(blockKey, false, false) : std::make_tuple(false, 0, nullptr);
if (matchedBlock != nullptr)
{
// Found match
TLLM_LOG_DEBUG(
"%s::storeBlocks - Found matching block %d, traverse", mLogPrefix.c_str(), matchedBlock->getBlockId());
searchRoot = matchedBlock;
// TODO possible optimization: if bid != matchedBlock->getBlockId(),
// block can be freed and inserted at mFreePrimaryBlocks.begin()
}
else
{
// No match
TLLM_LOG_DEBUG("%s::storeBlocks - No match, inserting block %d into search structure", mLogPrefix.c_str(),
block->getBlockId());
TLLM_CHECK_WITH_INFO(block->getBlockId() == bid,
"Block id mismatch " + std::to_string(block->getBlockId()) + " != " + std::to_string(bid));
needMatch = false; // no matching needed for following blocks
block->setBlockKey(blockKey, static_cast<SizeType32>(blockKey.uniqueTokens.size()) == mTokensPerBlock);
block->setPrevBlock(searchRoot);
block->setPrevBlockInSeq(searchRoot);
searchRoot->addNextBlock(blockKey, block);
// Sanity check. The list of stored blocks should be connected.
TLLM_CHECK(storedBlocks.empty() || block->getPrevBlock() == storedBlocks.back());
storedBlocks.push_back(block);
TLLM_CHECK(block->getPrevBlockInSeq() == nullptr
|| block->getPrevBlockInSeq()->getHash() == searchRoot->getHash());
auto oldHash = block->getHash();
auto newHash = BlockKeyHasher()(blockKey, searchRoot->getHash());
if (oldHash != newHash)
{
TLLM_LOG_DEBUG("#%d block hash %zx -> %zx", block->getBlockId(), oldHash, newHash);
block->setHash(newHash);
}
searchRoot = block;
numBlocksStoredForReuse++;
}
if (pinBlocks)
{
searchRoot->incRefCount();
}
lastStoredId = searchRoot->getBlockId();
}
if (mEventManager)
{
mEventManager->enqueueStoredEvent(storedBlocks, mWindowSize);
}
return {numBlocksStoredForReuse, lastStoredId};
}
void BlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx)
{
mWindowBlockManagers.at(windowSize).replaceSharedBlock(sequence, blockIdx);
}
void WindowBlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeType32 blockIdx)
{
auto const requestId = sequence.getRequestId();
auto const beamWidth = sequence.getBeamWidth();
auto& allocatedBlocks = mAllocatedBlocksPerSeq.at(requestId);
if (!allocatedBlocks.at((blockIdx + 1) * beamWidth - 1)->isShared())
{
return;
}
BlockKey blockKey = allocatedBlocks.at(blockIdx * beamWidth)->getBlockKey();
bool isFull = allocatedBlocks.at(blockIdx * beamWidth)->isFull();
// Free shared block
for (auto beamIdx = 0; beamIdx < beamWidth; ++beamIdx)
{
auto block = allocatedBlocks.at(blockIdx * beamWidth + beamIdx);
block->decRefCount();
if (!block->hasRefs())
{
mEvictionPolicy->releaseBlock(block);
}
}
// Allocate new blocks
TLLM_CHECK_WITH_INFO(hasFreeBlocks(beamWidth), "Can't allocate new blocks. No free blocks left.");
for (auto beamIdx = 0; beamIdx < beamWidth; ++beamIdx)
{
auto block = getFreeBlock(sequence, executor::KvCacheRetentionConfig::kDefaultRetentionPriority, std::nullopt,
sequence.getTransferMode(), sequence.getDirectory());
block->incRefCount();
if (sequence.getCacheBlockIds(mWindowSize).at(beamIdx).size() == 0)
{
block->setPrevBlockInSeq(nullptr);
}
else
{
block->setPrevBlockInSeq(mAllBlocksById.at(sequence.getCacheBlockIds(mWindowSize)[beamIdx].back()));
}
block->setBlockKey(blockKey, isFull);
block->setHash();
sequence.changeCacheBlock(mWindowSize, beamIdx, blockIdx, block->getBlockId());
allocatedBlocks.at(blockIdx * beamWidth + beamIdx) = block;
}
}
void BlockManager::releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize)
{
mWindowBlockManagers.at(windowSize).releaseLastBlock(sequence);
}
void WindowBlockManager::releaseLastBlock(GenerationRequest& sequence)
{
auto const requestId = sequence.getRequestId();
auto& allocatedBlocks = mAllocatedBlocksPerSeq.at(requestId);
auto it = allocatedBlocks.rbegin();
auto& block = *it;
// Decrease ref count
block->decRefCount();
// If ref count is zero, move block to free blocks
if (!block->hasRefs())
{
mEvictionPolicy->releaseBlock(block, true);
}
// Remove block from allocated blocks
allocatedBlocks.pop_back();
// Remove stored block ids in sequence
sequence.removeLastBlock(mWindowSize);
}
[[nodiscard]] SizeType32 WindowBlockManager::getNumFreeBlocks() const noexcept
{
return mEvictionPolicy->getNumFreeBlocks(kPrimaryLevel);
}
std::deque<tle::KVCacheEvent> BlockManager::getLatestEvents(std::optional<std::chrono::milliseconds> timeout) const
{
return mEventManager ? mEventManager->getEvents(timeout) : std::deque<tle::KVCacheEvent>{};
}
std::optional<KVCacheBlock::IdType> BlockManager::storeBlocksForReuse(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
{
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt;
for (auto& [_, manager] : mWindowBlockManagers)
{
lastStoredId = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
}
return lastStoredId;
}
std::optional<KVCacheBlock::IdType> BlockManager::releaseBlocks(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
{
// Released block will be stored when reuse is enabled.
// Reuse is implied to be enabled if llmRequest is provided.
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt;
for (auto& [_, manager] : mWindowBlockManagers)
{
if (!llmRequest.has_value() || llmRequest->isDummyRequest() || sequence.getBeamWidth() > 1)
{
lastStoredId = manager.releaseBlocks(sequence, std::nullopt);
}
else
{
lastStoredId = manager.releaseBlocks(sequence, llmRequest);
}
}
return lastStoredId;
}
void BlockManager::pinBlocks(GenerationRequest& sequence)
{
for (auto& [_, manager] : mWindowBlockManagers)
{
manager.pinBlocks(sequence);
}
}
void BlockManager::unpinBlocksById(KVCacheBlock::IdType blockId)
{
// Use the first window size
if (mWindowBlockManagers.empty())
{
return;
}
auto& firstManager = mWindowBlockManagers.begin()->second;
firstManager.unpinBlocksById(blockId);
}
void WindowBlockManager::pinBlocks(GenerationRequest& sequence)
{
auto const requestId = sequence.getRequestId();
auto& allocatedBlocks = mAllocatedBlocksPerSeq.at(requestId);
for (auto& block : allocatedBlocks)
{
block->incRefCount();
}
}
void WindowBlockManager::unpinBlocksById(KVCacheBlock::IdType blockId)
{
if (blockId < 0 || static_cast<size_t>(blockId) >= mAllBlocksById.size())
{
return;
}
auto block = mAllBlocksById[blockId];
while (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
{
block->decRefCount();
if (!block->hasRefs())
{
mEvictionPolicy->releaseBlock(block);
}
block = std::move(block->getPrevBlock());
}
}
void BlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest)
{
for (auto& [_, manager] : mWindowBlockManagers)
{
if (manager.isSWA())
{
// SWA cannot store new blocks on the fly because the block stored
// may go OOW and be reused by another sequence.
continue;
}
manager.storeNewBlock(sequence, llmRequest);
}
}
void WindowBlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest)
{
auto constexpr beamIdx = 0;
auto const& uniqueTokens = llmRequest->getUniqueTokens(beamIdx);
auto const& cacheBlockIds = sequence.getCacheBlockIds(mWindowSize);
if (uniqueTokens.size() == 0)
{
return;
}
// TODO: get the caller to mark tokens as filled / not filled, so that the kv-cache manager doesn't
// have to guess. Only (length - 1) tokens of the sequence have their kv-state recorded in kv-cache. We assume
// the last token's state is not filled yet.
auto const usableSize = static_cast<runtime::SizeType32>(uniqueTokens.size()) - 1;
if (usableSize % mTokensPerBlock != 0)
{
return;
}
auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize, mTokensPerBlock, true);
auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest);
if (blockKeys.size() < 2 || cacheBlockIds[beamIdx].size() < blockKeys.size())
{
// store all blocks
TLLM_LOG_DEBUG("%s::storeNewBlock - store all blocks", mLogPrefix.c_str());
(void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]);
return;
}
auto lastBlock = mAllBlocksById.at(cacheBlockIds[beamIdx][blockKeys.size() - 1]);
auto prevBlock = mAllBlocksById.at(cacheBlockIds[beamIdx][blockKeys.size() - 2]);
// If the previous block is not in the radix tree, we need to store all blocks
if (prevBlock->getPrevBlock() == nullptr)
{
TLLM_LOG_DEBUG("%s::storeNewBlock - store all blocks", mLogPrefix.c_str());
(void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]);
return;
}
if (lastBlock->getPrevBlock() != nullptr)
{
// If the last block is not in the radix tree, we need to store all blocks
TLLM_LOG_DEBUG("%s::storeNewBlock - no need to store", mLogPrefix.c_str());
return;
}
TLLM_LOG_DEBUG("%s::storeNewBlock - store the last block", mLogPrefix.c_str());
(void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]);
}
std::optional<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
{
auto constexpr beamIdx = 0;
auto const& uniqueTokens = llmRequest->getUniqueTokens(beamIdx);
auto const& cacheBlockIds = sequence.getCacheBlockIds(mWindowSize);
// TODO: get the caller to mark tokens as filled / not filled, so that the kv-cache manager doesn't
// have to guess. Only (length - 1) tokens of the sequence have their kv-state recorded in kv-cache. We assume
// the last token's state is not filled yet.
auto const usableSize = static_cast<runtime::SizeType32>(uniqueTokens.size()) - 1;
auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize, mTokensPerBlock, true);
auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest);
return storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks).second;
}
std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest)
{
auto const requestId = sequence.getRequestId();
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt;
auto node = mAllocatedBlocksPerSeq.extract(requestId);
TLLM_CHECK(node);
auto& allocatedBlocks = node.mapped();
if (llmRequest.has_value())
{
// If llmRequest is provided, block store for reuse is enabled.
if (!isSequenceValidForStoreForReuse(requestId))
{
TLLM_LOG_DEBUG(
"%s::releaseBlocks - sequence %lu does not have all blocks valid, block is not saved for reuse",
mLogPrefix.c_str(), sequence.getRequestId());
}
else
{
if (mIsSWA)
{
TLLM_LOG_DEBUG("%s::releaseBlocks - sequence %lu is valid for store for reuse", mLogPrefix.c_str(),
sequence.getRequestId());
}
auto const& uniqueTokens = llmRequest->getUniqueTokens(/*beamIdx=*/0);
// Only (length - 1) tokens of the sequence have their kv-state
// recorded in kv-cache. We assume the last token's state is not filled yet.
auto const usableSize = static_cast<runtime::SizeType32>(uniqueTokens.size()) - 1;
auto blockedUniqueTokens
= chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize, mTokensPerBlock, /*allowPartial=*/true);
auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest);
std::vector<KVCacheBlock::IdType> cacheBlockIds(allocatedBlocks.size());
std::transform(allocatedBlocks.begin(), allocatedBlocks.end(), cacheBlockIds.begin(),
[](BlockPtr const& block) { return block->getBlockId(); });
auto [numBlocksStoredForReuse, lastStoredId] = storeBlocks(std::move(blockKeys), cacheBlockIds);
TLLM_LOG_DEBUG("%s::releaseBlocks Request %lu, %d blocks stored for reuse", mLogPrefix.c_str(),
sequence.getRequestId(), numBlocksStoredForReuse);
}
}
for (auto it = allocatedBlocks.rbegin(); it != allocatedBlocks.rend() - sequence.getNumFrontBlocksRemoved(); ++it)
{
auto& block = *it;
// Decrease ref count
if (block->hasRefs())
{
// An out-of-window block may not have any ref count.
block->decRefCount();
}
// If ref count is zero, move block to free blocks
if (!block->hasRefs())
{
mEvictionPolicy->releaseBlock(block);
}
}
// Remove stored block ids in sequence
sequence.clearCacheBlocks(mWindowSize);
return lastStoredId;
}
void BlockManager::schedulingReleaseBlocks(RequestIdType requestId)
{
for (auto& [_, manager] : mWindowBlockManagers)
{
manager.schedulingReleaseBlocks(requestId);
}
}
void WindowBlockManager::schedulingReleaseBlocks(RequestIdType requestId)
{
for (auto& block : mAllocatedBlocksPerSeq.at(requestId))
{
// Decrease ref count
block->decSchedulingRefCount();
// If ref count is zero, move block to free blocks
if (!block->hasSchedulingRefs())
{
mSchedulingNumFreeBlocks++;
}
}
}
KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead,
SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences,
SizeType32 maxBeamWidth, std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkTokenLength, int64_t stream, runtime::SizeType32 maxSequenceLength, bool enableBlockReuse,
bool onboardBlocks, CacheType cacheType, bool enablePartialReuse, bool copyOnPartialReuse, bool enableIndexerKCache,
SizeType32 indexerKCacheQuantBlockSize, SizeType32 indexerKCacheIndexHeadDim)
: KVCacheManager(std::vector<SizeType32>(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow,
maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
std::make_shared<runtime::CudaStream>(reinterpret_cast<cudaStream_t>(stream)), maxSequenceLength,
enableBlockReuse, onboardBlocks, cacheType, std::nullopt, nullptr, enablePartialReuse, copyOnPartialReuse,
nullptr, enableIndexerKCache, indexerKCacheQuantBlockSize, indexerKCacheIndexHeadDim)
{
}
KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead,
SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences,
SizeType32 maxBeamWidth, std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkTokenLength, int64_t stream, runtime::SizeType32 maxSequenceLength, bool enableBlockReuse,
bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager, bool enableIndexerKCache,
SizeType32 indexerKCacheQuantBlockSize, SizeType32 indexerKCacheIndexHeadDim)
: KVCacheManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth,
maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
std::make_shared<runtime::CudaStream>(reinterpret_cast<cudaStream_t>(stream)), maxSequenceLength,
enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority, eventManager, enablePartialReuse,
copyOnPartialReuse, kvCacheConnectorManager, enableIndexerKCache, indexerKCacheQuantBlockSize,
indexerKCacheIndexHeadDim)
{
}
KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead,
SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences,
SizeType32 maxBeamWidth, std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkTokenLength, CudaStreamPtr stream, runtime::SizeType32 maxSequenceLength, bool enableBlockReuse,
bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager, bool enableIndexerKCache,
SizeType32 indexerKCacheQuantBlockSize, SizeType32 indexerKCacheIndexHeadDim)
: mMaxBeamWidth(maxBeamWidth)
, mDataType(dtype)
, mMaxAttentionWindow(*std::max_element(maxAttentionWindowVec.begin(), maxAttentionWindowVec.end()))
, mTokensPerBlock(tokensPerBlock)
, mSinkBubbleLength(BaseKVCacheManager::getSinkBubbleLength(sinkTokenLength, tokensPerBlock))
, mSinkBlockTokenLength(mSinkBubbleLength + sinkTokenLength)
, mBlockManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences,
std::move(stream), maxSequenceLength, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype,
mSinkBubbleLength, onboardBlocks, cacheType, secondaryOffloadMinPriority, std::move(eventManager),
enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager), std::nullopt, enableIndexerKCache,
indexerKCacheQuantBlockSize, indexerKCacheIndexHeadDim)
// disable block reuse for sink bubble since chopVectorIntoBlocks does not match KV cache blocks in this case
, mEnableBlockReuse{mSinkBubbleLength > 0 ? false : enableBlockReuse}
{
TLLM_CHECK_WITH_INFO(mSinkBlockTokenLength == 0 && mSinkBubbleLength == 0,
"[kv cache manager] streamLLM is not supported at the moment");
TLLM_CHECK_DEBUG(std::find(maxAttentionWindowVec.begin(), maxAttentionWindowVec.end(), mMaxAttentionWindow)
!= maxAttentionWindowVec.end());
// The sink tokens are stored in blocks separate from other tokens.
// If the last block of sink tokens is only partially filled,
// we fill that block with a "bubble" to reach the number of tokens per block.
TLLM_CHECK(mSinkBlockTokenLength % tokensPerBlock == 0);
TLLM_LOG_DEBUG("KV cache block reuse is %s", mEnableBlockReuse ? "enabled" : "disabled");
mSequences.reserve(maxNumSequences);
}
KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead,
SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences,
SizeType32 maxBeamWidth, std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkTokenLength, CudaStreamPtr stream, runtime::SizeType32 maxSequenceLength, bool enableBlockReuse,
bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager, bool enableIndexerKCache,
SizeType32 indexerKCacheQuantBlockSize, SizeType32 indexerKCacheIndexHeadDim)
: KVCacheManager(std::vector<SizeType32>(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow,
maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
std::move(stream), maxSequenceLength, enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority,
std::move(eventManager), enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager),
enableIndexerKCache, indexerKCacheQuantBlockSize, indexerKCacheIndexHeadDim)
{
}
void KVCacheManager::allocatePools(bool useUvm)
{
mBlockManager.allocatePools(useUvm);
auto const numPools = mBlockManager.getNumPools();
uint64_t cacheSizeBytes = 0;
for (SizeType32 poolIdx = 0; poolIdx < numPools; poolIdx++)
{
auto const cacheShape = mBlockManager.getPrimaryPool(poolIdx)->getShape();
auto const cacheVolume = ITensor::volume(cacheShape);
#ifdef ENABLE_FP4
auto const isFp4 = mDataType == nvinfer1::DataType::kFP4;
#else
auto const isFp4 = false;
#endif
if (!isFp4)
{
cacheSizeBytes += cacheVolume * BufferDataType(mDataType).getSize();
}
else
{
cacheSizeBytes += (cacheVolume * 4) / 8;
}
}
// Save the total number of bytes allocated for the KV-cache for KvCacheStats
mAllocatedBytes = cacheSizeBytes;
if (tc::Logger::getLogger()->getLevel() <= tc::Logger::INFO)
{
TLLM_LOG_INFO("Number of tokens per block: %d.", mBlockManager.getTokensPerBlock());
auto const maxNumTokens = mBlockManager.getNumPrimaryBlocks() * mBlockManager.getTokensPerBlock();
TLLM_LOG_INFO("[MemUsageChange] Allocated %0.2f GiB for max tokens in paged KV cache (%d).",
cacheSizeBytes / static_cast<double>(1 << 30), maxNumTokens);
}
auto const numKVPools
= mBlockManager.getNumPools(/*include_block_scalar_pools=*/false, /*include_indexer_k_cache_pools=*/false);
auto const numBlockScalePools
= mBlockManager.getNumPools(/*includeBlockScalePools=*/true, /*includeIndexerKCachePools=*/false) - numKVPools;
// Code in the attention kernels is cleaner if we can access the KV values and block scales separately.
mBlockPoolPointers = BufferManager::cpu(ITensor::makeShape({numKVPools, 2}), TRTDataType<void*>::value);
mBlockScalePoolPointers
= BufferManager::cpu(ITensor::makeShape({numBlockScalePools, 2}), TRTDataType<void*>::value);
auto poolPtrsRange = BufferRange<void*>(*mBlockPoolPointers);
auto blockScalePtrsRange = BufferRange<void*>(*mBlockScalePoolPointers);
SizeType32 kvPoolIdx = 0;
SizeType32 blockScalePoolIdx = 0;
for (SizeType32 poolIdx = 0; poolIdx < numPools; poolIdx++)
{
auto const& pool = mBlockManager.getPool(poolIdx);
auto& outIdx = pool.containsBlockScales ? blockScalePoolIdx : kvPoolIdx;
auto& outRange = pool.containsBlockScales ? blockScalePtrsRange : poolPtrsRange;
if (pool.containsIndexerKCache)
{
mIndexerKCachePoolPointers = pool.primaryPtr;
}
else
{
outRange[outIdx * 2] = pool.primaryPtr->data();
outRange[outIdx * 2 + 1] = pool.secondaryPtr ? pool.secondaryPtr->data() : nullptr;
outIdx++;
}
}
auto const numLayers = mBlockManager.getNumLayers();
mLayerToPoolMapping = BufferManager::cpu(ITensor::makeShape({numLayers, 2}), TRTDataType<SizeType32>::value);
auto poolMappingRange = BufferRange<SizeType32>(*mLayerToPoolMapping);
for (SizeType32 layerIdx = 0; layerIdx < numLayers; layerIdx++)
{
auto const indexOfPool = mBlockManager.getLayerPoolIdx(layerIdx);
auto const layerIdxInCachePool = mBlockManager.getPoolLayerIdx(layerIdx);
poolMappingRange[layerIdx * 2] = indexOfPool;
poolMappingRange[layerIdx * 2 + 1] = layerIdxInCachePool;
}
}
void KVCacheManager::releasePools()
{
mBlockManager.releasePools();
}
void KVCacheManager::startScheduling()
{
mBlockManager.startScheduling();
}
SizeType32 KVCacheManager::getNeededBlocksOneStep(
LlmRequest const& req, bool twoStepsLookAhead, SizeType32 windowSize) const
{
if ((req.isContextInitState() && req.isFirstContextChunk()) || req.isDisaggGenerationInitState())
{
auto const chunkSize = req.mMaxNewTokens;
auto const maxDraftTokensToAdd = req.getNumDraftTokens();
auto const promptCacheLen
= std::min((isCrossKv() ? req.getEncoderOutputLen() : req.mPromptLen) + maxDraftTokensToAdd,
windowSize + chunkSize)
+ mSinkBubbleLength;
auto const numSharedBlocks = promptCacheLen / getTokensPerBlock();
auto const numUnSharedTokens = promptCacheLen % getTokensPerBlock();
auto const numUnSharedBlocks
= tc::ceilDiv(numUnSharedTokens, getTokensPerBlock()) * req.mSamplingConfig.beamWidth;
auto const numRequiredBlocks = numSharedBlocks + numUnSharedBlocks;
return numRequiredBlocks;
}
if (req.isGenerationInProgressState())
{
if (isCrossKv())
{
return 0;
}
auto const numCurrTokens = getSequence(req.mRequestId).getNumTokens();
auto const generatedTokens = numCurrTokens - req.getPromptLen();
auto const maxTokensToAddToKVCache = req.mMaxNewTokens - generatedTokens;
auto const tokensPerStep = req.getNumDraftTokens() + 1;
auto const maxTokensToAdd = std::min((twoStepsLookAhead ? 2 : 1) * tokensPerStep, maxTokensToAddToKVCache);
auto const numNextTokens = numCurrTokens + maxTokensToAdd;
if (numNextTokens > mBlockManager.getWindowSizeMetadata(windowSize).maxTokenNum)
{
return 0;
}
auto const numCurrBlocks = tc::ceilDiv(numCurrTokens, getTokensPerBlock());
auto const numNextBlocks = tc::ceilDiv(numNextTokens, getTokensPerBlock());
auto const numRequiredBlocks = (numNextBlocks - numCurrBlocks) * req.mSamplingConfig.beamWidth;
return numRequiredBlocks;
}
return 0;
}
SizeType32 KVCacheManager::getRemainingBlocksToCompletion(LlmRequest const& req, SizeType32 windowSize) const
{
if (isCrossKv())
{
if (req.isContextInitState() && req.getContextCurrentPosition() == 0)
{
return tc::ceilDiv(req.getEncoderOutputLen(), getTokensPerBlock());
}
return 0; // cross KV cache doesn't grow after the initial context phase
}
auto const temporaryAttentionWindow = mBlockManager.getWindowSizeMetadata(windowSize).temporaryAttentionWindow;
SizeType32 const numContextBlocks
= (std::min(req.mPromptLen, windowSize + temporaryAttentionWindow) + mSinkBubbleLength) / getTokensPerBlock();
SizeType32 const numTotalBlocksPerBeam = tc::ceilDiv(
std::min(req.mPromptLen + req.mMaxNewTokens, windowSize + temporaryAttentionWindow) + mSinkBubbleLength,
getTokensPerBlock());
SizeType32 const numGenBlocksPerBeam = numTotalBlocksPerBeam - numContextBlocks;
SizeType32 numAllocBlocksPerBeam = 0;
{
std::scoped_lock lck(mSequencesMtx);
auto const seqIt = mSequences.find(req.mRequestId);
if (seqIt != mSequences.end())
{
auto const& seq = seqIt->second;
numAllocBlocksPerBeam = seq.getCacheBlockIds(windowSize).at(0).size();
}
}
// In case of sliding window attention, a new block is allocated when the
// window slides (and then the out-of-window block is detached). So we
// need an extra block for generation if the diff between the max sequence
// length and the current sequence length crosses both a block boundary
// and a window boundary.
auto const isSlidingWindow = (req.mPromptLen + req.mMaxNewTokens) > windowSize;
SizeType32 const currentSeqlenInBlocks = tc::ceilDiv(req.getNumTokens(0), getTokensPerBlock());
SizeType32 const maxSeqlenInBlocks = tc::ceilDiv(req.mPromptLen + req.mMaxNewTokens, getTokensPerBlock());
auto const willCrossBlockBoundary = maxSeqlenInBlocks > currentSeqlenInBlocks;
auto const willCrossWindowBlockBoundary = maxSeqlenInBlocks > numTotalBlocksPerBeam;
SizeType32 numExtraBlocksPerBeam
= isSlidingWindow && willCrossBlockBoundary && willCrossWindowBlockBoundary ? 1 : 0;
if (numAllocBlocksPerBeam < numContextBlocks) // Still haven't allocated all context blocks
{
return numContextBlocks - numAllocBlocksPerBeam
+ (numGenBlocksPerBeam + numExtraBlocksPerBeam) * req.mSamplingConfig.beamWidth;
}
return (numTotalBlocksPerBeam - numAllocBlocksPerBeam + numExtraBlocksPerBeam) * req.mSamplingConfig.beamWidth;
}
void BlockManager::updateSequenceCacheBlockOffsets(GenerationRequest& sequence, SizeType32 windowSize)
{
auto const& cacheBlocks = sequence.getCacheBlockIds(windowSize);
auto& cacheBlocksTensor = sequence.getCacheBlockIndices(windowSize);
auto const beamWidth = sequence.getBeamWidth();
auto* offsetsPtr = bufferCast<tk::KVCacheIndex>(cacheBlocksTensor);
auto const& offsetsShape = cacheBlocksTensor.getShape();
for (SizeType32 beamIdx = 0; beamIdx < beamWidth; ++beamIdx)
{
auto const& beamCacheBlock = cacheBlocks[beamIdx];
for (SizeType32 blockIdx = 0; blockIdx < static_cast<SizeType32>(beamCacheBlock.size()); ++blockIdx)
{
auto const blockId = beamCacheBlock.at(blockIdx);
mWindowBlockManagers.at(windowSize).setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId);
}
}
}
void WindowBlockManager::updateLastCacheBlockOffsets(GenerationRequest& sequence)
{
auto const& cacheBlocks = sequence.getCacheBlockIds(mWindowSize);
auto& cacheBlocksTensor = sequence.getCacheBlockIndices(mWindowSize);
auto const beamWidth = sequence.getBeamWidth();
auto* offsetsPtr = bufferCast<tk::KVCacheIndex>(cacheBlocksTensor);
auto const& offsetsShape = cacheBlocksTensor.getShape();
for (SizeType32 beamIdx = 0; beamIdx < beamWidth; ++beamIdx)
{
auto const& beamCacheBlock = cacheBlocks[beamIdx];
auto const blockId = beamCacheBlock.back();
auto const blockIdx = static_cast<SizeType32>(beamCacheBlock.size() - 1);
setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId);
}
}
void BlockManager::updateCacheBlockOffsetsAtIdx(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx)
{
auto const& cacheBlocks = sequence.getCacheBlockIds(windowSize);
auto& cacheBlocksTensor = sequence.getCacheBlockIndices(windowSize);
auto const beamWidth = sequence.getBeamWidth();
auto* offsetsPtr = bufferCast<tk::KVCacheIndex>(cacheBlocksTensor);
auto const& offsetsShape = cacheBlocksTensor.getShape();
for (SizeType32 beamIdx = 0; beamIdx < beamWidth; ++beamIdx)
{
auto const& beamCacheBlock = cacheBlocks[beamIdx];
auto const blockId = beamCacheBlock.at(blockIdx);
mWindowBlockManagers.at(windowSize).setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId);
}
}
void KVCacheManager::addToken(RequestIdType requestId)
{
// TODO: add streamLLM support
auto& sequence = getSequence(requestId);
sequence.addNewTokens(1);
mBlockManager.adjustBlocksIfNeeded(sequence);
}
void WindowBlockManager::detachFrontBlock(GenerationRequest& sequence)
{
// streamLLM is not supported at the moment. The out of window block will
// always be the 0th block.
TLLM_CHECK_WITH_INFO(
sequence.getBeamWidth() == 1, "[kv cache manager] detachBlock does not support beamWidth > 1 now.");
auto const requestId = sequence.getRequestId();
auto const beamWidth = sequence.getBeamWidth();
auto& allocatedBlocks = mAllocatedBlocksPerSeq.at(requestId);
SizeType32 outOfWindowBlockIdx = sequence.getNumFrontBlocksRemoved();
for (auto beamIdx = 0; beamIdx < beamWidth; ++beamIdx)
{
auto outOfWindowBlock = allocatedBlocks.at(outOfWindowBlockIdx * beamWidth + beamIdx);
TLLM_LOG_DEBUG("%s::detachFrontBlock - Detaching block %d from sequence %d", mLogPrefix.c_str(),
outOfWindowBlock->getBlockId(), requestId);
outOfWindowBlock->decRefCount();
if (outOfWindowBlock->hasRefs())
{
TLLM_LOG_DEBUG("%s::detachFrontBlock - OOW Block %d still has a non-zero ref count", mLogPrefix.c_str(),
outOfWindowBlock->getBlockId());
}
if (!outOfWindowBlock->hasRefs())
{
mEvictionPolicy->releaseBlock(outOfWindowBlock);
}
}
// Disconnect first block from sequence and remove it from allocated blocks
sequence.removeFrontBlock(mWindowSize);
}
std::optional<BlockKey> KVCacheManager::findNewContextBlock(
VecUniqueTokens const& uniqueTokens, LlmRequest const& llmRequest) const
{
auto newContextBlockOpt = mBlockManager.findNewContextBlock(uniqueTokens, llmRequest);
return newContextBlockOpt;
}
void KVCacheManager::addSequence(
RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, OptionalRef<LlmRequest> llmRequest)
{
// TODO: add streamLLM support
auto kvCacheRetentionConfig = llmRequest
? llmRequest->getKvCacheRetentionConfig().value_or(executor::KvCacheRetentionConfig())
: executor::KvCacheRetentionConfig();
auto const [seqIt, emplaceDone] = [&]
{
auto lck = std::scoped_lock(mSequencesMtx);
return mSequences.try_emplace(requestId, requestId, inputLength, beamWidth,
mBlockManager.getWindowSizesMetadata(), kvCacheRetentionConfig);
}();
TLLM_CHECK(emplaceDone);
auto& sequence = seqIt->second;
// Get statistics for block allocations/reuse pre request.
SizeType32 const numAllocTotalBlocksPreRequest = mBlockManager.getNumAllocTotalBlocks();
SizeType32 const numAllocNewBlocksPreRequest = mBlockManager.getNumAllocNewBlocks();
SizeType32 const numReusedBlocksPreRequest = mBlockManager.getNumReusedBlocks();
SizeType32 const numMissedBlocksPreRequest = mBlockManager.getNumMissedBlocks();
if (!mBlockManager.isSequenceHeld(requestId))
{
mBlockManager.holdSequence(requestId);
TLLM_LOG_DEBUG(
"[kv cache manager] Encounter new sequence %d, initialize sequence storage validity for all window sizes",
requestId);
}
else
{
TLLM_LOG_DEBUG(
"[kv cache manager] Encounter existing sequence %d, skip sequence storage validity initialization",
requestId);
}
for (auto const [windowSize, metadata] : mBlockManager.getWindowSizesMetadata())
{
// NOTE: Caller to KVCacheManager::addSequence should deal with the chunking
auto const maxTokenNum = metadata.maxTokenNum;
auto const temporaryAttentionWindow = metadata.temporaryAttentionWindow;
// Consider the temporaryAttentionWindow when allocating blocks.
auto const effectiveInputLength = std::min(inputLength, maxTokenNum + temporaryAttentionWindow);
auto const numContextBlocks = tc::ceilDiv(effectiveInputLength, getTokensPerBlock());
if (mEnableBlockReuse)
{
mBlockManager.addSequence(sequence, effectiveInputLength, numContextBlocks, *llmRequest, windowSize);
}
else
{
if (!mEnableBlockReuse && llmRequest && llmRequest->getKvCacheRetentionConfig().has_value())
{
TLLM_LOG_WARNING(
"Request %d has a retention configuration set, but block reuse is disabled. The retention "
"config "
"will "
"have no effect.",
llmRequest->mRequestId);
}
bool isShareLastContextBlock = isCrossKv() || effectiveInputLength % getTokensPerBlock() == 0;
mBlockManager.addSequence(sequence, numContextBlocks, windowSize, isShareLastContextBlock);
}
mBlockManager.updateSequenceCacheBlockOffsets(sequence, windowSize);
}
if (llmRequest)
{
// Update statistics for block allocations/reuse per request.
llmRequest->updateAllocTotalBlocksPerRequest(
mBlockManager.getNumAllocTotalBlocks() - numAllocTotalBlocksPreRequest);
llmRequest->updateAllocNewBlocksPerRequest(mBlockManager.getNumAllocNewBlocks() - numAllocNewBlocksPreRequest);
llmRequest->updateReusedBlocksPerRequest(mBlockManager.getNumReusedBlocks() - numReusedBlocksPreRequest);
llmRequest->updateMissedBlocksPerRequest(mBlockManager.getNumMissedBlocks() - numMissedBlocksPreRequest);
}
}
void KVCacheManager::storeContextBlocks(LlmRequest const& llmRequest)
{
auto const requestId = llmRequest.mRequestId;
bool found = false;
{
// protect the mSequences
std::scoped_lock lock(mSequencesMtx);
found = mSequences.find(requestId) != mSequences.end();
}
if (found)
{
auto& sequence = getSequence(requestId);
if (mEnableBlockReuse && !llmRequest.isDummyRequest())
{
mBlockManager.storeContextBlocks(sequence, llmRequest);
}
}
else
{
TLLM_LOG_WARNING("[kv cache manager] storeContextBlocks: Can not find sequence for request %lu", requestId);
}
}
void KVCacheManager::storeNewBlock(LlmRequest const& llmRequest)
{
// We store newest block for potential reuse only if:
// - Beam search is NOT enabled
// - Block reuse is enabled.
auto const requestId = llmRequest.mRequestId;
auto& sequence = getSequence(requestId);
if (sequence.getBeamWidth() > 1 || !mEnableBlockReuse)
{
return;
}
mBlockManager.storeNewBlock(sequence, llmRequest);
}
std::optional<KVCacheBlock::IdType> KVCacheManager::removeSequence(
RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
{
TLLM_LOG_TRACE("[%s]::%s start", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
auto sequenceNode = [this, requestId]
{
std::scoped_lock lock(mSequencesMtx);
return mSequences.extract(requestId);
}();
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt;
if (!sequenceNode.empty())
{
if (mEnableBlockReuse)
{
lastStoredId = mBlockManager.releaseBlocks(sequenceNode.mapped(), llmRequest, pinBlocks);
}
else
{
lastStoredId = mBlockManager.releaseBlocks(sequenceNode.mapped(), std::nullopt, pinBlocks);
}
}
if (mBlockManager.isSequenceHeld(requestId))
{
mBlockManager.releaseSequence(requestId);
TLLM_LOG_DEBUG("Remove sequence %d, release sequence storage validity for all window sizes", requestId);
}
TLLM_CHECK(!mBlockManager.isSequenceHeld(requestId));
TLLM_LOG_TRACE("[%s]::%s stop", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
return lastStoredId;
}
std::optional<KVCacheBlock::IdType> KVCacheManager::storeBlocksForReuse(
RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
{
TLLM_LOG_TRACE("[%s]::%s start", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
auto& sequence = getSequence(requestId);
std::optional<KVCacheBlock::IdType> lastStoredId
= mBlockManager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
TLLM_LOG_TRACE("[%s]::%s stop", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
return lastStoredId;
}
void KVCacheManager::schedulingRemoveSequence(RequestIdType requestId)
{
// Mimic Free all blocks for this sequence
mBlockManager.schedulingReleaseBlocks(requestId);
}
void KVCacheManager::pinBlocks(RequestIdType requestId)
{
auto& sequence = getSequence(requestId);
mBlockManager.pinBlocks(sequence);
}
void KVCacheManager::unpinBlocksById(KVCacheBlock::IdType blockId)
{
mBlockManager.unpinBlocksById(blockId);
}
SizeType32 KVCacheManager::copyBlockOffsets(ITensor& output, SizeType32 outputSlotOffset, RequestIdType requestId) const
{
auto const& sequence = getSequence(requestId);
auto const beamWidth = sequence.getBeamWidth();
auto* dstPtr = bufferCast<tk::KVCacheIndex>(output);
auto const& dstShape = output.getShape();
SizeType32 constexpr kIdx = 0;
SizeType32 constexpr vIdx = 1;
SizeType32 maxBlockCount{0};
// Get page table for each KV cache pool
SizeType32 absolutePoolIdx = 0;
for (auto const [windowSize, metadata] : mBlockManager.getWindowSizesMetadata())
{
auto const& cacheBlocksTensor = sequence.getCacheBlockIndices(windowSize);
auto const* srcPtr = bufferCast<tk::KVCacheIndex>(cacheBlocksTensor);
auto const& srcShape = cacheBlocksTensor.getShape();
auto const& cacheBlockIds = sequence.getCacheBlockIds(windowSize);
for (SizeType32 poolIdx = 0; poolIdx < metadata.numPools; poolIdx++, absolutePoolIdx++)
{
for (SizeType32 beamIdx = 0; beamIdx < beamWidth; ++beamIdx)
{
auto const beamBlockCount = cacheBlockIds[beamIdx].size();
auto const copyChunkSize = beamBlockCount * sizeof(tk::KVCacheIndex);
for (auto xIdx : {kIdx, vIdx})
{
auto const srcIndex = tc::flat_index(srcShape.d, poolIdx, beamIdx, xIdx, 0);
auto const dstIndex
= tc::flat_index(dstShape.d, absolutePoolIdx, outputSlotOffset + beamIdx, xIdx, 0);
std::memcpy(dstPtr + dstIndex, srcPtr + srcIndex, copyChunkSize);
}
maxBlockCount = std::max<SizeType32>(maxBlockCount, static_cast<SizeType32>(beamBlockCount));
}
}
}
return maxBlockCount;
}
void KVCacheManager::getBlockOffsetsOfBatch(
ITensor& output, SizeType32 firstBatchSlotIdx, SizeType32 batchSize, SizeType32 beamWidth) const
{
// Get page table for each KV cache pool
for (auto batchSlotIdx = 0; batchSlotIdx < batchSize; ++batchSlotIdx)
{
copyBlockOffsets(output, batchSlotIdx * beamWidth, firstBatchSlotIdx + batchSlotIdx);
}
}
std::map<SizeType32, std::vector<SizeType32>> BaseKVCacheManager::groupLayersByWindowSize(
std::vector<SizeType32> const& maxAttentionWindowVec, SizeType32 numLayers)
{
auto const numNonUniqueWindowSizes = static_cast<SizeType32>(maxAttentionWindowVec.size());
std::map<SizeType32, std::vector<SizeType32>> uniqueWindowSizeToLayers;
for (SizeType32 layerIdx = 0; layerIdx < numLayers; layerIdx++)
{
/*
At this point (Deep in the construction of TrtGptModel), maxAttentionWindowVec isn't "stretched" to the
length of numLayers yet. So, we need to rotate the window sizes per layer with modulo.
*/
auto const windowSize = maxAttentionWindowVec.at(layerIdx % numNonUniqueWindowSizes);
uniqueWindowSizeToLayers[windowSize].push_back(layerIdx);
}
return uniqueWindowSizeToLayers;
}
std::tuple<uint64_t, uint64_t> BaseKVCacheManager::calculateFreeMemBytes(
runtime::BufferManager const& bufferManager, executor::KvCacheConfig const& config)
{
auto const freeMemFraction
= config.getFreeGpuMemoryFraction().value_or(executor::KvCacheConfig::kDefaultGpuMemFraction);
TLLM_CHECK_WITH_INFO(freeMemFraction < 1.0F,
"Invalid freeMemFraction, freeMemFraction (%f) must be smaller than 1.0f", freeMemFraction);
if (config.getMaxTokens().has_value())
{
if (config.getFreeGpuMemoryFraction().has_value())
{
TLLM_LOG_WARNING(
"Both freeGpuMemoryFraction (aka kv_cache_free_gpu_mem_fraction) "
"and maxTokens (aka max_tokens_in_paged_kv_cache) "
"are set (to %f and %ld, respectively). The smaller value will be used.",
freeMemFraction, (int64_t) config.getMaxTokens().value());
}
}
TLLM_CUDA_CHECK(::cudaDeviceSynchronize());
auto const [freeMem, totalMem] = tc::getDeviceMemoryInfo(config.getUseUvm());
auto const finalFreeMem = freeMem + bufferManager.memoryPoolFree();
TLLM_LOG_INFO("Memory usage when calculating max tokens in paged kv cache: total: %0.2f GiB, available: %0.2f GiB",
totalMem / static_cast<double>(1 << 30), finalFreeMem / static_cast<double>(1 << 30));
TLLM_CHECK_WITH_INFO(finalFreeMem <= totalMem, "Free memory cannot exceed total memory");
auto const freePrimaryMemBytes = static_cast<uint64_t>(finalFreeMem * freeMemFraction);
auto const freeSecondaryMemBytes = config.getHostCacheSize().value_or(0);
TLLM_LOG_DEBUG("Calculated free memory: {.freePrimaryMemBytes=%" PRIu64 ", .freeSecondaryMemBytes=%" PRIu64 "}",
freePrimaryMemBytes, freeSecondaryMemBytes);
return std::make_tuple(freePrimaryMemBytes, freeSecondaryMemBytes);
}
namespace
{
bool isSortedVectorIdenticalAcrossAllRanks(WorldConfig const& worldConfig, std::vector<SizeType32> const& vector)
{
auto const numRanks = worldConfig.getSize();
auto const numElements = static_cast<int>(vector.size());
int maxNumElements = 0;
int minNumElements = 0;
COMM_SESSION.allreduce(&numElements, &maxNumElements, 1, mpi::MpiType::kINT32, mpi::MpiOp::MAX);
COMM_SESSION.allreduce(&numElements, &minNumElements, 1, mpi::MpiType::kINT32, mpi::MpiOp::MIN);
if (maxNumElements != minNumElements)
{
return false;
}
std::vector<SizeType32> allElements(numElements * numRanks);
COMM_SESSION.allgather(vector.data(), allElements.data(), numElements, mpi::MpiType::kUINT32);
for (int i = 0; i < numElements; ++i)
{
auto const ref = allElements.at(i);
for (int rank = 1; rank < numRanks; ++rank)
{
if (allElements[rank * numElements + i] != ref)
return false;
}
}
return true;
}
} // namespace
BlocksPerWindow BaseKVCacheManager::calculateMaxNumBlocks(executor::KvCacheConfig const& config, bool isCrossAttention,
nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
std::map<SizeType32, std::vector<SizeType32>> const& windowSizeToLayers, uint64_t allottedPrimaryMemBytes,
uint64_t allottedSecondaryMemBytes, size_t extraCostMemory, SizeType32 kvFactor)
{
TLLM_LOG_DEBUG("Calculating max num blocks for %s: {.allottedPrimaryMemBytes=%" PRIu64
", .allottedSecondaryMemBytes=%" PRIu64 "}",
isCrossAttention ? "Cross KvCacheManager" : "Self KvCacheManager", allottedPrimaryMemBytes,
allottedSecondaryMemBytes);
if (config.getMaxTokens().has_value() && windowSizeToLayers.size() > 1)
{
TLLM_LOG_WARNING(
"Setting maxTokens when using Variable Sliding Window Attention is a strange concept, as it limits "
"the number of max tokens *per window size* [limiting the sum of all window sizes is even stranger]. "
"Anticipating the effects of this requires quite a complex calculation, and it probably isn't the "
"configuration you meant to use.");
}
std::map<SizeType32, SizeType32> cacheSizeBytesPerTokenPerWindow;
for (auto const& [windowSize, managedLayers] : windowSizeToLayers)
{
auto const cacheSizePerToken = BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize(
modelConfig, managedLayers, isCrossAttention, kvFactor);
auto const cacheSizeBytesPerToken = cacheSizePerToken * BufferDataType(dtype).getSize();
cacheSizeBytesPerTokenPerWindow[windowSize] = cacheSizeBytesPerToken;
}
bool const isVSWA = cacheSizeBytesPerTokenPerWindow.size() > 1;
TLLM_LOG_DEBUG("extraCostMemory [Gib]: %0.2f", extraCostMemory / static_cast<double>(1 << 30));
allottedPrimaryMemBytes = allottedPrimaryMemBytes - extraCostMemory;
auto const tokensPerBlock = modelConfig.getTokensPerBlock();
auto const calculatePrimaryBlocks
= [&](SizeType32 windowSize, float windowSizeShare, SizeType32 cacheSizeBytesPerToken)
{
TLLM_LOG_DEBUG("windowSizeShare: %f, cacheSizeBytesPerToken: %d", windowSizeShare, cacheSizeBytesPerToken);
auto maxTokens = static_cast<uint64_t>(
allottedPrimaryMemBytes * windowSizeShare / static_cast<double>(cacheSizeBytesPerToken));
// kv_cache_config.max_tokens is not effective in VSWA scheme
if (config.getMaxTokens().has_value() && !isVSWA)
{
auto const maxTokensFromConfig = static_cast<uint64_t>(config.getMaxTokens().value());
if (maxTokensFromConfig < maxTokens)
{
TLLM_LOG_DEBUG("Maximum kv-cache token overridden by configuration as '%ld'.", maxTokensFromConfig);
maxTokens = std::min(maxTokensFromConfig, maxTokens);
}
}
TLLM_LOG_DEBUG("Primary maxTokens for windowSize %d: %ld", windowSize, maxTokens);
SizeType32 const blocksInPrimaryPool = tc::ceilDiv(maxTokens, tokensPerBlock);
TLLM_LOG_DEBUG(
"Number of blocks in KV cache primary pool for windowSize %d: %d", windowSize, blocksInPrimaryPool);
return blocksInPrimaryPool;
};
auto const calculateSecondaryBlocks
= [&](SizeType32 windowSize, float windowSizeShare, SizeType32 cacheSizeBytesPerToken)
{
auto const maxTokensSecondary
= static_cast<SizeType32>(allottedSecondaryMemBytes * windowSizeShare / cacheSizeBytesPerToken);
SizeType32 const blocksInSecondaryPool = std::max(0, maxTokensSecondary / tokensPerBlock);
TLLM_LOG_DEBUG(
"Number of blocks in KV cache secondary pool for windowSize %d: %d, onboard blocks to primary memory "
"before reuse: %s",
windowSize, blocksInSecondaryPool, config.getOnboardBlocks() ? "true" : "false");
return blocksInSecondaryPool;
};
std::map<SizeType32, float> windowSizeToShare;
// By default, we allocate equal proportion shares of memory for all
// window sizes (see the else case). With TRTLLM_WINDOW_SIZE_SHARES,
// we can override this behavior to adjust the memory share of each
// window size. For example, if we have window size of [512, 32768],
// then setting TRTLLM_WINDOW_SIZE_SHARES=0.4,0.6 will be allocating
// 40% of the memory to window size 512 and 60% of the memory to window
// size 32768.
if (auto envStr = std::getenv("TRTLLM_WINDOW_SIZE_SHARES"))
{
std::stringstream ss(envStr);
std::vector<float> shares;
float share;
while (ss >> share)
{
shares.push_back(share);
if (ss.peek() == ',')
ss.ignore();
}
TLLM_CHECK_WITH_INFO(shares.size() == windowSizeToLayers.size(),
"Number of shares in TRTLLM_WINDOW_SIZE_SHARES (%ld) must match number of window sizes (%ld)",
shares.size(), windowSizeToLayers.size());
float sumShares = 0.0f;
for (auto s : shares)
{
TLLM_CHECK_WITH_INFO(0.0f <= s && s <= 1.0f, "Shares must be in value range [0,1], got %f", s);
sumShares += s;
}
TLLM_CHECK_WITH_INFO(sumShares > 0.0f, "Sum of shares must be > 0.");
// Normalize shares to 1.0
for (auto& s : shares)
{
s /= sumShares;
}
size_t i = 0;
for (auto const& [windowSize, _] : windowSizeToLayers)
{
windowSizeToShare[windowSize] = shares[i++];
}
}
else
{
// NOTE: Righteously, blocks allocated should be proportional with
// regard to window size. Currently, we are first allocating identical
// number of blocks for all layers to achieve identical performance.
for (auto const& [windowSize, _] : windowSizeToLayers)
{
windowSizeToShare[windowSize] = 1.0f / windowSizeToLayers.size();
}
}
std::vector<SizeType32> blocksPrimary;
std::vector<SizeType32> blocksSecondary;
for (auto const& [windowSize, managedLayers] : windowSizeToLayers)
{
auto const cacheSizeBytesPerToken = cacheSizeBytesPerTokenPerWindow.at(windowSize);
auto const windowSizeShare = windowSizeToShare.at(windowSize);
auto const blocksInPrimaryPool = calculatePrimaryBlocks(windowSize, windowSizeShare, cacheSizeBytesPerToken);
auto const blocksInSecondaryPool
= calculateSecondaryBlocks(windowSize, windowSizeShare, cacheSizeBytesPerToken);
blocksPrimary.push_back(blocksInPrimaryPool);
blocksSecondary.push_back(blocksInSecondaryPool);
}
std::vector<SizeType32> windowSizes;
windowSizes.reserve(windowSizeToLayers.size());
for (auto const& [k, _] : windowSizeToLayers)
{
windowSizes.push_back(k);
}
if (worldConfig.getSize() > 1)
{
TLLM_CHECK(worldConfig.validMpiConfig());
auto const rank = worldConfig.getRank();
using tensorrt_llm::common::vec2str;
TLLM_CHECK_WITH_INFO(isSortedVectorIdenticalAcrossAllRanks(
worldConfig, windowSizes), // sorted thanks to windowSizeToLayers being a std::map
"[RANK %d] Asymmetrical pipeline parallelism detected: Ranks either have a different number of window "
"sizes, or differing values. This is not supported with Variable Sliding Window Attention. Local window "
"sizes for reference: %s",
rank, vec2str(windowSizes).c_str());
TLLM_LOG_DEBUG(
"[RANK %d] Before mpi::MpiOp::MIN reduction: window sizes %s / primary blocks %s / secondary blocks %s",
rank, vec2str(windowSizes).c_str(), vec2str(blocksPrimary).c_str(), vec2str(blocksSecondary).c_str());
// make sure all ranks use same value for max blocks
auto blocksWorld = blocksPrimary;
COMM_SESSION.allreduce(
blocksPrimary.data(), blocksWorld.data(), windowSizes.size(), mpi::MpiType::kINT32, mpi::MpiOp::MIN);
blocksPrimary = blocksWorld;
COMM_SESSION.allreduce(
blocksSecondary.data(), blocksWorld.data(), windowSizes.size(), mpi::MpiType::kINT32, mpi::MpiOp::MIN);
blocksSecondary = blocksWorld;
TLLM_LOG_DEBUG(
"[RANK %d] After mpi::MpiOp::MIN reduction: window sizes %s / primary blocks %s / secondary blocks %s",
rank, vec2str(windowSizes).c_str(), vec2str(blocksPrimary).c_str(), vec2str(blocksSecondary).c_str());
}
BlocksPerWindow windowSizeToBlocks;
TLLM_LOG_INFO("Blocks per window size:");
for (size_t i = 0; i < windowSizes.size(); ++i)
{
auto const windowSize = windowSizes.at(i);
auto const primaryBlocks = blocksPrimary.at(i);
auto const secondayBlocks = blocksSecondary.at(i);
TLLM_LOG_INFO(
"[windowSize=%d] {.primaryBlocks=%d, .secondayBlocks=%d}", windowSize, primaryBlocks, secondayBlocks);
windowSizeToBlocks[windowSize] = {primaryBlocks, secondayBlocks};
}
return windowSizeToBlocks;
}
void KVCacheManager::removeToken(RequestIdType requestId)
{
// TODO: add streamLLM support
auto& sequence = getSequence(requestId);
if (sequence.getNumTokens() == 0)
{
return;
}
TLLM_CHECK_WITH_INFO(sequence.getBeamWidth() == 1, "[kv cache manager] removeToken does not support beamWidth > 1");
sequence.removeTokens(1);
for (auto const [windowSize, metadata] : mBlockManager.getWindowSizesMetadata())
{
SizeType32 const tokensInWindow = sequence.getNumTokens() % windowSize;
if (tokensInWindow % getTokensPerBlock() == 0)
{
mBlockManager.releaseLastBlock(sequence, windowSize);
}
}
}
void KVCacheManager::rewindKVCache(RequestIdType requestId, SizeType32 rewindLengths)
{
for (SizeType32 si = 0; si < rewindLengths; ++si)
{
removeToken(requestId);
}
}
GenerationRequest const& KVCacheManager::getSequence(RequestIdType requestId) const
{
auto lck = std::scoped_lock(mSequencesMtx);
return mSequences.at(requestId);
}
GenerationRequest& KVCacheManager::getSequence(RequestIdType requestId)
{
auto lck = std::scoped_lock(mSequencesMtx);
return mSequences.at(requestId);
}
SizeType32 BaseKVCacheManager::getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock)
{
auto const sinkTokensInLastBlock = sinkTokenLen % tokensPerBlock;
auto const sinkBubbleLength = sinkTokensInLastBlock == 0 ? 0 : tokensPerBlock - sinkTokensInLastBlock;
return sinkBubbleLength;
}
std::vector<std::vector<SizeType32>> const& KVCacheManager::getCacheBlockIds(
RequestIdType requestId, SizeType32 windowSize) const
{
return getSequence(requestId).getCacheBlockIds(windowSize);
}
std::vector<std::vector<std::vector<SizeType32>>> KVCacheManager::getBatchCacheBlockIds(
std::vector<LlmRequest::RequestIdType> const& requestIds, SizeType32 windowSize) const
{
std::vector<std::vector<std::vector<SizeType32>>> result{};
result.reserve(requestIds.size());
for (auto const& requestId : requestIds)
{
auto const& sequence = getSequence(requestId);
result.emplace_back(sequence.getCacheBlockIds(windowSize));
}
return result;
}
std::optional<KVCacheBlock::IdType> KVCacheManager::getLastBlockId(LlmRequest::RequestIdType requestId) const
{
auto const& seq = getSequence(requestId);
// Use the first window size
auto firstWindowSize = mBlockManager.getFirstWindowSize();
if (firstWindowSize == 0)
{
return std::nullopt;
}
auto const& perBeam = seq.getCacheBlockIds(firstWindowSize);
if (perBeam.empty() || perBeam[0].empty())
{
return std::nullopt;
}
return perBeam[0].back();
}
runtime::ITensor::SharedPtr KVCacheManager::getUniquePrimaryPool() const
{
TLLM_CHECK_WITH_INFO(mBlockManager.getWindowSizesMetadata().size() == 1,
"getUniquePrimaryPool is only supported for a single window size");
return mBlockManager.getPrimaryPool(0);
}
runtime::ITensor::SharedPtr KVCacheManager::getPrimaryPool(SizeType32 layer_idx) const
{
return mBlockManager.getPrimaryPool(mBlockManager.getLayerPoolIdx(layer_idx));
}
runtime::ITensor::SharedPtr KVCacheManager::getIndexerKCachePool() const
{
return mIndexerKCachePoolPointers;
}
SizeType32 KVCacheManager::getMaxCapacityBatchSize(SizeType32 inputLength, SizeType32 outputLength) const
{
auto minMaxBatchSizeAllWindows = std::numeric_limits<SizeType32>::max();
for (auto const [windowSize, metadata] : mBlockManager.getWindowSizesMetadata())
{
auto const blockRequirementsPerSequence = KVCacheManager::calculateMaxBlockRequirements(
inputLength, outputLength, mSinkBlockTokenLength, windowSize, mMaxBeamWidth, mTokensPerBlock);
auto const maxBatchSizeWindow = metadata.allottedPrimaryBlocks / blockRequirementsPerSequence;
// The window with the *smallest* max batch size is the limiting factor
// Hence, the std::*min* of all the max batch sizes is chosen
minMaxBatchSizeAllWindows = std::min(minMaxBatchSizeAllWindows, maxBatchSizeWindow);
}
return minMaxBatchSizeAllWindows;
}
SizeType32 KVCacheManager::calculateMaxBlockRequirementsPerBeam(
SizeType32 sequenceLength, SizeType32 sinkTokenLength, SizeType32 windowSize, SizeType32 tokensPerBlock)
{
auto const sinkBubbleLength = BaseKVCacheManager::getSinkBubbleLength(sinkTokenLength, tokensPerBlock);
auto const actualSeqLen = std::min(sequenceLength, windowSize);
auto actualMaxTokenNum = actualSeqLen + sinkBubbleLength;
auto numBlocks = tc::ceilDiv(actualMaxTokenNum, tokensPerBlock);
if (sequenceLength > windowSize)
{
numBlocks += kSWAExtraBlock;
}
return numBlocks;
}
SizeType32 KVCacheManager::calculateMaxBlockRequirements(SizeType32 inputLength, SizeType32 outputLength,
SizeType32 sinkTokenLength, SizeType32 windowSize, SizeType32 beamWidth, SizeType32 tokensPerBlock)
{
// We split beam width > 1, as it introduces a lot of complexity.
auto const wholeSequenceLength = inputLength + outputLength;
if (beamWidth == 1)
{
return KVCacheManager::calculateMaxBlockRequirementsPerBeam(
wholeSequenceLength, sinkTokenLength, windowSize, tokensPerBlock);
}
if (windowSize <= outputLength)
{
// We at most will need outputLength of distinct blocks for SWA
return KVCacheManager::calculateMaxBlockRequirementsPerBeam(
outputLength, sinkTokenLength, windowSize, tokensPerBlock)
* beamWidth;
}
// Otherwise, we calculate how many tokens will be in output blocks.
auto const effectiveAttentionWindow = std::min(windowSize, wholeSequenceLength);
auto const numContextTokensInAttentionWindow
= effectiveAttentionWindow - outputLength; // This is positive because we handled the other case above.
auto const sinkBubbleLength = BaseKVCacheManager::getSinkBubbleLength(sinkTokenLength, tokensPerBlock);
auto const numContextBlocks = (numContextTokensInAttentionWindow + sinkBubbleLength) / tokensPerBlock;
auto const leftoverContextToken = numContextTokensInAttentionWindow - numContextBlocks * tokensPerBlock;
auto numOutputBlocks = tc::ceilDiv(outputLength + leftoverContextToken, tokensPerBlock);
if (wholeSequenceLength > windowSize)
{
numOutputBlocks += kSWAExtraBlock;
}
return numContextBlocks + numOutputBlocks * beamWidth;
}
[[nodiscard]] SizeType32 KVCacheManager::calculateMaxAttentionWindow(SizeType32 inputLength, SizeType32 outputLength,
SizeType32 sinkTokenLength, SizeType32 blockCapacity, SizeType32 beamWidth, SizeType32 tokensPerBlock)
{
// The function that gives the number of blocks required given an attention window is only linear by part. It is
// however, monotonically increasing in the attention window. Therefore, we need to find in which part of the
// function we are. First, are we in a case where not even the entire output will fit?
auto const outputBlockRequirements
= calculateMaxBlockRequirements(0, outputLength, sinkTokenLength, outputLength, beamWidth, tokensPerBlock);
if (outputBlockRequirements > blockCapacity)
{
return (blockCapacity / beamWidth) * tokensPerBlock;
}
// Otherwise, we need to determine how many context tokens we can fit on top of the output tokens. First, there
// are a few context tokens we might be able to fit 'for free' because the output is not a multiple of the
// number of tokens per block.
auto const leftoverBlockCapacity = blockCapacity - outputBlockRequirements;
return std::min(outputLength + leftoverBlockCapacity * tokensPerBlock, inputLength + outputLength);
}
} // namespace tensorrt_llm::batch_manager::kv_cache_manager