/* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/batch_manager/common.h" #include "tensorrt_llm/batch_manager/evictionPolicy.h" #include "tensorrt_llm/batch_manager/kvCacheTransferManager.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/kernels/kvCacheIndex.h" #include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/iBuffer.h" #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/modelConfig.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" #include "tensorrt_llm/runtime/worldConfig.h" #include #include #include #include namespace tc = tensorrt_llm::common; namespace tk = tensorrt_llm::kernels; namespace tle = tensorrt_llm::executor; using namespace tle::kv_cache; using namespace tensorrt_llm::runtime; using namespace tensorrt_llm::batch_manager::kv_cache_manager; using namespace tensorrt_llm::batch_manager::eviction_policy; using BlocksPerWindow = std::map>; namespace { inline uint8_t getNthByte(SizeType32 hashPart, uint8_t byteIdx) noexcept { return static_cast((hashPart >> (24 - byteIdx * 8)) & 0xFF); } //! \brief Get all blocks in a sequence by traversing backwards from the last block. //! \param lastBlock is a BlockPtr to the last block in the sequence to start traversal from //! \return Vector of BlockPtr-s in sequence order std::vector getAllSequenceBlocks(BlockPtr lastBlock) { // First count the number of blocks to pre-allocate the vector auto currentBlock = lastBlock; size_t blockCount = 0; while (currentBlock != nullptr && currentBlock->getBlockId() != KVCacheBlock::kCachedBlocksRootId) { blockCount++; currentBlock = currentBlock->getPrevBlockInSeq(); } if (blockCount == 0) { return {}; } // Create and pre-allocate the vector with the correct size std::vector sequenceBlocks(blockCount); // Now traverse backwards and fill from the end currentBlock = lastBlock; size_t currentIndex = blockCount - 1; while (currentBlock != nullptr && currentBlock->getBlockId() != KVCacheBlock::kCachedBlocksRootId) { sequenceBlocks[currentIndex--] = currentBlock; currentBlock = currentBlock->getPrevBlockInSeq(); } return sequenceBlocks; } } // namespace namespace tensorrt_llm::batch_manager::kv_cache_manager { std::vector generateBlockHashExtraKeys( tensorrt_llm::batch_manager::LlmRequest const& llmRequest, SizeType32 startTokenIdx, SizeType32 endTokenIdx) { auto const multimodalHashes = llmRequest.getMultimodalHashes(); auto const multimodalPositions = llmRequest.getMultimodalPositions(); auto const multimodalLengths = llmRequest.getMultimodalLengths(); if (!multimodalHashes || !multimodalPositions || !multimodalLengths || !(*multimodalHashes) || (*multimodalHashes)->empty() || !(*multimodalPositions) || (*multimodalPositions)->empty() || !(*multimodalLengths) || (*multimodalLengths)->empty()) { return {}; } if ((*multimodalHashes)->size() != (*multimodalPositions)->size() || (*multimodalPositions)->size() != (*multimodalLengths)->size()) { TLLM_LOG_WARNING("Multimodal data arrays have mismatched sizes"); return {}; } std::vector extraKeys; // MmKey = std::pair, SizeType32> extraKeys.reserve((*multimodalPositions)->size()); std::array mmHashArray; for (size_t i = 0; i < (*multimodalPositions)->size(); ++i) { auto const& startPos = (*(*multimodalPositions))[i]; auto const& length = (*(*multimodalLengths))[i]; auto const& mmHashVector = (*(*multimodalHashes))[i]; TLLM_CHECK_WITH_INFO(mmHashVector.size() == 8, "Multimodal hash vector has unexpected size: %zu (expected 8)", mmHashVector.size()); // mmHashVector[j] comes from Python's int(hex_chunk, 16) // where hex_chunk like "00010203" means 0x00 is MSB and 0x03 is LSB (big endian) // Convert 8x 32-bit integers into a 32-byte array preserving Blake3 hash byte order // Example: hashPart = 0x00010203 → mmHashArray[0:3] = [0x00, 0x01, 0x02, 0x03] for (size_t j = 0; j < 8; ++j) { auto const& hashPart = mmHashVector[j]; for (uint8_t byteIdx = 0; byteIdx < 4; ++byteIdx) { mmHashArray[j * 4 + byteIdx] = getNthByte(hashPart, byteIdx); } } // Check if this multimodal content overlaps with the current block if (endTokenIdx > startPos && startTokenIdx < startPos + length) { uint64_t mmStartInBlock = (startPos >= startTokenIdx) ? 0 : static_cast(startTokenIdx - startPos); extraKeys.emplace_back(mmHashArray, mmStartInBlock); } } return extraKeys; } std::vector buildBlockKeys( std::list& blockedUniqueTokens, tensorrt_llm::batch_manager::LlmRequest const& llmRequest) { std::vector blockKeys; SizeType32 currentTokenIdx = 0; for (auto& uniqueTokens : blockedUniqueTokens) { auto extraKeys = generateBlockHashExtraKeys(llmRequest, currentTokenIdx, currentTokenIdx + uniqueTokens.size()); currentTokenIdx += uniqueTokens.size(); blockKeys.emplace_back(llmRequest.getInputTokensExtraIds().has_value(), llmRequest.getLoraTaskId(), std::move(uniqueTokens), std::move(extraKeys), llmRequest.getCacheSaltID()); } return blockKeys; } bool BlockKey::operator==(BlockKey const& other) const noexcept { return (usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId && uniqueTokens == other.uniqueTokens && extraKeys == other.extraKeys && cacheSaltID == other.cacheSaltID); } size_t BlockKeyHasher::hash(BlockKey const& blockKey, std::size_t parentHash) noexcept { // Hashing algorithm adapted from StackOverflow: // https://stackoverflow.com/questions/664014/what-integer-hash-function-are-good-that-accepts-an-integer-hash-key // Constants provide very good distribution - each input bit affects each output bit with ~50% probability. size_t seed = blockKey.uniqueTokens.size() ^ parentHash * UINT64_C(0xbf58476d1ce4e5b9); if (parentHash == 0 && blockKey.cacheSaltID) { // Only hashing the cache salt ID for the first block in the sequence uint64_t c = blockKey.cacheSaltID.value(); c = (c ^ (c >> 30)) * UINT64_C(0xbf58476d1ce4e5b9); c = (c ^ (c >> 27)) * UINT64_C(0x94d049bb133111eb); c = c ^ (c >> 31); seed ^= c + 0x9e3779b9 + (seed << 6) + (seed >> 2); } for (auto const& uniqueToken : blockKey.uniqueTokens) { uint32_t a = static_cast(uniqueToken.tokenId); a = ((a >> 16) ^ a) * 0x45d9f3b; a = ((a >> 16) ^ a) * 0x45d9f3b; a = (a >> 16) ^ a; seed ^= a + 0x9e3779b9 + (seed << 6) + (seed >> 2); if (blockKey.usesExtraIds) { uint64_t b = uniqueToken.tokenExtraId; b = (b ^ (b >> 30)) * UINT64_C(0xbf58476d1ce4e5b9); b = (b ^ (b >> 27)) * UINT64_C(0x94d049bb133111eb); b = b ^ (b >> 31); seed ^= b + 0x9e3779b9 + (seed << 6) + (seed >> 2); } } if (blockKey.loraTaskId) { uint64_t c = blockKey.loraTaskId.value(); c = (c ^ (c >> 30)) * UINT64_C(0xbf58476d1ce4e5b9); c = (c ^ (c >> 27)) * UINT64_C(0x94d049bb133111eb); c = c ^ (c >> 31); seed ^= c + 0x9e3779b9 + (seed << 6) + (seed >> 2); } // Add extra keys for multimodal data mixing in external multimodal item hash and token offset within this sequence // block if (!blockKey.extraKeys.empty()) { for (auto const& [mmHash, startOffset] : blockKey.extraKeys) { // Hash the multimodal hash array in 32-bit chunks (more efficient) for (size_t i = 0; i < 32; i += 4) { // Combine 4 bytes into a 32-bit word (construct as little endian order) uint32_t word = static_cast(mmHash[i]) | (static_cast(mmHash[i + 1]) << 8) | (static_cast(mmHash[i + 2]) << 16) | (static_cast(mmHash[i + 3]) << 24); // Mix the word into the seed word = ((word >> 16) ^ word) * 0x45d9f3b; word = ((word >> 16) ^ word) * 0x45d9f3b; word = (word >> 16) ^ word; seed ^= word + 0x9e3779b9 + (seed << 6) + (seed >> 2); } // Hash the start offset uint64_t e = static_cast(startOffset); e = (e ^ (e >> 30)) * UINT64_C(0xbf58476d1ce4e5b9); e = (e ^ (e >> 27)) * UINT64_C(0x94d049bb133111eb); e = e ^ (e >> 31); seed ^= e + 0x9e3779b9 + (seed << 6) + (seed >> 2); } } return seed; } KVCacheBlock::KVCacheBlock(IdType blockId, tk::KVCacheIndex blockIdx) : mBlockId(blockId) , mMemoryPoolBlockIndex{blockIdx} , mRefCount(0) , mSchedulingRefCount(0) , mPrevBlock(nullptr) , mFreeBlockIterator(std::nullopt) , mIsFull{false} , mPriority{executor::KvCacheRetentionConfig::kDefaultRetentionPriority} , mDurationMs{std::nullopt} , mExpirationTime{std::nullopt} , mHash{0} { } void KVCacheBlock::startScheduling() { mSchedulingRefCount = mRefCount; } KVCacheBlock::IdType KVCacheBlock::getBlockId() const { return mBlockId; } NextBlockMap KVCacheBlock::getNextBlocks() const { return mNextBlocks; } tk::KVCacheIndex::UnderlyingType KVCacheBlock::getMemoryPoolBlockIndex() const { return mMemoryPoolBlockIndex.get(); } std::vector KVCacheBlock::getExtraKeys() const { return mBlockKey.extraKeys; } bool KVCacheBlock::isPrimary() const { return mMemoryPoolBlockIndex.isPrimary(); } void KVCacheBlock::swapMemoryPoolBlockOffset(std::shared_ptr otherBlock) { std::swap(mMemoryPoolBlockIndex, otherBlock->mMemoryPoolBlockIndex); } void KVCacheBlock::incRefCount() { mRefCount++; } void KVCacheBlock::decRefCount() { TLLM_CHECK_WITH_INFO( hasRefs(), "Can't remove link from block (id=%d) that is not allocated", static_cast(mBlockId)); mRefCount--; } void KVCacheBlock::decSchedulingRefCount() { TLLM_CHECK_WITH_INFO(hasSchedulingRefs(), "Can't remove link from block that is not allocated"); mSchedulingRefCount--; } bool KVCacheBlock::hasRefs() const { return mRefCount > 0; } bool KVCacheBlock::isShared() const { // block is considered shared if ready for reuse return mRefCount > 1 || mPrevBlock != nullptr; } bool KVCacheBlock::hasSchedulingRefs() const { return mSchedulingRefCount > 0; } void KVCacheBlock::setBlockKey(BlockKey const& blockKey, bool isFull) { mBlockKey = blockKey; mIsFull = isFull; } BlockKey KVCacheBlock::getBlockKey() { return mBlockKey; } void KVCacheBlock::setPriority(executor::RetentionPriority priority) { mPriority = priority; } executor::RetentionPriority KVCacheBlock::getPriority() const { return mPriority; } std::optional KVCacheBlock::getDurationMs() const { return mDurationMs; } void KVCacheBlock::setDurationMs(std::optional durationMs) { mDurationMs = durationMs; } void KVCacheBlock::setExpirationTime(std::optional expirationTime) { mExpirationTime = expirationTime; } std::optional KVCacheBlock::getExpirationTime() const { return mExpirationTime; } void KVCacheBlock::setHash(size_t hash) { mHash = hash; } void KVCacheBlock::setHash() { mHash = BlockKeyHasher()(mBlockKey, mPrevBlockInSeq ? mPrevBlockInSeq->getHash() : 0); } size_t KVCacheBlock::getHash() const { return mHash; } VecUniqueTokens const& KVCacheBlock::getUniqueTokens() const { return mBlockKey.uniqueTokens; } BlockPtr const& KVCacheBlock::getPrevBlock() const { return mPrevBlock; } void KVCacheBlock::setPrevBlock(BlockPtr prevBlock) { mPrevBlock = std::move(prevBlock); } BlockPtr const& KVCacheBlock::getPrevBlockInSeq() const { return mPrevBlockInSeq; } void KVCacheBlock::setPrevBlockInSeq(BlockPtr prevBlock) { mPrevBlockInSeq = std::move(prevBlock); } void KVCacheBlock::addNextBlock(BlockKey const& blockKey, BlockPtr block) { if (mNextBlocks.find(blockKey) == mNextBlocks.end()) { mNextBlocks[blockKey] = std::move(block); } } std::tuple KVCacheBlock::findMatchingBlock( BlockKey const& blockKey, bool enablePartialReuse, bool copyOnPartialReuse) const { if (blockKey.uniqueTokens.size() == 0 || mNextBlocks.size() == 0) { return {false, 0, nullptr}; } auto itr = mNextBlocks.find(blockKey); if (itr == mNextBlocks.end()) { if (enablePartialReuse) { SizeType32 bestNumMatched{0}; BlockPtr bestBlock{nullptr}; for (auto const& [key, block] : mNextBlocks) { if (copyOnPartialReuse || (!block->hasRefs() && block->isLeaf())) { SizeType32 numMatched = key.partialMatch(blockKey); if (numMatched > bestNumMatched) { bestNumMatched = numMatched; bestBlock = block; } } } if (bestNumMatched > 0) { return {true, bestNumMatched, bestBlock}; } } return {false, 0, nullptr}; } auto block = itr->second; return {!block->isFull(), static_cast(blockKey.uniqueTokens.size()), block}; } void KVCacheBlock::freeLeafBlock() { // assure that this is a leaf block TLLM_CHECK(isLeaf()); // free from previous block if (mPrevBlock != nullptr) { mPrevBlock->removeNextBlock(mBlockKey); mPrevBlock = nullptr; } } void KVCacheBlock::removeNextBlock(BlockKey const& blockKey) { mNextBlocks.erase(blockKey); } bool KVCacheBlock::isFull() const { return mIsFull; } bool KVCacheBlock::isLeaf() const { return mNextBlocks.empty(); } // This function calculates the number of block a layer should have, given // the total free memory and the window size of each layer. // For example, if we have 1 layer of window size 1024, and 2 layer of window // size 2048, and 3 layers of 4096. // Each layer of window size 1024 should have // 1024 / (1024 + 2048 * 2 + 4096 * 3) proportion of the total blocks. // Each layer of window size 2048 should have // 2048 / (1024 + 2048 * 2 + 4096 * 3) proportion of the total blocks. // Each layer of window size 4096 should have // 4096 / (1024 + 2048 * 2 + 4096 * 3) proportion of the total blocks. // NOTE: Currently the use of this function is not used for // BaseKVCacheManager::calculateMaxNumBlocks because the we want to first // achieve identical performance as assuming all layers as full attention. std::map BlockManager::calculateWindowSizeToShare( std::map> const& windowSizeToLayers, std::map const& windowSizeToCacheSizePerToken) { if (windowSizeToLayers.size() == 1) { return {{windowSizeToLayers.begin()->first, 1.0f}}; } std::map windowSizeToContribution; SizeType32 cacheSizePerTokenTotal = std::accumulate(windowSizeToCacheSizePerToken.begin(), windowSizeToCacheSizePerToken.end(), SizeType32{0}, [](auto sum, auto const& windowSize) { return sum + windowSize.second; }); for (auto const& [windowSize, cacheSizePerToken] : windowSizeToCacheSizePerToken) { auto const cacheSizeWeight = static_cast(cacheSizePerToken) / cacheSizePerTokenTotal; windowSizeToContribution[windowSize] = cacheSizeWeight; } for (auto const& [windowSize, _] : windowSizeToLayers) { windowSizeToContribution.at(windowSize) *= windowSize; } auto const windowSizesTotalSum = std::accumulate(windowSizeToContribution.begin(), windowSizeToContribution.end(), 0.0, [](auto sum, auto const& windowSize) { return sum + windowSize.second; }); std::map windowSizeToShare; for (auto const& [windowSize, windowSizeSum] : windowSizeToContribution) { float const fraction = windowSizeSum / windowSizesTotalSum; TLLM_CHECK(0.0f < fraction && fraction <= 1.0f); windowSizeToShare[windowSize] = fraction; } auto total = std::accumulate(windowSizeToShare.begin(), windowSizeToShare.end(), 0.0f, [](auto sum, auto const& windowSize) { return sum + windowSize.second; }); TLLM_CHECK(total == 1.0f); return windowSizeToShare; } BlockManager::BlockManager(std::vector const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, std::shared_ptr stream, SizeType32 maxSequenceLength, SizeType32 maxBeamWidth, std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, std::shared_ptr kvCacheConnectorManager, std::optional agentConfig, bool enableIndexerKCache, SizeType32 indexerKCacheQuantBlockSize, SizeType32 indexerKCacheIndexHeadDim) : mNumLayers{static_cast(numKvHeadsPerLayer.size())} , mTokensPerBlock{tokensPerBlock} , mEventManager{std::move(eventManager)} , mStream{stream} , mCacheType{cacheType} , mIsEnableIndexerKCache{enableIndexerKCache} , mIndexerKCacheQuantBlockSize{indexerKCacheQuantBlockSize} , mIndexerKCacheIndexHeadDim{indexerKCacheIndexHeadDim} { if (agentConfig.has_value()) mLoopbackAgent = makeLoopbackAgent("nixl", &agentConfig.value()); else mLoopbackAgent = nullptr; auto const uniqueWindowSizeToLayers = BaseKVCacheManager::groupLayersByWindowSize(maxAttentionWindowVec, mNumLayers); TLLM_CHECK_WITH_INFO(kvCacheConnectorManager == nullptr || uniqueWindowSizeToLayers.size() == 1, "KV Cache Connector is not supported with multiple window sizes"); auto const numUniqueWindowSizes = static_cast(uniqueWindowSizeToLayers.size()); mIsVariableWindow = numUniqueWindowSizes > 1; mIsVariableGQA = std::unordered_set(numKvHeadsPerLayer.begin(), numKvHeadsPerLayer.end()).size() > 1; mLayerToWindowSize.resize(mNumLayers); for (auto const& [windowSize, layersWithWindowSize] : uniqueWindowSizeToLayers) { if (windowSize > maxSequenceLength) { TLLM_LOG_WARNING("[kv cache manager] window size %d is greater than max sequence length %d", windowSize, maxSequenceLength); } for (auto& layerIdx : layersWithWindowSize) { mLayerToWindowSize.at(layerIdx) = windowSize; } auto const [allottedPrimaryBlocks, allottedSecondaryBlocks] = blocksPerWindow.at(windowSize); TLLM_CHECK(allottedPrimaryBlocks > 0); // You can't have a model with negative primary blocks... mWindowBlockManagers.try_emplace(windowSize, dtype, windowSize, layersWithWindowSize, numKvHeadsPerLayer, sizePerHead, tokensPerBlock, /*isSWA=*/windowSize < maxSequenceLength, allottedPrimaryBlocks, allottedSecondaryBlocks, maxNumSequences, stream, onboardBlocks, cacheType, secondaryOffloadMinPriority, mEventManager, enablePartialReuse, copyOnPartialReuse, kvCacheConnectorManager, mLoopbackAgent, enableIndexerKCache, indexerKCacheQuantBlockSize, indexerKCacheIndexHeadDim); } auto const numAllPools = getNumPools(); mAbsolutePoolToWindowSize.reserve(numAllPools); mAbsolutePoolToRelativePoolIndex.reserve(numAllPools); auto absolutePoolsOffset = SizeType32{0}; for (auto const& [windowSize, manager] : mWindowBlockManagers) { auto const numPools = manager.getNumPools(); for (auto i = 0; i < numPools; ++i) { mAbsolutePoolToWindowSize.push_back(windowSize); mAbsolutePoolToRelativePoolIndex.push_back(i); } // SWA allocates blocks linearly, and we need as many blocks as full attention, // where full attention has windowSize = maxSequenceLength. auto const maxTokenNum = std::max(windowSize, maxSequenceLength) + sinkBubbleLength; auto const temporaryAttentionWindow = manager.calculateTemporaryAttentionWindow(tempAttentionWindowInputs); // Consider the temporaryAttentionWindow when allocating blocks. // Current tempAttentionWindow calculation does not consider the // concept of SWA right now at most occupying maxSequenceLength of // blocks. So the calculation of maxToken + tempAttention will exceed // maxSequenceLength. A temporary resolution here is to cap the // calculation to maxSequenceLength. I will proceed with a follow-up // MR to remove the tempAttentionWindow concept. auto const maxBlocksPerSeq = tc::ceilDiv(std::min(maxSequenceLength, maxTokenNum + temporaryAttentionWindow), tokensPerBlock); auto const [allottedPrimaryBlocks, allottedSecondaryBlocks] = blocksPerWindow.at(windowSize); mWindowSizeToMetadata[windowSize] = WindowSizeMetadata{allottedPrimaryBlocks, allottedSecondaryBlocks, absolutePoolsOffset, numPools, maxTokenNum, maxBlocksPerSeq, manager.getMaxNumBlocks(), temporaryAttentionWindow, windowSize, manager.isSWA()}; TLLM_LOG_INFO( "Max KV cache blocks per sequence: %d [window size=%d], tokens per block=%d, primary blocks=%d, secondary " "blocks=%d, max sequence length=%d", maxBlocksPerSeq, windowSize, tokensPerBlock, allottedPrimaryBlocks, allottedSecondaryBlocks, maxSequenceLength); TLLM_LOG_DEBUG( "%s Metadata: %s", manager.getLogPrefix().c_str(), mWindowSizeToMetadata[windowSize].toString().c_str()); absolutePoolsOffset += numPools; } TLLM_CHECK_WITH_INFO(mWindowBlockManagers.size() == mWindowSizeToMetadata.size() && std::equal(mWindowBlockManagers.cbegin(), mWindowBlockManagers.cend(), mWindowSizeToMetadata.cbegin(), mWindowSizeToMetadata.cend(), [](auto const& window1, auto const& window2) { return window1.first == window2.first; }), "Iteration order of window sizes between mWindowBlockManagers and mWindowSizeToMetadata *must* be ensured. " "Maybe you tried changing either of them to an std::unordered_map?"); } WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 windowSize, std::vector const& managedLayers, std::vector const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock, bool isSWA, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr stream, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, std::shared_ptr kvCacheConnectorManager, std::shared_ptr loopbackAgent, bool enableIndexerKCache, SizeType32 indexerKCacheQuantBlockSize, SizeType32 indexerKCacheIndexHeadDim) : mDataType{dtype} , mWindowSize{windowSize} , mNumPrimaryBlocks{blocksInPrimaryPool} , mNumSecondaryBlocks{blocksInSecondaryPool} , mOnboardBlocks(onboardBlocks) , mBufferManager{std::move(stream)} , mSchedulingNumFreeBlocks{0} , mTokensPerBlock{tokensPerBlock} , mIsSWA{isSWA} , mCachedBlocksRoot{std::make_shared(KVCacheBlock::kCachedBlocksRootId, tk::KVCacheIndex{0})} , mCacheType{cacheType} , mEventManager(std::move(eventManager)) , mLoopbackAgent{loopbackAgent} , mTransferManager{std::make_shared(mBufferManager, mLoopbackAgent)} , mAllocTotalBlocks{0} , mAllocNewBlocks{0} , mReusedBlocks{0} , mReusedUniqueBlocks{0} , mMissedBlocks{0} , mKVFactor{mCacheType == CacheType::kSELFKONLY ? 1 : 2} , mLogPrefix{tensorrt_llm::common::fmtstr("BlockManager[windowSize=%u]", mWindowSize)} , mReusedTokens{0.0} , mTotalInputTokens{0.0} , mEnablePartialReuse{enablePartialReuse} , mCopyOnPartialReuse{copyOnPartialReuse} , mKvCacheConnectorManager{std::move(kvCacheConnectorManager)} , mEnableIndexerKCache{enableIndexerKCache} , mIndexerKCacheQuantBlockSize{indexerKCacheQuantBlockSize} , mIndexerKCacheIndexHeadDim{indexerKCacheIndexHeadDim} { std::map numLayersPerPool; for (auto const layerIdx : managedLayers) { auto const& layerIndexWithinPool = numLayersPerPool[numKvHeadsPerLayer.at(layerIdx)]++; mLayerToIndexWithinPool[layerIdx] = layerIndexWithinPool; } auto numEltsPerContainer = getNumEltsPerContainer(); #ifdef ENABLE_FP4 if (numEltsPerContainer == 2) { TLLM_CHECK_WITH_INFO(sizePerHead % 2 == 0, "sizePerHead must be divisible by 2 for 4-bit KV cache."); } #endif size_t poolIndex = 0; for (auto const [numKvHeads, numLayers] : numLayersPerPool) { for (auto const layerIdx : managedLayers) { if (numKvHeadsPerLayer.at(layerIdx) == numKvHeads) { mLayerToPoolIndex[layerIdx] = poolIndex; } } mPools.emplace_back(numLayers, mKVFactor, numKvHeads, sizePerHead / numEltsPerContainer, tokensPerBlock); ++poolIndex; } #ifdef ENABLE_FP4 // TODO(miovine): make the block size configurable. Should we have an additional argument // to specify FP4 related parameters (scale dtypes, etc)? This can also be passed // in the constructor. constexpr SizeType32 kQuantBlockSizeNVFP4 = 16; if (dtype == nvinfer1::DataType::kFP4) { createBlockScalePools(kQuantBlockSizeNVFP4); } #endif if (mEnableIndexerKCache) { createIndexerKCachePools(); } // Create free blocks mAllBlocksById.reserve(blocksInPrimaryPool + blocksInSecondaryPool); for (KVCacheBlock::IdType blockId = 0; blockId < blocksInPrimaryPool; ++blockId) { mAllBlocksById.emplace_back(std::make_shared(blockId, tk::KVCacheIndex{blockId, false})); } for (KVCacheBlock::IdType blockId = 0; blockId < blocksInSecondaryPool; ++blockId) { mAllBlocksById.emplace_back( std::make_shared(blocksInPrimaryPool + blockId, tk::KVCacheIndex{blockId, true})); } mAllocatedBlocksPerSeq.reserve(maxNumSequences); mEvictionPolicy = std::make_shared(); mEvictionPolicy->initialize( mAllBlocksById, {blocksInPrimaryPool, blocksInSecondaryPool}, secondaryOffloadMinPriority); if (mEventManager) { mEventManager->enqueueCreatedEvent({blocksInPrimaryPool, blocksInSecondaryPool}, mWindowSize); } } WindowBlockManager::~WindowBlockManager() { float reusedUniqueBlocksPercentage = mReusedUniqueBlocks == 0 || mAllocTotalBlocks == 0 ? 0 : static_cast(mReusedUniqueBlocks) / static_cast(mAllocNewBlocks) * 100; float cacheHitRate = mReusedBlocks == 0 ? 0 : static_cast(mReusedBlocks) / (static_cast(mReusedBlocks + mMissedBlocks)); TLLM_LOG_DEBUG("%s - total allocated blocks: %lu ", mLogPrefix.c_str(), mAllocTotalBlocks); TLLM_LOG_DEBUG("%s - allocated new blocks: %lu ", mLogPrefix.c_str(), mAllocNewBlocks); TLLM_LOG_DEBUG("%s - missed blocks: %lu ", mLogPrefix.c_str(), mMissedBlocks); TLLM_LOG_DEBUG("%s - reused blocks: %lu ", mLogPrefix.c_str(), mReusedBlocks); TLLM_LOG_DEBUG("%s - reused unique blocks: %lu ", mLogPrefix.c_str(), mReusedUniqueBlocks); TLLM_LOG_DEBUG( "%s - reused unique blocks percentage (%%): %.2f ", mLogPrefix.c_str(), reusedUniqueBlocksPercentage); TLLM_LOG_DEBUG("%s - cache hit rate: %.2f ", mLogPrefix.c_str(), cacheHitRate); TLLM_LOG_DEBUG("%s - reused tokens: %.0f ", mLogPrefix.c_str(), mReusedTokens); TLLM_LOG_DEBUG("%s - reused tokens percentage (%%): %.2f ", mLogPrefix.c_str(), 100.0 * mReusedTokens / mTotalInputTokens); } bool BlockManager::verifyQueueIntegrity(SizeType32 windowSize) { return mWindowBlockManagers.at(windowSize).verifyQueueIntegrity(); } bool WindowBlockManager::verifyQueueIntegrity() { return mEvictionPolicy->verifyQueueIntegrity(); } void BlockManager::storeContextBlocks(GenerationRequest& sequence, LlmRequest const& llmRequest) { constexpr int beamIdx = 0; // no need to consider more than one beam for input tokens for (auto const& [windowSize, _] : mWindowBlockManagers) { if (mWindowBlockManagers.at(windowSize).isSWA()) { // SWA cannot store new blocks on the fly because the block stored // may go OOW and be reused by another sequence. continue; } auto cacheBlockIds = sequence.getCacheBlockIds(windowSize); auto const& uniqueTokens = llmRequest.getUniqueTokens(beamIdx); auto blockedUniqueTokens = chopVectorIntoBlocks(uniqueTokens, uniqueTokens.size() - 1, getTokensPerBlock(), false); auto blockKeys = buildBlockKeys(blockedUniqueTokens, llmRequest); (void) mWindowBlockManagers.at(windowSize).storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]); } } void WindowBlockManager::createBlockScalePools(SizeType32 quantBlockSize) { SizeType32 const numEltsPerContainer = getNumEltsPerContainer(); SizeType32 numPools = mPools.size(); for (SizeType32 i = 0; i < numPools; ++i) { auto& kvPool = mPools[i]; if (kvPool.containsIndexerKCache || kvPool.containsBlockScales) { continue; } TLLM_CHECK_WITH_INFO((kvPool.sizePerHead * numEltsPerContainer) % quantBlockSize == 0, "Cannot use FP4 quantization since kvPool.sizePerHead is not divisible by FP4 quantBlockSize."); auto blockScaleSizePerHead = kvPool.sizePerHead * numEltsPerContainer / quantBlockSize; mPools.emplace_back(kvPool.numLayers, kvPool.kvFactor, kvPool.numKvHeads, blockScaleSizePerHead, kvPool.tokensPerBlock, /*primaryPool=*/nullptr, /*secondaryPool=*/nullptr, /*containsBlockScales=*/true, /*containsIndexerKCache=*/false); } } void WindowBlockManager::createIndexerKCachePools() { SizeType32 numPools = mPools.size(); for (SizeType32 i = 0; i < numPools; ++i) { auto& kvPool = mPools[i]; if (kvPool.containsIndexerKCache || kvPool.containsBlockScales) { continue; } SizeType32 scaleSize = mIndexerKCacheIndexHeadDim / mIndexerKCacheQuantBlockSize * 4; mPools.emplace_back(kvPool.numLayers, kvPool.kvFactor, 1, scaleSize + mIndexerKCacheIndexHeadDim, kvPool.tokensPerBlock, /*primaryPool=*/nullptr, /*secondaryPool=*/nullptr, /*containsBlockScales=*/false, /*containsIndexerKCache=*/true); } } void BlockManager::allocatePools(bool useUvm) { for (auto& [_, manager] : mWindowBlockManagers) { manager.allocatePools(useUvm); } } void WindowBlockManager::allocatePools(bool useUvm) { constexpr nvinfer1::DataType kScaleDtypeNVFP4 = nvinfer1::DataType::kFP8; // Allocate a memory pool backing the blocks for each numKvHeads // TODO(oargov): allocate pools in a single buffer and split it, to avoid fragmentation for (auto& pool : mPools) { auto blockSize = pool.blockSize; auto poolDtype = pool.containsBlockScales ? kScaleDtypeNVFP4 : mDataType; #ifdef ENABLE_FP4 auto const poolIsFP4 = poolDtype == nvinfer1::DataType::kFP4; #else auto const poolIsFP4 = false; #endif if (poolIsFP4) { poolDtype = nvinfer1::DataType::kINT8; } if (pool.containsIndexerKCache) { poolDtype = nvinfer1::DataType::kUINT8; } nvinfer1::Dims cacheShape; cacheShape = ITensor::makeShape({mNumPrimaryBlocks, pool.numLayers, mKVFactor, blockSize}); TLLM_LOG_DEBUG("[%s] Allocating primary pool with %d blocks for %d layers with %d kv heads", mLogPrefix.c_str(), mNumPrimaryBlocks, pool.numLayers, pool.numKvHeads); if (useUvm) pool.primaryPtr = BufferManager::managed(cacheShape, poolDtype); else pool.primaryPtr = mBufferManager.gpuSync(cacheShape, poolDtype); if (mNumSecondaryBlocks > 0) { nvinfer1::Dims const cacheShapeOffload = ITensor::makeShape({mNumSecondaryBlocks, pool.numLayers, mKVFactor, blockSize}); TLLM_LOG_DEBUG("[%s] Allocating secondary pool with %d blocks for %d layers with %d kv heads", mLogPrefix.c_str(), mNumSecondaryBlocks, pool.numLayers, pool.numKvHeads); pool.secondaryPtr = BufferManager::pinned(cacheShapeOffload, poolDtype); } } } void BlockManager::releasePools() { for (auto& [_, manager] : mWindowBlockManagers) { manager.releasePools(); } } void WindowBlockManager::releasePools() { for (auto& pool : mPools) { if (pool.primaryPtr) { pool.primaryPtr->release(); } if (pool.secondaryPtr) { pool.secondaryPtr->release(); } } mBufferManager.getStream().synchronize(); mBufferManager.memoryPoolTrimTo(0); } void BlockManager::startScheduling() { for (auto& [_, manager] : mWindowBlockManagers) { manager.startScheduling(); } } void WindowBlockManager::startScheduling() { mSchedulingNumFreeBlocks = mEvictionPolicy->getNumFreeBlocks(kPrimaryLevel); for (auto& [requestId, slotAllocatedBlocks] : mAllocatedBlocksPerSeq) { for (auto& allocatedBlock : slotAllocatedBlocks) { allocatedBlock->startScheduling(); } } } void WindowBlockManager::freeLeafBlock(BlockPtr const& block) { // The eviction policy needs blocks to still be linked to their old parents when they're reclaimed. // This is so it can check if the parent should be queued for eviction. block->freeLeafBlock(); } void WindowBlockManager::freeChildren(BlockPtr const& block) { // Free all descendants of block for (auto const& p : block->getNextBlocks()) { auto childBlock = p.second; freeChildren(childBlock); } // Free block if (mEventManager && blockInRadixTree(block)) { mEventManager->enqueueRemovedEvent(block, mWindowSize); } freeLeafBlock(block); } BlockPtr WindowBlockManager::getFreeBlock(GenerationRequest& sequence, executor::RetentionPriority priority, std::optional durationMs, executor::KvCacheTransferMode mode, std::string const& directory) { // eviction policy get free primary block auto [block, canOffload] = mEvictionPolicy->getFreeBlock(kPrimaryLevel); if (block->getUniqueTokens().empty()) { ++mAllocNewBlocks; } ++mAllocTotalBlocks; // Offloading is an option only when these conditions are met: // 1. Block contains state (evidenced by presence of tokens) // 2. Eviction policy indicated block can be offloaded // 3. At least one free block in secondary memory // 4. Onboarding is enabled (allowing block to be brought back into primary) if (!block->getUniqueTokens().empty() && canOffload && mEvictionPolicy->getNumFreeBlocks(kSecondaryLevel) > 0 && mOnboardBlocks) { // Offload block in primary memory before repurposing auto offloadBlock = std::get<0>(mEvictionPolicy->getFreeBlock(kSecondaryLevel)); mTransferManager->offload(block, offloadBlock, mPools, 0, mode, directory); // swap linear block offsets (i.e. make block the offload block) block->swapMemoryPoolBlockOffset(offloadBlock); if (mEventManager && blockInRadixTree(block)) { mEventManager->enqueueUpdatedEvent( tle::KVCacheUpdatedData(block->getHash()).cacheLevelUpdated(kPrimaryLevel, kSecondaryLevel), mWindowSize); } // Update the block as a secondary block (maintaining its priority) mEvictionPolicy->claimBlock(block); // Release the block into secondary block queue mEvictionPolicy->releaseBlock(block); // We have the offloaded block as the block to use now. block = offloadBlock; } // Removes children of the block from the search tree freeChildren(block); // Claim the block in primary block queue mEvictionPolicy->claimBlock(block, priority, durationMs); // Deal with invalidating block save for reuse for the sequence if (mBlockToSequence.count(block->getBlockId()) > 0) { auto const& originalOwnerSequenceId = mBlockToSequence[block->getBlockId()]; if (mIsValidStoreForReuseSequence.count(originalOwnerSequenceId) > 0 && sequence.getRequestId() != originalOwnerSequenceId) { TLLM_LOG_DEBUG("%s::getFreeBlock - Block %d was originally held but released from sequence %d", mLogPrefix.c_str(), block->getBlockId(), originalOwnerSequenceId); if (mIsValidStoreForReuseSequence[originalOwnerSequenceId]) { TLLM_LOG_DEBUG("%s::getFreeBlock - Invalidate store block for reuse for sequence %d", mLogPrefix.c_str(), originalOwnerSequenceId); } else { TLLM_LOG_DEBUG("%s::getFreeBlock - Store block for reuse for sequence %d is already invalid", mLogPrefix.c_str(), originalOwnerSequenceId); } mIsValidStoreForReuseSequence[originalOwnerSequenceId] = false; } } // Record which sequence is using the block mBlockToSequence[block->getBlockId()] = sequence.getRequestId(); TLLM_LOG_DEBUG("%s::getFreeBlock - Block %d is now acquired by sequence %d", mLogPrefix.c_str(), block->getBlockId(), sequence.getRequestId()); return block; } void WindowBlockManager::setOffsets(tk::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType32 beamIdx, SizeType32 blockIdx, KVCacheBlock::IdType blockId) const { auto constexpr kIdx = 0; auto constexpr vIdx = 1; auto const& block = mAllBlocksById[blockId]; for (SizeType32 poolIdx = 0; poolIdx < static_cast(mPools.size()); poolIdx++) { auto const& pool = mPools.at(poolIdx); for (auto const xIdx : {kIdx, vIdx}) { auto constexpr layerIdx = 0; auto const offsetIndex = tensorrt_llm::common::flat_index(offsetsShape.d, poolIdx, beamIdx, xIdx, blockIdx); auto const fieldIdx = mCacheType == CacheType::kSELFKONLY ? 0 : xIdx; auto const blockIndex = tk::KVCacheIndex{ common::flat_index3(block->getMemoryPoolBlockIndex(), layerIdx, fieldIdx, pool.numLayers, mKVFactor)}; offsetsPtr[offsetIndex] = blockIndex; } } } void BlockManager::setOffsets(tk::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType32 beamIdx, SizeType32 blockIdx, KVCacheBlock::IdType blockId, SizeType32 windowSize) const { mWindowBlockManagers.at(windowSize).setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId); } void BlockManager::onboardBlock(GenerationRequest& sequence, BlockPtr const& offloadBlock, SizeType32 windowSize, executor::KvCacheTransferMode mode, std::string const& directory) { mWindowBlockManagers.at(windowSize).onboardBlock(sequence, offloadBlock, mode, directory); } void WindowBlockManager::onboardBlock(GenerationRequest& sequence, BlockPtr const& offloadBlock, executor::KvCacheTransferMode mode, std::string const& directory) { if (mOnboardBlocks && !offloadBlock->isPrimary()) { auto block = getFreeBlock( sequence, executor::KvCacheRetentionConfig::kDefaultRetentionPriority, std::nullopt, mode, directory); mTransferManager->onboard(offloadBlock, block, mPools, 0, mode, directory); // swap linear block offsets (i.e. make block the offload block and vice versa) offloadBlock->swapMemoryPoolBlockOffset(block); if (mEventManager) { mEventManager->enqueueUpdatedEvent( tle::KVCacheUpdatedData(offloadBlock->getHash()).cacheLevelUpdated(kSecondaryLevel, kPrimaryLevel), mWindowSize); } mEvictionPolicy->releaseBlock(block); // append block to offload queue // offloadBlock is now in primary memory pool } } void BlockManager::offloadBlock( BlockPtr const& block, SizeType32 windowSize, executor::KvCacheTransferMode mode, std::string const& directory) { mWindowBlockManagers.at(windowSize).offloadBlock(block, mode, directory); } void WindowBlockManager::offloadBlock( BlockPtr const& block, executor::KvCacheTransferMode mode, std::string const& directory) { // The current default behavior is to offload the out-of-window block // to secondary block pool to allow more free primary blocks for reuse. // However, such behavior does not take account whether the offloaded // block is useful or not and may just lead to more traffic instead. // The ideal way of this is to dedicate the offloading of the block // to the eviction policy. if (mOnboardBlocks && block->isPrimary()) { // Offload block in primary memory before repurposing auto offloadBlock = std::get<0>(mEvictionPolicy->getFreeBlock(kSecondaryLevel)); // If we're swapping a block to secondary memory, maintain the prior priority values. mEvictionPolicy->claimBlock(offloadBlock); mTransferManager->offload(block, offloadBlock, mPools, 0, mode, directory); // swap linear block offsets (i.e. make block the offload block) block->swapMemoryPoolBlockOffset(offloadBlock); if (mEventManager && blockInRadixTree(block)) { mEventManager->enqueueUpdatedEvent( tle::KVCacheUpdatedData(block->getHash()).cacheLevelUpdated(kPrimaryLevel, kSecondaryLevel), mWindowSize); } mEvictionPolicy->releaseBlock(offloadBlock); // append offloadBlock to mFreePrimaryBlocks queue // block is now in secondary memory } } [[nodiscard]] std::optional BlockManager::findNewContextBlock( VecUniqueTokens const& uniqueTokens, LlmRequest const& llmRequest) const { TLLM_CHECK_WITH_INFO( !isVariableWindow(), "The optimization of delaying requests won't work for variable window attention"); auto const& onlyManager = mWindowBlockManagers.cbegin()->second; return onlyManager.findNewContextBlock(uniqueTokens, llmRequest); } std::optional WindowBlockManager::findNewContextBlock( VecUniqueTokens const& uniqueTokens, LlmRequest const& llmRequest) const { auto blockedUniqueTokens = chopVectorIntoBlocks(uniqueTokens, uniqueTokens.size(), mTokensPerBlock, false); auto blockKeys = buildBlockKeys(blockedUniqueTokens, llmRequest); BlockKey ret; ret.loraTaskId = llmRequest.getLoraTaskId(); auto searchRoot = mCachedBlocksRoot; for (auto const& blockKey : blockKeys) { ret.uniqueTokens.insert(ret.uniqueTokens.end(), blockKey.uniqueTokens.begin(), blockKey.uniqueTokens.end()); auto [partialMatch, numMatched, matchingBlock] = searchRoot != nullptr ? searchRoot->findMatchingBlock(blockKey, false, false) : std::make_tuple(false, 0, nullptr); if (matchingBlock == nullptr) { return ret; } searchRoot = std::move(matchingBlock); } return std::nullopt; } bool WindowBlockManager::blockInRadixTree(BlockPtr const& block) { return !block->getUniqueTokens().empty() && block->getPrevBlock() != nullptr; } std::shared_ptr WindowBlockManager::findBlocksInReuseTreeByBlockKey(BlockKey const& blockKey) { std::lock_guard lock(mCachedBlocksRootMutex); auto blockedUniqueTokens = chopVectorIntoBlocks(blockKey.uniqueTokens, blockKey.uniqueTokens.size(), mTokensPerBlock, true); std::vector blockKeys; for (auto const& blockedUniqueTokensList : blockedUniqueTokens) { blockKeys.push_back(blockKey); blockKeys.back().uniqueTokens = blockedUniqueTokensList; } auto searchRoot = mCachedBlocksRoot; for (auto const& blockKey : blockKeys) { auto [partialMatch, numMatched, matchingBlock] = searchRoot != nullptr ? searchRoot->findMatchingBlock(blockKey, true, true) : std::make_tuple(false, 0, nullptr); if (matchingBlock == nullptr) { return nullptr; } searchRoot = std::move(matchingBlock); } return searchRoot; } SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector const& blockKeys, SizeType32 numContextBlocks, GenerationRequest& sequence, std::vector const& perBlockRetentions, executor::KvCacheTransferMode mode, std::string const& directory) { std::lock_guard lock(mCachedBlocksRootMutex); SizeType32 numMatchedTokens{0}; auto searchRoot = mCachedBlocksRoot; // The last block cannot be shared between beams because it will be written to. // Make sure a unique block is allocated per beam. auto const beamWidth = sequence.getBeamWidth(); SizeType32 numSharedContextBlocks = beamWidth > 1 ? numContextBlocks - 1 : numContextBlocks; auto blockItr = blockKeys.begin(); for (int bi = 0; bi < numSharedContextBlocks; ++bi) { auto [partialMatch, numMatched, matchingBlock] = searchRoot != nullptr && blockItr != blockKeys.end() ? searchRoot->findMatchingBlock(*blockItr, mEnablePartialReuse, mCopyOnPartialReuse) : std::make_tuple(false, 0, nullptr); if (matchingBlock != nullptr && numMatchedTokens + numMatched <= sequence.getCurrentPrepopulatedPromptLen()) { KVCacheBlock::IdType matchingBlockId = matchingBlock->getBlockId(); numMatchedTokens += numMatched > 0 ? numMatched : blockItr->uniqueTokens.size(); if (perBlockRetentions[bi].retentionPriority.has_value() && matchingBlock->getPriority() != perBlockRetentions[bi].retentionPriority && mEventManager) { mEventManager->enqueueUpdatedEvent( tle::KVCacheUpdatedData(matchingBlock->getHash()) .priorityUpdated(matchingBlock->getPriority(), *perBlockRetentions[bi].retentionPriority), mWindowSize); } if (partialMatch) { if (matchingBlock->hasRefs() || !matchingBlock->isLeaf()) { // Somebody else is using block or it is not a leaf, copy reusable tokens auto newBlock = getFreeBlock( sequence, matchingBlock->getPriority(), matchingBlock->getDurationMs(), mode, directory); mTransferManager->onboard(matchingBlock, newBlock, mPools, numMatched, mode, directory); // TODO: (optional) Send out event matchingBlock = newBlock; if (blockItr != blockKeys.end()) { matchingBlock->setBlockKey( *blockItr, blockItr->uniqueTokens.size() == static_cast(mTokensPerBlock)); } matchingBlock->setHash(); TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Copied partially filled block %d", mLogPrefix.c_str(), matchingBlockId); } else { // Leaf block that nobody is using. Make block private and reuse freeLeafBlock(matchingBlock); mEvictionPolicy->claimBlock( matchingBlock, perBlockRetentions[bi].retentionPriority, perBlockRetentions[bi].durationMs); TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Reused partially filled block %d", mLogPrefix.c_str(), matchingBlockId); } searchRoot = nullptr; // no matching needed for following blocks } else { // Recover block and reuse mEvictionPolicy->claimBlock( matchingBlock, perBlockRetentions[bi].retentionPriority, perBlockRetentions[bi].durationMs); TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Matched full block %d", mLogPrefix.c_str(), matchingBlockId); searchRoot = matchingBlock; } onboardBlock(sequence, matchingBlock, mode, directory); addBlockToAllBeams(matchingBlock, sequence); // TODO: only add once for reused blocks ++mReusedBlocks; if (!reusedBlockIds.count(matchingBlockId)) { reusedBlockIds.insert(matchingBlockId); ++mReusedUniqueBlocks; } ++blockItr; } else { // If we haven't set a priority, set it to the default priority level (low) auto freeBlock = getFreeBlock(sequence, perBlockRetentions[bi].retentionPriority.value_or( executor::KvCacheRetentionConfig::kDefaultRetentionPriority), perBlockRetentions[bi].durationMs, mode, directory); addBlockToAllBeams(freeBlock, sequence); TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - No match, allocated new block %d for sequence %lu", mLogPrefix.c_str(), freeBlock->getBlockId(), sequence.getRequestId()); searchRoot = nullptr; // no matching needed for following blocks if (blockItr != blockKeys.end()) { freeBlock->setBlockKey( *blockItr, blockItr->uniqueTokens.size() == static_cast(mTokensPerBlock)); ++blockItr; } freeBlock->setHash(); ++mMissedBlocks; } } // Allocate new blocks that cannot be shared by multiple beams. for (int bi = numSharedContextBlocks; bi < numContextBlocks; ++bi) { // TODO: Still look for match. Clone matching block or allocate fresh ones. // This work is described in JIRA task https://jirasw.nvidia.com/browse/TRTLLM-2069. for (SizeType32 beamIdx = 0; beamIdx < beamWidth; ++beamIdx) { // If we haven't set a priority, set it to the default priority level (low) auto freeBlock = getFreeBlock(sequence, perBlockRetentions[bi].retentionPriority.value_or( executor::KvCacheRetentionConfig::kDefaultRetentionPriority), perBlockRetentions[bi].durationMs, mode, directory); addBlockToBeam(freeBlock, sequence, beamIdx); if (blockItr != blockKeys.end()) { freeBlock->setBlockKey( *blockItr, blockItr->uniqueTokens.size() == static_cast(mTokensPerBlock)); ++blockItr; } freeBlock->setHash(); TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Beam %d. Allocated non-shared block %d for bi %d", mLogPrefix.c_str(), beamIdx, freeBlock->getBlockId(), bi); } ++mMissedBlocks; if (blockItr != blockKeys.end()) { ++blockItr; } } sequence.setCurrentPrepopulatedPromptLen(numMatchedTokens); return numMatchedTokens; } void BlockManager::syncTransferManagerWithBufferManager() { for (auto& [_, manager] : mWindowBlockManagers) { manager.syncTransferManagerWithBufferManager(); } } void WindowBlockManager::syncTransferManagerWithBufferManager() { mTransferManager->syncWithBufferManager(); } void BlockManager::refreshBlocks() { for (auto& [_, manager] : mWindowBlockManagers) { manager.refreshBlocks(); } } void WindowBlockManager::refreshBlocks() { mEvictionPolicy->refresh(); mTransferManager->syncTransfers(); } // There are two versions of BlockManager::addSequence function. // This is called when block reuse is enabled. void BlockManager::addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest, SizeType32 windowSize) { mWindowBlockManagers.at(windowSize).addSequence(sequence, inputLength, numContextBlocks, llmRequest); } // There are two versions of WindowBlockManager::addSequence function. // This is called when block reuse is enabled. void WindowBlockManager::addSequence( GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest) { auto const requestId = sequence.getRequestId(); auto const [seqIt, emplaceDone] = mAllocatedBlocksPerSeq.emplace(requestId, std::vector{}); TLLM_CHECK(emplaceDone); auto constexpr beamIdx = 0; auto const& uniqueTokens = (mCacheType == CacheType::kSELF || mCacheType == CacheType::kSELFKONLY) ? llmRequest.getUniqueTokens(beamIdx) : *(llmRequest.getEncoderUniqueTokens().value()); // Ignore last token because it can't be recovered auto blockedUniqueTokens = chopVectorIntoBlocks(uniqueTokens, inputLength - 1, mTokensPerBlock, true); // Add empty block if last token is separated if (inputLength % mTokensPerBlock == 1) { blockedUniqueTokens.emplace_back(); } auto blockKeys = buildBlockKeys(blockedUniqueTokens, llmRequest); auto config = llmRequest.getKvCacheRetentionConfig(); auto perBlockRetentions = config.value_or(executor::KvCacheRetentionConfig()) .getPerBlockRetentionPriorityDuration(getTokensPerBlock(), inputLength); auto mode = config.value_or(executor::KvCacheRetentionConfig()).getTransferMode(); auto directory = config.value_or(executor::KvCacheRetentionConfig()).getDirectory(); if (mode != executor::KvCacheTransferMode::DRAM && directory.empty()) { TLLM_LOG_WARNING( "Transfer mode %d specified without directory, falling back to DRAM mode", static_cast(mode)); mode = executor::KvCacheTransferMode::DRAM; } TLLM_CHECK(perBlockRetentions.size() == (size_t) numContextBlocks); auto const prepopulatedPromptLen = loadOrAllocateBlocks(blockKeys, numContextBlocks, sequence, perBlockRetentions, mode, directory); mReusedTokens += static_cast(prepopulatedPromptLen); mTotalInputTokens += static_cast(uniqueTokens.size()); SizeType32 numConnectorMatchedTokens = 0; // If we're using a KV cache connector, check if any additional blocks can be loaded. if (mKvCacheConnectorManager && !llmRequest.isDummyRequest()) { numConnectorMatchedTokens = mKvCacheConnectorManager->getNumNewMatchedTokens(llmRequest, prepopulatedPromptLen); } llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen + numConnectorMatchedTokens, getTokensPerBlock()); TLLM_LOG_DEBUG("addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d, numConnectorMatchedTokens %d", llmRequest.mRequestId, inputLength, prepopulatedPromptLen, numConnectorMatchedTokens); } void BlockManager::adjustBlocksIfNeeded(GenerationRequest& sequence) { for (auto& [windowSize, manager] : mWindowBlockManagers) { mWindowBlockManagers.at(windowSize).adjustBlocksIfNeeded(sequence); } } void WindowBlockManager::adjustBlocksIfNeeded(GenerationRequest& sequence) { auto const minTokensForBlockDetach = mWindowSize + mTokensPerBlock; while ( sequence.getNumTokens() - sequence.getNumFrontBlocksRemoved() * getTokensPerBlock() >= minTokensForBlockDetach) { // Detaching block for SWA is non-trivial due to the radix tree structure. // For now, when reuse is enabled, we do not detach blocks for SWA. TLLM_CHECK_WITH_INFO(mIsSWA, "A block only go out-of-window in SWA"); detachFrontBlock(sequence); } if ((sequence.getNumTokens() - 1) % getTokensPerBlock() == 0) { // Allocating a new block when the last token is a block boundary allocateBlock(sequence, /*shareAmongBeams=*/sequence.getBeamWidth() == 1); updateLastCacheBlockOffsets(sequence); } } // There are two versions of BlockManager::addSequence function. // This is called when block reuse is disabled. void BlockManager::addSequence( GenerationRequest& sequence, SizeType32 numContextBlocks, SizeType32 windowSize, bool isShareLastContextBlock) { mWindowBlockManagers.at(windowSize).addSequence(sequence, numContextBlocks, isShareLastContextBlock); } // There are two versions of WindowBlockManager::addSequence function. // This is called when block reuse is disabled. void WindowBlockManager::addSequence( GenerationRequest& sequence, SizeType32 numContextBlocks, bool isShareLastContextBlock) { if (mKvCacheConnectorManager) { TLLM_LOG_WARNING( "KV Cache Connector specified when block reuse is disabled. The KV Cache Connector will be " "ignored."); } auto const requestId = sequence.getRequestId(); auto const [seqIt, emplaceDone] = mAllocatedBlocksPerSeq.emplace(requestId, std::vector{}); TLLM_CHECK(emplaceDone); TLLM_CHECK_WITH_INFO(numContextBlocks > 0, "numContextBlocks must be greater than 0"); for (SizeType32 bi = 0; bi < numContextBlocks - 1; ++bi) { allocateBlock(sequence, /*shareAmongBeams=*/true); } allocateBlock(sequence, /*shareAmongBeams=*/isShareLastContextBlock); } void WindowBlockManager::addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx) { auto const requestId = sequence.getRequestId(); block->incRefCount(); if (sequence.getCacheBlockIds(mWindowSize).at(beamIdx).size() == 0) { block->setPrevBlockInSeq(nullptr); } else { block->setPrevBlockInSeq(mAllBlocksById.at(sequence.getCacheBlockIds(mWindowSize)[beamIdx].back())); } sequence.addCacheBlock(mWindowSize, beamIdx, block->getBlockId()); mAllocatedBlocksPerSeq.at(requestId).push_back(block); } void WindowBlockManager::addBlockToAllBeams(BlockPtr& block, GenerationRequest& sequence) { auto const beamWidth = sequence.getBeamWidth(); for (SizeType32 beamIdx = 0; beamIdx < beamWidth; ++beamIdx) { addBlockToBeam(block, sequence, beamIdx); } } void BlockManager::allocateBlock(GenerationRequest& sequence, SizeType32 windowSize) { mWindowBlockManagers.at(windowSize).allocateBlock(sequence, false); } void WindowBlockManager::allocateBlock(GenerationRequest& sequence, bool shareAmongBeams) { auto const beamWidth = sequence.getBeamWidth(); auto const requiredBlocks = shareAmongBeams ? 1 : beamWidth; TLLM_CHECK_WITH_INFO(hasFreeBlocks(requiredBlocks), "Can't allocate new blocks. No free blocks left."); if (shareAmongBeams) { // add same block to all beams auto block = getFreeBlock(sequence, sequence.getDecodeRetentionPriority(), sequence.getDecodeDurationMs(), sequence.getTransferMode(), sequence.getDirectory()); for (auto beamIdx = 0; beamIdx < beamWidth; ++beamIdx) { addBlockToBeam(block, sequence, beamIdx); } } else { // add different block to each beam for (auto beamIdx = 0; beamIdx < beamWidth; ++beamIdx) { auto block = getFreeBlock(sequence, sequence.getDecodeRetentionPriority(), sequence.getDecodeDurationMs(), sequence.getTransferMode(), sequence.getDirectory()); addBlockToBeam(block, sequence, beamIdx); } } } std::pair> WindowBlockManager::storeBlocks( std::vector const& blockKeys, std::vector const& blockIds, bool pinBlocks) { SizeType32 numBlocksStoredForReuse = 0; std::lock_guard lock(mCachedBlocksRootMutex); TLLM_LOG_DEBUG( "%s::storeBlocks - %zu blockKeys, %zu blockIds", mLogPrefix.c_str(), blockKeys.size(), blockIds.size()); auto searchRoot = mCachedBlocksRoot; bool needMatch = true; auto numBlocks = blockKeys.size(); std::vector storedBlocks; std::vector pinnedBlockIds; for (std::size_t blockCnt = 0; blockCnt < numBlocks; ++blockCnt) { auto const bid = blockIds[blockCnt]; TLLM_LOG_DEBUG("%s::storeBlocks - Searching match for block %d", mLogPrefix.c_str(), bid); auto& block = mAllBlocksById[bid]; auto const& blockKey = blockKeys[blockCnt]; auto [partialMatch, numMatched, matchedBlock] = needMatch ? searchRoot->findMatchingBlock(blockKey, false, false) : std::make_tuple(false, 0, nullptr); if (matchedBlock != nullptr) { // Found match TLLM_LOG_DEBUG( "%s::storeBlocks - Found matching block %d, traverse", mLogPrefix.c_str(), matchedBlock->getBlockId()); searchRoot = matchedBlock; // TODO possible optimization: if bid != matchedBlock->getBlockId(), // block can be freed and inserted at mFreePrimaryBlocks.begin() } else { // No match TLLM_LOG_DEBUG("%s::storeBlocks - No match, inserting block %d into search structure", mLogPrefix.c_str(), block->getBlockId()); TLLM_CHECK_WITH_INFO(block->getBlockId() == bid, "Block id mismatch " + std::to_string(block->getBlockId()) + " != " + std::to_string(bid)); needMatch = false; // no matching needed for following blocks block->setBlockKey(blockKey, static_cast(blockKey.uniqueTokens.size()) == mTokensPerBlock); block->setPrevBlock(searchRoot); block->setPrevBlockInSeq(searchRoot); searchRoot->addNextBlock(blockKey, block); // Sanity check. The list of stored blocks should be connected. TLLM_CHECK(storedBlocks.empty() || block->getPrevBlock() == storedBlocks.back()); storedBlocks.push_back(block); TLLM_CHECK(block->getPrevBlockInSeq() == nullptr || block->getPrevBlockInSeq()->getHash() == searchRoot->getHash()); auto oldHash = block->getHash(); auto newHash = BlockKeyHasher()(blockKey, searchRoot->getHash()); if (oldHash != newHash) { TLLM_LOG_DEBUG("#%d block hash %zx -> %zx", block->getBlockId(), oldHash, newHash); block->setHash(newHash); } searchRoot = block; numBlocksStoredForReuse++; } if (pinBlocks) { searchRoot->incRefCount(); pinnedBlockIds.push_back(searchRoot->getBlockId()); } } if (mEventManager) { mEventManager->enqueueStoredEvent(storedBlocks, mWindowSize); } return {numBlocksStoredForReuse, pinnedBlockIds}; } void BlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx) { mWindowBlockManagers.at(windowSize).replaceSharedBlock(sequence, blockIdx); } void WindowBlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeType32 blockIdx) { auto const requestId = sequence.getRequestId(); auto const beamWidth = sequence.getBeamWidth(); auto& allocatedBlocks = mAllocatedBlocksPerSeq.at(requestId); if (!allocatedBlocks.at((blockIdx + 1) * beamWidth - 1)->isShared()) { return; } BlockKey blockKey = allocatedBlocks.at(blockIdx * beamWidth)->getBlockKey(); bool isFull = allocatedBlocks.at(blockIdx * beamWidth)->isFull(); // Free shared block for (auto beamIdx = 0; beamIdx < beamWidth; ++beamIdx) { auto block = allocatedBlocks.at(blockIdx * beamWidth + beamIdx); block->decRefCount(); if (!block->hasRefs()) { mEvictionPolicy->releaseBlock(block); } } // Allocate new blocks TLLM_CHECK_WITH_INFO(hasFreeBlocks(beamWidth), "Can't allocate new blocks. No free blocks left."); for (auto beamIdx = 0; beamIdx < beamWidth; ++beamIdx) { auto block = getFreeBlock(sequence, executor::KvCacheRetentionConfig::kDefaultRetentionPriority, std::nullopt, sequence.getTransferMode(), sequence.getDirectory()); block->incRefCount(); if (sequence.getCacheBlockIds(mWindowSize).at(beamIdx).size() == 0) { block->setPrevBlockInSeq(nullptr); } else { block->setPrevBlockInSeq(mAllBlocksById.at(sequence.getCacheBlockIds(mWindowSize)[beamIdx].back())); } block->setBlockKey(blockKey, isFull); block->setHash(); sequence.changeCacheBlock(mWindowSize, beamIdx, blockIdx, block->getBlockId()); allocatedBlocks.at(blockIdx * beamWidth + beamIdx) = block; } } void BlockManager::releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize) { mWindowBlockManagers.at(windowSize).releaseLastBlock(sequence); } void WindowBlockManager::releaseLastBlock(GenerationRequest& sequence) { auto const requestId = sequence.getRequestId(); auto& allocatedBlocks = mAllocatedBlocksPerSeq.at(requestId); auto it = allocatedBlocks.rbegin(); auto& block = *it; // Decrease ref count block->decRefCount(); // If ref count is zero, move block to free blocks if (!block->hasRefs()) { mEvictionPolicy->releaseBlock(block, true); } // Remove block from allocated blocks allocatedBlocks.pop_back(); // Remove stored block ids in sequence sequence.removeLastBlock(mWindowSize); } [[nodiscard]] SizeType32 WindowBlockManager::getNumFreeBlocks() const noexcept { return mEvictionPolicy->getNumFreeBlocks(kPrimaryLevel); } std::deque BlockManager::getLatestEvents(std::optional timeout) const { return mEventManager ? mEventManager->getEvents(timeout) : std::deque{}; } std::vector BlockManager::storeBlocksForReuse( GenerationRequest& sequence, OptionalRef llmRequest, bool pinBlocks) { std::vector pinnedBlockIds; for (auto& [_, manager] : mWindowBlockManagers) { pinnedBlockIds = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks); } return pinnedBlockIds; } std::optional BlockManager::releaseBlocks( GenerationRequest& sequence, OptionalRef llmRequest, bool pinBlocks) { // Released block will be stored when reuse is enabled. // Reuse is implied to be enabled if llmRequest is provided. std::optional lastStoredId = std::nullopt; // For now, the attention kernel only accepts a single // "prepopulatedPromptLen", that is, all window sizes will use the same // prepopulated prompt length, so it is meaningless right now to save // blocks only for a certain window size while blocks in the other // window size are not valid for saving for reuse. bool isAllWindowSizesValidForStoreForReuse = true; for (auto& [windowSize, manager] : mWindowBlockManagers) { isAllWindowSizesValidForStoreForReuse &= manager.isSequenceValidForStoreForReuse(sequence.getRequestId()); } for (auto& [_, manager] : mWindowBlockManagers) { if (!llmRequest.has_value() || llmRequest->isDummyRequest() || sequence.getBeamWidth() > 1 || !isAllWindowSizesValidForStoreForReuse) { lastStoredId = manager.releaseBlocks(sequence, std::nullopt); } else { lastStoredId = manager.releaseBlocks(sequence, llmRequest); } } return lastStoredId; } void BlockManager::pinBlocks(GenerationRequest& sequence) { for (auto& [_, manager] : mWindowBlockManagers) { manager.pinBlocks(sequence); } } void BlockManager::unpinBlocksById(std::vector const& blockIds) { // Use the first window size if (mWindowBlockManagers.empty()) { return; } auto& firstManager = mWindowBlockManagers.begin()->second; firstManager.unpinBlocksById(blockIds); } void WindowBlockManager::pinBlocks(GenerationRequest& sequence) { auto const requestId = sequence.getRequestId(); auto& allocatedBlocks = mAllocatedBlocksPerSeq.at(requestId); for (auto& block : allocatedBlocks) { block->incRefCount(); } } void WindowBlockManager::unpinBlocksById(std::vector const& blockIds) { if (blockIds.empty()) { return; } for (auto const& blockId : blockIds) { TLLM_CHECK_WITH_INFO(blockId >= 0 && static_cast(blockId) < mAllBlocksById.size(), "Block id %d is out of range", blockId); auto block = mAllBlocksById[blockId]; if (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId) { block->decRefCount(); if (!block->hasRefs()) { mEvictionPolicy->releaseBlock(block); } } } } void BlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef llmRequest) { for (auto& [_, manager] : mWindowBlockManagers) { if (manager.isSWA()) { // SWA cannot store new blocks on the fly because the block stored // may go OOW and be reused by another sequence. continue; } manager.storeNewBlock(sequence, llmRequest); } } void WindowBlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef llmRequest) { auto constexpr beamIdx = 0; auto const& uniqueTokens = llmRequest->getUniqueTokens(beamIdx); auto const& cacheBlockIds = sequence.getCacheBlockIds(mWindowSize); if (uniqueTokens.size() == 0) { return; } // TODO: get the caller to mark tokens as filled / not filled, so that the kv-cache manager doesn't // have to guess. Only (length - 1) tokens of the sequence have their kv-state recorded in kv-cache. We assume // the last token's state is not filled yet. auto const usableSize = static_cast(uniqueTokens.size()) - 1; if (usableSize % mTokensPerBlock != 0) { return; } auto blockedUniqueTokens = chopVectorIntoBlocks(uniqueTokens, usableSize, mTokensPerBlock, true); auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest); if (blockKeys.size() < 2 || cacheBlockIds[beamIdx].size() < blockKeys.size()) { // store all blocks TLLM_LOG_DEBUG("%s::storeNewBlock - store all blocks", mLogPrefix.c_str()); (void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]); return; } auto lastBlock = mAllBlocksById.at(cacheBlockIds[beamIdx][blockKeys.size() - 1]); auto prevBlock = mAllBlocksById.at(cacheBlockIds[beamIdx][blockKeys.size() - 2]); // If the previous block is not in the radix tree, we need to store all blocks if (prevBlock->getPrevBlock() == nullptr) { TLLM_LOG_DEBUG("%s::storeNewBlock - store all blocks", mLogPrefix.c_str()); (void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]); return; } if (lastBlock->getPrevBlock() != nullptr) { // If the last block is not in the radix tree, we need to store all blocks TLLM_LOG_DEBUG("%s::storeNewBlock - no need to store", mLogPrefix.c_str()); return; } TLLM_LOG_DEBUG("%s::storeNewBlock - store the last block", mLogPrefix.c_str()); (void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]); } std::vector WindowBlockManager::storeBlocksForReuse( GenerationRequest& sequence, OptionalRef llmRequest, bool pinBlocks) { auto constexpr beamIdx = 0; auto const& uniqueTokens = llmRequest->getUniqueTokens(beamIdx); auto const& cacheBlockIds = sequence.getCacheBlockIds(mWindowSize); // TODO: get the caller to mark tokens as filled / not filled, so that the kv-cache manager doesn't // have to guess. Only (length - 1) tokens of the sequence have their kv-state recorded in kv-cache. We assume // the last token's state is not filled yet. auto const usableSize = static_cast(uniqueTokens.size()) - 1; auto blockedUniqueTokens = chopVectorIntoBlocks(uniqueTokens, usableSize, mTokensPerBlock, true); auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest); auto [numStored, pinnedBlockIds] = storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks); return pinnedBlockIds; } std::optional WindowBlockManager::releaseBlocks( GenerationRequest& sequence, OptionalRef llmRequest) { auto const requestId = sequence.getRequestId(); std::optional lastStoredId = std::nullopt; auto node = mAllocatedBlocksPerSeq.extract(requestId); TLLM_CHECK(node); auto& allocatedBlocks = node.mapped(); if (llmRequest.has_value()) { // If llmRequest is provided, block store for reuse is enabled. if (!isSequenceValidForStoreForReuse(requestId)) { TLLM_LOG_DEBUG( "%s::releaseBlocks - sequence %lu does not have all blocks valid, block is not saved for reuse", mLogPrefix.c_str(), sequence.getRequestId()); } else { if (mIsSWA) { TLLM_LOG_DEBUG("%s::releaseBlocks - sequence %lu is valid for store for reuse", mLogPrefix.c_str(), sequence.getRequestId()); } auto const& uniqueTokens = llmRequest->getUniqueTokens(/*beamIdx=*/0); // Only (length - 1) tokens of the sequence have their kv-state // recorded in kv-cache. We assume the last token's state is not filled yet. auto const usableSize = static_cast(uniqueTokens.size()) - 1; auto blockedUniqueTokens = chopVectorIntoBlocks(uniqueTokens, usableSize, mTokensPerBlock, /*allowPartial=*/true); auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest); std::vector cacheBlockIds(allocatedBlocks.size()); std::transform(allocatedBlocks.begin(), allocatedBlocks.end(), cacheBlockIds.begin(), [](BlockPtr const& block) { return block->getBlockId(); }); auto [numBlocksStoredForReuse, pinnedBlockIds] = storeBlocks(std::move(blockKeys), cacheBlockIds); TLLM_LOG_DEBUG("%s::releaseBlocks Request %lu, %d blocks stored for reuse", mLogPrefix.c_str(), sequence.getRequestId(), numBlocksStoredForReuse); } } for (auto it = allocatedBlocks.rbegin(); it != allocatedBlocks.rend() - sequence.getNumFrontBlocksRemoved(); ++it) { auto& block = *it; // Decrease ref count if (block->hasRefs()) { // An out-of-window block may not have any ref count. block->decRefCount(); } // If ref count is zero, move block to free blocks if (!block->hasRefs()) { mEvictionPolicy->releaseBlock(block); } } // Remove stored block ids in sequence sequence.clearCacheBlocks(mWindowSize); return lastStoredId; } void BlockManager::schedulingReleaseBlocks(RequestIdType requestId) { for (auto& [_, manager] : mWindowBlockManagers) { manager.schedulingReleaseBlocks(requestId); } } void WindowBlockManager::schedulingReleaseBlocks(RequestIdType requestId) { for (auto& block : mAllocatedBlocksPerSeq.at(requestId)) { // Decrease ref count block->decSchedulingRefCount(); // If ref count is zero, move block to free blocks if (!block->hasSchedulingRefs()) { mSchedulingNumFreeBlocks++; } } } KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, SizeType32 sinkTokenLength, int64_t stream, runtime::SizeType32 maxSequenceLength, bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, bool enablePartialReuse, bool copyOnPartialReuse, bool enableIndexerKCache, SizeType32 indexerKCacheQuantBlockSize, SizeType32 indexerKCacheIndexHeadDim) : KVCacheManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength, std::make_shared(reinterpret_cast(stream)), maxSequenceLength, enableBlockReuse, onboardBlocks, cacheType, std::nullopt, nullptr, enablePartialReuse, copyOnPartialReuse, nullptr, enableIndexerKCache, indexerKCacheQuantBlockSize, indexerKCacheIndexHeadDim) { } KVCacheManager::KVCacheManager(std::vector const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, SizeType32 sinkTokenLength, int64_t stream, runtime::SizeType32 maxSequenceLength, bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, std::shared_ptr kvCacheConnectorManager, bool enableIndexerKCache, SizeType32 indexerKCacheQuantBlockSize, SizeType32 indexerKCacheIndexHeadDim) : KVCacheManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength, std::make_shared(reinterpret_cast(stream)), maxSequenceLength, enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority, eventManager, enablePartialReuse, copyOnPartialReuse, kvCacheConnectorManager, enableIndexerKCache, indexerKCacheQuantBlockSize, indexerKCacheIndexHeadDim) { } KVCacheManager::KVCacheManager(std::vector const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, SizeType32 sinkTokenLength, CudaStreamPtr stream, runtime::SizeType32 maxSequenceLength, bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, std::shared_ptr kvCacheConnectorManager, bool enableIndexerKCache, SizeType32 indexerKCacheQuantBlockSize, SizeType32 indexerKCacheIndexHeadDim) : mMaxBeamWidth(maxBeamWidth) , mDataType(dtype) , mMaxAttentionWindow(*std::max_element(maxAttentionWindowVec.begin(), maxAttentionWindowVec.end())) , mTokensPerBlock(tokensPerBlock) , mSinkBubbleLength(BaseKVCacheManager::getSinkBubbleLength(sinkTokenLength, tokensPerBlock)) , mSinkBlockTokenLength(mSinkBubbleLength + sinkTokenLength) , mBlockManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, std::move(stream), maxSequenceLength, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, mSinkBubbleLength, onboardBlocks, cacheType, secondaryOffloadMinPriority, std::move(eventManager), enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager), std::nullopt, enableIndexerKCache, indexerKCacheQuantBlockSize, indexerKCacheIndexHeadDim) // disable block reuse for sink bubble since chopVectorIntoBlocks does not match KV cache blocks in this case , mEnableBlockReuse{mSinkBubbleLength > 0 ? false : enableBlockReuse} { TLLM_CHECK_WITH_INFO(mSinkBlockTokenLength == 0 && mSinkBubbleLength == 0, "[kv cache manager] streamLLM is not supported at the moment"); TLLM_CHECK_DEBUG(std::find(maxAttentionWindowVec.begin(), maxAttentionWindowVec.end(), mMaxAttentionWindow) != maxAttentionWindowVec.end()); // The sink tokens are stored in blocks separate from other tokens. // If the last block of sink tokens is only partially filled, // we fill that block with a "bubble" to reach the number of tokens per block. TLLM_CHECK(mSinkBlockTokenLength % tokensPerBlock == 0); TLLM_LOG_DEBUG("KV cache block reuse is %s", mEnableBlockReuse ? "enabled" : "disabled"); mSequences.reserve(maxNumSequences); } KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, SizeType32 sinkTokenLength, CudaStreamPtr stream, runtime::SizeType32 maxSequenceLength, bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, std::shared_ptr kvCacheConnectorManager, bool enableIndexerKCache, SizeType32 indexerKCacheQuantBlockSize, SizeType32 indexerKCacheIndexHeadDim) : KVCacheManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength, std::move(stream), maxSequenceLength, enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority, std::move(eventManager), enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager), enableIndexerKCache, indexerKCacheQuantBlockSize, indexerKCacheIndexHeadDim) { } void KVCacheManager::allocatePools(bool useUvm) { mBlockManager.allocatePools(useUvm); auto const numPools = mBlockManager.getNumPools(); uint64_t cacheSizeBytes = 0; for (SizeType32 poolIdx = 0; poolIdx < numPools; poolIdx++) { auto const cacheShape = mBlockManager.getPrimaryPool(poolIdx)->getShape(); auto const cacheVolume = ITensor::volume(cacheShape); #ifdef ENABLE_FP4 auto const isFp4 = mDataType == nvinfer1::DataType::kFP4; #else auto const isFp4 = false; #endif if (!isFp4) { cacheSizeBytes += cacheVolume * BufferDataType(mDataType).getSize(); } else { cacheSizeBytes += (cacheVolume * 4) / 8; } } // Save the total number of bytes allocated for the KV-cache for KvCacheStats mAllocatedBytes = cacheSizeBytes; if (tc::Logger::getLogger()->getLevel() <= tc::Logger::INFO) { TLLM_LOG_INFO("Number of tokens per block: %d.", mBlockManager.getTokensPerBlock()); auto const maxNumTokens = mBlockManager.getNumPrimaryBlocks() * mBlockManager.getTokensPerBlock(); TLLM_LOG_INFO("[MemUsageChange] Allocated %0.2f GiB for max tokens in paged KV cache (%d).", cacheSizeBytes / static_cast(1 << 30), maxNumTokens); } auto const numKVPools = mBlockManager.getNumPools(/*include_block_scalar_pools=*/false, /*include_indexer_k_cache_pools=*/false); auto const numBlockScalePools = mBlockManager.getNumPools(/*includeBlockScalePools=*/true, /*includeIndexerKCachePools=*/false) - numKVPools; // Code in the attention kernels is cleaner if we can access the KV values and block scales separately. mBlockPoolPointers = BufferManager::cpu(ITensor::makeShape({numKVPools, 2}), TRTDataType::value); mBlockScalePoolPointers = BufferManager::cpu(ITensor::makeShape({numBlockScalePools, 2}), TRTDataType::value); auto poolPtrsRange = BufferRange(*mBlockPoolPointers); auto blockScalePtrsRange = BufferRange(*mBlockScalePoolPointers); SizeType32 kvPoolIdx = 0; SizeType32 blockScalePoolIdx = 0; for (SizeType32 poolIdx = 0; poolIdx < numPools; poolIdx++) { auto const& pool = mBlockManager.getPool(poolIdx); auto& outIdx = pool.containsBlockScales ? blockScalePoolIdx : kvPoolIdx; auto& outRange = pool.containsBlockScales ? blockScalePtrsRange : poolPtrsRange; if (pool.containsIndexerKCache) { mIndexerKCachePoolPointers = pool.primaryPtr; } else { outRange[outIdx * 2] = pool.primaryPtr->data(); outRange[outIdx * 2 + 1] = pool.secondaryPtr ? pool.secondaryPtr->data() : nullptr; outIdx++; } } auto const numLayers = mBlockManager.getNumLayers(); mLayerToPoolMapping = BufferManager::cpu(ITensor::makeShape({numLayers, 2}), TRTDataType::value); auto poolMappingRange = BufferRange(*mLayerToPoolMapping); for (SizeType32 layerIdx = 0; layerIdx < numLayers; layerIdx++) { auto const indexOfPool = mBlockManager.getLayerPoolIdx(layerIdx); auto const layerIdxInCachePool = mBlockManager.getPoolLayerIdx(layerIdx); poolMappingRange[layerIdx * 2] = indexOfPool; poolMappingRange[layerIdx * 2 + 1] = layerIdxInCachePool; } } void KVCacheManager::releasePools() { mBlockManager.releasePools(); } void KVCacheManager::startScheduling() { mBlockManager.startScheduling(); } SizeType32 KVCacheManager::getNeededBlocksOneStep( LlmRequest const& req, bool twoStepsLookAhead, SizeType32 windowSize) const { if ((req.isContextInitState() && req.isFirstContextChunk()) || req.isDisaggGenerationInitState()) { auto const chunkSize = req.mMaxNewTokens; auto const maxDraftTokensToAdd = req.getNumDraftTokens(); auto const promptCacheLen = std::min((isCrossKv() ? req.getEncoderOutputLen() : req.mPromptLen) + maxDraftTokensToAdd, windowSize + chunkSize) + mSinkBubbleLength; auto const numSharedBlocks = promptCacheLen / getTokensPerBlock(); auto const numUnSharedTokens = promptCacheLen % getTokensPerBlock(); auto const numUnSharedBlocks = tc::ceilDiv(numUnSharedTokens, getTokensPerBlock()) * req.mSamplingConfig.beamWidth; auto const numRequiredBlocks = numSharedBlocks + numUnSharedBlocks; return numRequiredBlocks; } if (req.isGenerationInProgressState()) { if (isCrossKv()) { return 0; } auto const numCurrTokens = getSequence(req.mRequestId).getNumTokens(); auto const generatedTokens = numCurrTokens - req.getPromptLen(); auto const maxTokensToAddToKVCache = req.mMaxNewTokens - generatedTokens; auto const tokensPerStep = req.getNumDraftTokens() + 1; auto const maxTokensToAdd = std::min((twoStepsLookAhead ? 2 : 1) * tokensPerStep, maxTokensToAddToKVCache); auto const numNextTokens = numCurrTokens + maxTokensToAdd; if (numNextTokens > mBlockManager.getWindowSizeMetadata(windowSize).maxTokenNum) { return 0; } auto const numCurrBlocks = tc::ceilDiv(numCurrTokens, getTokensPerBlock()); auto const numNextBlocks = tc::ceilDiv(numNextTokens, getTokensPerBlock()); auto const numRequiredBlocks = (numNextBlocks - numCurrBlocks) * req.mSamplingConfig.beamWidth; return numRequiredBlocks; } return 0; } SizeType32 KVCacheManager::getRemainingBlocksToCompletion(LlmRequest const& req, SizeType32 windowSize) const { if (isCrossKv()) { if (req.isContextInitState() && req.getContextCurrentPosition() == 0) { return tc::ceilDiv(req.getEncoderOutputLen(), getTokensPerBlock()); } return 0; // cross KV cache doesn't grow after the initial context phase } auto const temporaryAttentionWindow = mBlockManager.getWindowSizeMetadata(windowSize).temporaryAttentionWindow; SizeType32 const numContextBlocks = (std::min(req.mPromptLen, windowSize + temporaryAttentionWindow) + mSinkBubbleLength) / getTokensPerBlock(); SizeType32 const numTotalBlocksPerBeam = tc::ceilDiv( std::min(req.mPromptLen + req.mMaxNewTokens, windowSize + temporaryAttentionWindow) + mSinkBubbleLength, getTokensPerBlock()); SizeType32 const numGenBlocksPerBeam = numTotalBlocksPerBeam - numContextBlocks; SizeType32 numAllocBlocksPerBeam = 0; { std::scoped_lock lck(mSequencesMtx); auto const seqIt = mSequences.find(req.mRequestId); if (seqIt != mSequences.end()) { auto const& seq = seqIt->second; numAllocBlocksPerBeam = seq.getCacheBlockIds(windowSize).at(0).size(); } } // In case of sliding window attention, a new block is allocated when the // window slides (and then the out-of-window block is detached). So we // need an extra block for generation if the diff between the max sequence // length and the current sequence length crosses both a block boundary // and a window boundary. auto const isSlidingWindow = (req.mPromptLen + req.mMaxNewTokens) > windowSize; SizeType32 const currentSeqlenInBlocks = tc::ceilDiv(req.getNumTokens(0), getTokensPerBlock()); SizeType32 const maxSeqlenInBlocks = tc::ceilDiv(req.mPromptLen + req.mMaxNewTokens, getTokensPerBlock()); auto const willCrossBlockBoundary = maxSeqlenInBlocks > currentSeqlenInBlocks; auto const willCrossWindowBlockBoundary = maxSeqlenInBlocks > numTotalBlocksPerBeam; SizeType32 numExtraBlocksPerBeam = isSlidingWindow && willCrossBlockBoundary && willCrossWindowBlockBoundary ? 1 : 0; if (numAllocBlocksPerBeam < numContextBlocks) // Still haven't allocated all context blocks { return numContextBlocks - numAllocBlocksPerBeam + (numGenBlocksPerBeam + numExtraBlocksPerBeam) * req.mSamplingConfig.beamWidth; } return (numTotalBlocksPerBeam - numAllocBlocksPerBeam + numExtraBlocksPerBeam) * req.mSamplingConfig.beamWidth; } void BlockManager::updateSequenceCacheBlockOffsets(GenerationRequest& sequence, SizeType32 windowSize) { auto const& cacheBlocks = sequence.getCacheBlockIds(windowSize); auto& cacheBlocksTensor = sequence.getCacheBlockIndices(windowSize); auto const beamWidth = sequence.getBeamWidth(); auto* offsetsPtr = bufferCast(cacheBlocksTensor); auto const& offsetsShape = cacheBlocksTensor.getShape(); for (SizeType32 beamIdx = 0; beamIdx < beamWidth; ++beamIdx) { auto const& beamCacheBlock = cacheBlocks[beamIdx]; for (SizeType32 blockIdx = 0; blockIdx < static_cast(beamCacheBlock.size()); ++blockIdx) { auto const blockId = beamCacheBlock.at(blockIdx); mWindowBlockManagers.at(windowSize).setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId); } } } void WindowBlockManager::updateLastCacheBlockOffsets(GenerationRequest& sequence) { auto const& cacheBlocks = sequence.getCacheBlockIds(mWindowSize); auto& cacheBlocksTensor = sequence.getCacheBlockIndices(mWindowSize); auto const beamWidth = sequence.getBeamWidth(); auto* offsetsPtr = bufferCast(cacheBlocksTensor); auto const& offsetsShape = cacheBlocksTensor.getShape(); for (SizeType32 beamIdx = 0; beamIdx < beamWidth; ++beamIdx) { auto const& beamCacheBlock = cacheBlocks[beamIdx]; auto const blockId = beamCacheBlock.back(); auto const blockIdx = static_cast(beamCacheBlock.size() - 1); setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId); } } void BlockManager::updateCacheBlockOffsetsAtIdx(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx) { auto const& cacheBlocks = sequence.getCacheBlockIds(windowSize); auto& cacheBlocksTensor = sequence.getCacheBlockIndices(windowSize); auto const beamWidth = sequence.getBeamWidth(); auto* offsetsPtr = bufferCast(cacheBlocksTensor); auto const& offsetsShape = cacheBlocksTensor.getShape(); for (SizeType32 beamIdx = 0; beamIdx < beamWidth; ++beamIdx) { auto const& beamCacheBlock = cacheBlocks[beamIdx]; auto const blockId = beamCacheBlock.at(blockIdx); mWindowBlockManagers.at(windowSize).setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId); } } void KVCacheManager::addToken(RequestIdType requestId) { // TODO: add streamLLM support auto& sequence = getSequence(requestId); sequence.addNewTokens(1); mBlockManager.adjustBlocksIfNeeded(sequence); } void WindowBlockManager::detachFrontBlock(GenerationRequest& sequence) { // streamLLM is not supported at the moment. The out of window block will // always be the 0th block. TLLM_CHECK_WITH_INFO( sequence.getBeamWidth() == 1, "[kv cache manager] detachBlock does not support beamWidth > 1 now."); auto const requestId = sequence.getRequestId(); auto const beamWidth = sequence.getBeamWidth(); auto& allocatedBlocks = mAllocatedBlocksPerSeq.at(requestId); SizeType32 outOfWindowBlockIdx = sequence.getNumFrontBlocksRemoved(); for (auto beamIdx = 0; beamIdx < beamWidth; ++beamIdx) { auto outOfWindowBlock = allocatedBlocks.at(outOfWindowBlockIdx * beamWidth + beamIdx); TLLM_LOG_DEBUG("%s::detachFrontBlock - Detaching block %d from sequence %d", mLogPrefix.c_str(), outOfWindowBlock->getBlockId(), requestId); outOfWindowBlock->decRefCount(); if (outOfWindowBlock->hasRefs()) { TLLM_LOG_DEBUG("%s::detachFrontBlock - OOW Block %d still has a non-zero ref count", mLogPrefix.c_str(), outOfWindowBlock->getBlockId()); } if (!outOfWindowBlock->hasRefs()) { mEvictionPolicy->releaseBlock(outOfWindowBlock); } } // Disconnect first block from sequence and remove it from allocated blocks sequence.removeFrontBlock(mWindowSize); } std::optional KVCacheManager::findNewContextBlock( VecUniqueTokens const& uniqueTokens, LlmRequest const& llmRequest) const { auto newContextBlockOpt = mBlockManager.findNewContextBlock(uniqueTokens, llmRequest); return newContextBlockOpt; } void KVCacheManager::addSequence( RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, OptionalRef llmRequest) { // TODO: add streamLLM support auto kvCacheRetentionConfig = llmRequest ? llmRequest->getKvCacheRetentionConfig().value_or(executor::KvCacheRetentionConfig()) : executor::KvCacheRetentionConfig(); auto const [seqIt, emplaceDone] = [&] { auto lck = std::scoped_lock(mSequencesMtx); return mSequences.try_emplace(requestId, requestId, inputLength, beamWidth, mBlockManager.getWindowSizesMetadata(), kvCacheRetentionConfig); }(); TLLM_CHECK(emplaceDone); auto& sequence = seqIt->second; // Get statistics for block allocations/reuse pre request. SizeType32 const numAllocTotalBlocksPreRequest = mBlockManager.getNumAllocTotalBlocks(); SizeType32 const numAllocNewBlocksPreRequest = mBlockManager.getNumAllocNewBlocks(); SizeType32 const numReusedBlocksPreRequest = mBlockManager.getNumReusedBlocks(); SizeType32 const numMissedBlocksPreRequest = mBlockManager.getNumMissedBlocks(); if (!mBlockManager.isSequenceHeld(requestId)) { mBlockManager.holdSequence(requestId); TLLM_LOG_DEBUG( "[kv cache manager] Encounter new sequence %d, initialize sequence storage validity for all window sizes", requestId); } else { TLLM_LOG_DEBUG( "[kv cache manager] Encounter existing sequence %d, skip sequence storage validity initialization", requestId); } for (auto const [windowSize, metadata] : mBlockManager.getWindowSizesMetadata()) { // NOTE: Caller to KVCacheManager::addSequence should deal with the chunking auto const maxTokenNum = metadata.maxTokenNum; auto const temporaryAttentionWindow = metadata.temporaryAttentionWindow; // Consider the temporaryAttentionWindow when allocating blocks. auto const effectiveInputLength = std::min(inputLength, maxTokenNum + temporaryAttentionWindow); auto const numContextBlocks = tc::ceilDiv(effectiveInputLength, getTokensPerBlock()); if (mEnableBlockReuse) { mBlockManager.addSequence(sequence, effectiveInputLength, numContextBlocks, *llmRequest, windowSize); } else { if (!mEnableBlockReuse && llmRequest && llmRequest->getKvCacheRetentionConfig().has_value()) { TLLM_LOG_WARNING( "Request %d has a retention configuration set, but block reuse is disabled. The retention " "config " "will " "have no effect.", llmRequest->mRequestId); } bool isShareLastContextBlock = isCrossKv() || effectiveInputLength % getTokensPerBlock() == 0; mBlockManager.addSequence(sequence, numContextBlocks, windowSize, isShareLastContextBlock); } mBlockManager.updateSequenceCacheBlockOffsets(sequence, windowSize); } if (llmRequest) { // Update statistics for block allocations/reuse per request. llmRequest->updateAllocTotalBlocksPerRequest( mBlockManager.getNumAllocTotalBlocks() - numAllocTotalBlocksPreRequest); llmRequest->updateAllocNewBlocksPerRequest(mBlockManager.getNumAllocNewBlocks() - numAllocNewBlocksPreRequest); llmRequest->updateReusedBlocksPerRequest(mBlockManager.getNumReusedBlocks() - numReusedBlocksPreRequest); llmRequest->updateMissedBlocksPerRequest(mBlockManager.getNumMissedBlocks() - numMissedBlocksPreRequest); } } void KVCacheManager::storeContextBlocks(LlmRequest const& llmRequest) { auto const requestId = llmRequest.mRequestId; bool found = false; { // protect the mSequences std::scoped_lock lock(mSequencesMtx); found = mSequences.find(requestId) != mSequences.end(); } if (found) { auto& sequence = getSequence(requestId); if (mEnableBlockReuse && !llmRequest.isDummyRequest()) { mBlockManager.storeContextBlocks(sequence, llmRequest); } } else { TLLM_LOG_WARNING("[kv cache manager] storeContextBlocks: Can not find sequence for request %lu", requestId); } } void KVCacheManager::storeNewBlock(LlmRequest const& llmRequest) { // We store newest block for potential reuse only if: // - Beam search is NOT enabled // - Block reuse is enabled. auto const requestId = llmRequest.mRequestId; auto& sequence = getSequence(requestId); if (sequence.getBeamWidth() > 1 || !mEnableBlockReuse) { return; } mBlockManager.storeNewBlock(sequence, llmRequest); } std::optional KVCacheManager::removeSequence( RequestIdType requestId, OptionalRef llmRequest, bool pinBlocks) { TLLM_LOG_TRACE("[%s]::%s start", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__); auto sequenceNode = [this, requestId] { std::scoped_lock lock(mSequencesMtx); return mSequences.extract(requestId); }(); std::optional lastStoredId = std::nullopt; if (!sequenceNode.empty()) { if (mEnableBlockReuse) { lastStoredId = mBlockManager.releaseBlocks(sequenceNode.mapped(), llmRequest, pinBlocks); } else { lastStoredId = mBlockManager.releaseBlocks(sequenceNode.mapped(), std::nullopt, pinBlocks); } } if (mBlockManager.isSequenceHeld(requestId)) { mBlockManager.releaseSequence(requestId); TLLM_LOG_DEBUG("Remove sequence %d, release sequence storage validity for all window sizes", requestId); } TLLM_CHECK(!mBlockManager.isSequenceHeld(requestId)); TLLM_LOG_TRACE("[%s]::%s stop", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__); return lastStoredId; } std::vector KVCacheManager::storeBlocksForReuse( RequestIdType requestId, OptionalRef llmRequest, bool pinBlocks) { TLLM_LOG_TRACE("[%s]::%s start", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__); auto& sequence = getSequence(requestId); auto pinnedBlockIds = mBlockManager.storeBlocksForReuse(sequence, llmRequest, pinBlocks); TLLM_LOG_TRACE("[%s]::%s stop", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__); return pinnedBlockIds; } void KVCacheManager::schedulingRemoveSequence(RequestIdType requestId) { // Mimic Free all blocks for this sequence mBlockManager.schedulingReleaseBlocks(requestId); } void KVCacheManager::pinBlocks(RequestIdType requestId) { auto& sequence = getSequence(requestId); mBlockManager.pinBlocks(sequence); } void KVCacheManager::unpinBlocksById(std::vector const& blockIds) { mBlockManager.unpinBlocksById(blockIds); } SizeType32 KVCacheManager::copyBlockOffsets(ITensor& output, SizeType32 outputSlotOffset, RequestIdType requestId) const { auto const& sequence = getSequence(requestId); auto const beamWidth = sequence.getBeamWidth(); auto* dstPtr = bufferCast(output); auto const& dstShape = output.getShape(); SizeType32 constexpr kIdx = 0; SizeType32 constexpr vIdx = 1; SizeType32 maxBlockCount{0}; // Get page table for each KV cache pool SizeType32 absolutePoolIdx = 0; for (auto const [windowSize, metadata] : mBlockManager.getWindowSizesMetadata()) { auto const& cacheBlocksTensor = sequence.getCacheBlockIndices(windowSize); auto const* srcPtr = bufferCast(cacheBlocksTensor); auto const& srcShape = cacheBlocksTensor.getShape(); auto const& cacheBlockIds = sequence.getCacheBlockIds(windowSize); for (SizeType32 poolIdx = 0; poolIdx < metadata.numPools; poolIdx++, absolutePoolIdx++) { for (SizeType32 beamIdx = 0; beamIdx < beamWidth; ++beamIdx) { auto const beamBlockCount = cacheBlockIds[beamIdx].size(); auto const copyChunkSize = beamBlockCount * sizeof(tk::KVCacheIndex); for (auto xIdx : {kIdx, vIdx}) { auto const srcIndex = tc::flat_index(srcShape.d, poolIdx, beamIdx, xIdx, 0); auto const dstIndex = tc::flat_index(dstShape.d, absolutePoolIdx, outputSlotOffset + beamIdx, xIdx, 0); std::memcpy(dstPtr + dstIndex, srcPtr + srcIndex, copyChunkSize); } maxBlockCount = std::max(maxBlockCount, static_cast(beamBlockCount)); } } } return maxBlockCount; } void KVCacheManager::getBlockOffsetsOfBatch( ITensor& output, SizeType32 firstBatchSlotIdx, SizeType32 batchSize, SizeType32 beamWidth) const { // Get page table for each KV cache pool for (auto batchSlotIdx = 0; batchSlotIdx < batchSize; ++batchSlotIdx) { copyBlockOffsets(output, batchSlotIdx * beamWidth, firstBatchSlotIdx + batchSlotIdx); } } std::map> BaseKVCacheManager::groupLayersByWindowSize( std::vector const& maxAttentionWindowVec, SizeType32 numLayers) { auto const numNonUniqueWindowSizes = static_cast(maxAttentionWindowVec.size()); std::map> uniqueWindowSizeToLayers; for (SizeType32 layerIdx = 0; layerIdx < numLayers; layerIdx++) { /* At this point (Deep in the construction of TrtGptModel), maxAttentionWindowVec isn't "stretched" to the length of numLayers yet. So, we need to rotate the window sizes per layer with modulo. */ auto const windowSize = maxAttentionWindowVec.at(layerIdx % numNonUniqueWindowSizes); uniqueWindowSizeToLayers[windowSize].push_back(layerIdx); } return uniqueWindowSizeToLayers; } std::tuple BaseKVCacheManager::calculateFreeMemBytes( runtime::BufferManager const& bufferManager, executor::KvCacheConfig const& config) { auto const freeMemFraction = config.getFreeGpuMemoryFraction().value_or(executor::KvCacheConfig::kDefaultGpuMemFraction); TLLM_CHECK_WITH_INFO(freeMemFraction < 1.0F, "Invalid freeMemFraction, freeMemFraction (%f) must be smaller than 1.0f", freeMemFraction); if (config.getMaxTokens().has_value()) { if (config.getFreeGpuMemoryFraction().has_value()) { TLLM_LOG_WARNING( "Both freeGpuMemoryFraction (aka kv_cache_free_gpu_mem_fraction) " "and maxTokens (aka max_tokens_in_paged_kv_cache) " "are set (to %f and %ld, respectively). The smaller value will be used.", freeMemFraction, (int64_t) config.getMaxTokens().value()); } } TLLM_CUDA_CHECK(::cudaDeviceSynchronize()); auto const [freeMem, totalMem] = tc::getDeviceMemoryInfo(config.getUseUvm()); auto const finalFreeMem = freeMem + bufferManager.memoryPoolFree(); TLLM_LOG_INFO("Memory usage when calculating max tokens in paged kv cache: total: %0.2f GiB, available: %0.2f GiB", totalMem / static_cast(1 << 30), finalFreeMem / static_cast(1 << 30)); TLLM_CHECK_WITH_INFO(finalFreeMem <= totalMem, "Free memory cannot exceed total memory"); auto const freePrimaryMemBytes = static_cast(finalFreeMem * freeMemFraction); auto const freeSecondaryMemBytes = config.getHostCacheSize().value_or(0); TLLM_LOG_DEBUG("Calculated free memory: {.freePrimaryMemBytes=%" PRIu64 ", .freeSecondaryMemBytes=%" PRIu64 "}", freePrimaryMemBytes, freeSecondaryMemBytes); return std::make_tuple(freePrimaryMemBytes, freeSecondaryMemBytes); } namespace { bool isSortedVectorIdenticalAcrossAllRanks(WorldConfig const& worldConfig, std::vector const& vector) { auto const numRanks = worldConfig.getSize(); auto const numElements = static_cast(vector.size()); int maxNumElements = 0; int minNumElements = 0; COMM_SESSION.allreduce(&numElements, &maxNumElements, 1, mpi::MpiType::kINT32, mpi::MpiOp::MAX); COMM_SESSION.allreduce(&numElements, &minNumElements, 1, mpi::MpiType::kINT32, mpi::MpiOp::MIN); if (maxNumElements != minNumElements) { return false; } std::vector allElements(numElements * numRanks); COMM_SESSION.allgather(vector.data(), allElements.data(), numElements, mpi::MpiType::kUINT32); for (int i = 0; i < numElements; ++i) { auto const ref = allElements.at(i); for (int rank = 1; rank < numRanks; ++rank) { if (allElements[rank * numElements + i] != ref) return false; } } return true; } } // namespace BlocksPerWindow BaseKVCacheManager::calculateMaxNumBlocks(executor::KvCacheConfig const& config, bool isCrossAttention, nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig, std::map> const& windowSizeToLayers, uint64_t allottedPrimaryMemBytes, uint64_t allottedSecondaryMemBytes, size_t extraCostMemory, SizeType32 kvFactor) { TLLM_LOG_DEBUG("Calculating max num blocks for %s: {.allottedPrimaryMemBytes=%" PRIu64 ", .allottedSecondaryMemBytes=%" PRIu64 "}", isCrossAttention ? "Cross KvCacheManager" : "Self KvCacheManager", allottedPrimaryMemBytes, allottedSecondaryMemBytes); if (config.getMaxTokens().has_value() && windowSizeToLayers.size() > 1) { TLLM_LOG_WARNING( "Setting maxTokens when using Variable Sliding Window Attention is a strange concept, as it limits " "the number of max tokens *per window size* [limiting the sum of all window sizes is even stranger]. " "Anticipating the effects of this requires quite a complex calculation, and it probably isn't the " "configuration you meant to use."); } std::map cacheSizeBytesPerTokenPerWindow; for (auto const& [windowSize, managedLayers] : windowSizeToLayers) { auto const cacheSizePerToken = BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize( modelConfig, managedLayers, isCrossAttention, kvFactor); auto const cacheSizeBytesPerToken = cacheSizePerToken * BufferDataType(dtype).getSize(); cacheSizeBytesPerTokenPerWindow[windowSize] = cacheSizeBytesPerToken; } bool const isVSWA = cacheSizeBytesPerTokenPerWindow.size() > 1; TLLM_LOG_DEBUG("extraCostMemory [Gib]: %0.2f", extraCostMemory / static_cast(1 << 30)); allottedPrimaryMemBytes = allottedPrimaryMemBytes - extraCostMemory; auto const tokensPerBlock = modelConfig.getTokensPerBlock(); auto const calculatePrimaryBlocks = [&](SizeType32 windowSize, float windowSizeShare, SizeType32 cacheSizeBytesPerToken) { TLLM_LOG_DEBUG("windowSizeShare: %f, cacheSizeBytesPerToken: %d", windowSizeShare, cacheSizeBytesPerToken); auto maxTokens = static_cast( allottedPrimaryMemBytes * windowSizeShare / static_cast(cacheSizeBytesPerToken)); // kv_cache_config.max_tokens is not effective in VSWA scheme if (config.getMaxTokens().has_value() && !isVSWA) { auto const maxTokensFromConfig = static_cast(config.getMaxTokens().value()); if (maxTokensFromConfig < maxTokens) { TLLM_LOG_DEBUG("Maximum kv-cache token overridden by configuration as '%ld'.", maxTokensFromConfig); maxTokens = std::min(maxTokensFromConfig, maxTokens); } } TLLM_LOG_DEBUG("Primary maxTokens for windowSize %d: %ld", windowSize, maxTokens); SizeType32 const blocksInPrimaryPool = tc::ceilDiv(maxTokens, tokensPerBlock); TLLM_LOG_DEBUG( "Number of blocks in KV cache primary pool for windowSize %d: %d", windowSize, blocksInPrimaryPool); return blocksInPrimaryPool; }; auto const calculateSecondaryBlocks = [&](SizeType32 windowSize, float windowSizeShare, SizeType32 cacheSizeBytesPerToken) { auto const maxTokensSecondary = static_cast(allottedSecondaryMemBytes * windowSizeShare / cacheSizeBytesPerToken); SizeType32 const blocksInSecondaryPool = std::max(0, maxTokensSecondary / tokensPerBlock); TLLM_LOG_DEBUG( "Number of blocks in KV cache secondary pool for windowSize %d: %d, onboard blocks to primary memory " "before reuse: %s", windowSize, blocksInSecondaryPool, config.getOnboardBlocks() ? "true" : "false"); return blocksInSecondaryPool; }; std::map windowSizeToShare; // By default, we allocate equal proportion shares of memory for all // window sizes (see the else case). With TRTLLM_WINDOW_SIZE_SHARES, // we can override this behavior to adjust the memory share of each // window size. For example, if we have window size of [512, 32768], // then setting TRTLLM_WINDOW_SIZE_SHARES=0.4,0.6 will be allocating // 40% of the memory to window size 512 and 60% of the memory to window // size 32768. if (auto envStr = std::getenv("TRTLLM_WINDOW_SIZE_SHARES")) { std::stringstream ss(envStr); std::vector shares; float share; while (ss >> share) { shares.push_back(share); if (ss.peek() == ',') ss.ignore(); } TLLM_CHECK_WITH_INFO(shares.size() == windowSizeToLayers.size(), "Number of shares in TRTLLM_WINDOW_SIZE_SHARES (%ld) must match number of window sizes (%ld)", shares.size(), windowSizeToLayers.size()); float sumShares = 0.0f; for (auto s : shares) { TLLM_CHECK_WITH_INFO(0.0f <= s && s <= 1.0f, "Shares must be in value range [0,1], got %f", s); sumShares += s; } TLLM_CHECK_WITH_INFO(sumShares > 0.0f, "Sum of shares must be > 0."); // Normalize shares to 1.0 for (auto& s : shares) { s /= sumShares; } size_t i = 0; for (auto const& [windowSize, _] : windowSizeToLayers) { windowSizeToShare[windowSize] = shares[i++]; } } else { // NOTE: Righteously, blocks allocated should be proportional with // regard to window size. Currently, we are first allocating identical // number of blocks for all layers to achieve identical performance. for (auto const& [windowSize, _] : windowSizeToLayers) { windowSizeToShare[windowSize] = 1.0f / windowSizeToLayers.size(); } } std::vector blocksPrimary; std::vector blocksSecondary; for (auto const& [windowSize, managedLayers] : windowSizeToLayers) { auto const cacheSizeBytesPerToken = cacheSizeBytesPerTokenPerWindow.at(windowSize); auto const windowSizeShare = windowSizeToShare.at(windowSize); auto const blocksInPrimaryPool = calculatePrimaryBlocks(windowSize, windowSizeShare, cacheSizeBytesPerToken); auto const blocksInSecondaryPool = calculateSecondaryBlocks(windowSize, windowSizeShare, cacheSizeBytesPerToken); blocksPrimary.push_back(blocksInPrimaryPool); blocksSecondary.push_back(blocksInSecondaryPool); } std::vector windowSizes; windowSizes.reserve(windowSizeToLayers.size()); for (auto const& [k, _] : windowSizeToLayers) { windowSizes.push_back(k); } if (worldConfig.getSize() > 1) { TLLM_CHECK(worldConfig.validMpiConfig()); auto const rank = worldConfig.getRank(); using tensorrt_llm::common::vec2str; TLLM_CHECK_WITH_INFO(isSortedVectorIdenticalAcrossAllRanks( worldConfig, windowSizes), // sorted thanks to windowSizeToLayers being a std::map "[RANK %d] Asymmetrical pipeline parallelism detected: Ranks either have a different number of window " "sizes, or differing values. This is not supported with Variable Sliding Window Attention. Local window " "sizes for reference: %s", rank, vec2str(windowSizes).c_str()); TLLM_LOG_DEBUG( "[RANK %d] Before mpi::MpiOp::MIN reduction: window sizes %s / primary blocks %s / secondary blocks %s", rank, vec2str(windowSizes).c_str(), vec2str(blocksPrimary).c_str(), vec2str(blocksSecondary).c_str()); // make sure all ranks use same value for max blocks auto blocksWorld = blocksPrimary; COMM_SESSION.allreduce( blocksPrimary.data(), blocksWorld.data(), windowSizes.size(), mpi::MpiType::kINT32, mpi::MpiOp::MIN); blocksPrimary = blocksWorld; COMM_SESSION.allreduce( blocksSecondary.data(), blocksWorld.data(), windowSizes.size(), mpi::MpiType::kINT32, mpi::MpiOp::MIN); blocksSecondary = blocksWorld; TLLM_LOG_DEBUG( "[RANK %d] After mpi::MpiOp::MIN reduction: window sizes %s / primary blocks %s / secondary blocks %s", rank, vec2str(windowSizes).c_str(), vec2str(blocksPrimary).c_str(), vec2str(blocksSecondary).c_str()); } BlocksPerWindow windowSizeToBlocks; TLLM_LOG_INFO("Blocks per window size:"); for (size_t i = 0; i < windowSizes.size(); ++i) { auto const windowSize = windowSizes.at(i); auto const primaryBlocks = blocksPrimary.at(i); auto const secondayBlocks = blocksSecondary.at(i); TLLM_LOG_INFO( "[windowSize=%d] {.primaryBlocks=%d, .secondayBlocks=%d}", windowSize, primaryBlocks, secondayBlocks); windowSizeToBlocks[windowSize] = {primaryBlocks, secondayBlocks}; } return windowSizeToBlocks; } void KVCacheManager::removeToken(RequestIdType requestId) { // TODO: add streamLLM support auto& sequence = getSequence(requestId); if (sequence.getNumTokens() == 0) { return; } TLLM_CHECK_WITH_INFO(sequence.getBeamWidth() == 1, "[kv cache manager] removeToken does not support beamWidth > 1"); sequence.removeTokens(1); for (auto const [windowSize, metadata] : mBlockManager.getWindowSizesMetadata()) { SizeType32 const tokensInWindow = sequence.getNumTokens() % windowSize; if (tokensInWindow % getTokensPerBlock() == 0) { mBlockManager.releaseLastBlock(sequence, windowSize); } } } void KVCacheManager::rewindKVCache(RequestIdType requestId, SizeType32 rewindLengths) { for (SizeType32 si = 0; si < rewindLengths; ++si) { removeToken(requestId); } } GenerationRequest const& KVCacheManager::getSequence(RequestIdType requestId) const { auto lck = std::scoped_lock(mSequencesMtx); return mSequences.at(requestId); } GenerationRequest& KVCacheManager::getSequence(RequestIdType requestId) { auto lck = std::scoped_lock(mSequencesMtx); return mSequences.at(requestId); } SizeType32 BaseKVCacheManager::getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock) { auto const sinkTokensInLastBlock = sinkTokenLen % tokensPerBlock; auto const sinkBubbleLength = sinkTokensInLastBlock == 0 ? 0 : tokensPerBlock - sinkTokensInLastBlock; return sinkBubbleLength; } std::vector> const& KVCacheManager::getCacheBlockIds( RequestIdType requestId, SizeType32 windowSize) const { return getSequence(requestId).getCacheBlockIds(windowSize); } std::vector>> KVCacheManager::getBatchCacheBlockIds( std::vector const& requestIds, SizeType32 windowSize) const { std::vector>> result{}; result.reserve(requestIds.size()); for (auto const& requestId : requestIds) { auto const& sequence = getSequence(requestId); result.emplace_back(sequence.getCacheBlockIds(windowSize)); } return result; } std::optional KVCacheManager::getLastBlockId(LlmRequest::RequestIdType requestId) const { auto const& seq = getSequence(requestId); // Use the first window size auto firstWindowSize = mBlockManager.getFirstWindowSize(); if (firstWindowSize == 0) { return std::nullopt; } auto const& perBeam = seq.getCacheBlockIds(firstWindowSize); if (perBeam.empty() || perBeam[0].empty()) { return std::nullopt; } return perBeam[0].back(); } runtime::ITensor::SharedPtr KVCacheManager::getUniquePrimaryPool() const { TLLM_CHECK_WITH_INFO(mBlockManager.getWindowSizesMetadata().size() == 1, "getUniquePrimaryPool is only supported for a single window size"); return mBlockManager.getPrimaryPool(0); } runtime::ITensor::SharedPtr KVCacheManager::getPrimaryPool(SizeType32 layer_idx) const { return mBlockManager.getPrimaryPool(mBlockManager.getLayerPoolIdx(layer_idx)); } runtime::ITensor::SharedPtr KVCacheManager::getIndexerKCachePool() const { return mIndexerKCachePoolPointers; } SizeType32 KVCacheManager::getMaxCapacityBatchSize(SizeType32 inputLength, SizeType32 outputLength) const { auto minMaxBatchSizeAllWindows = std::numeric_limits::max(); for (auto const [windowSize, metadata] : mBlockManager.getWindowSizesMetadata()) { auto const blockRequirementsPerSequence = KVCacheManager::calculateMaxBlockRequirements( inputLength, outputLength, mSinkBlockTokenLength, windowSize, mMaxBeamWidth, mTokensPerBlock); auto const maxBatchSizeWindow = metadata.allottedPrimaryBlocks / blockRequirementsPerSequence; // The window with the *smallest* max batch size is the limiting factor // Hence, the std::*min* of all the max batch sizes is chosen minMaxBatchSizeAllWindows = std::min(minMaxBatchSizeAllWindows, maxBatchSizeWindow); } return minMaxBatchSizeAllWindows; } SizeType32 KVCacheManager::calculateMaxBlockRequirementsPerBeam( SizeType32 sequenceLength, SizeType32 sinkTokenLength, SizeType32 windowSize, SizeType32 tokensPerBlock) { auto const sinkBubbleLength = BaseKVCacheManager::getSinkBubbleLength(sinkTokenLength, tokensPerBlock); auto const actualSeqLen = std::min(sequenceLength, windowSize); auto actualMaxTokenNum = actualSeqLen + sinkBubbleLength; auto numBlocks = tc::ceilDiv(actualMaxTokenNum, tokensPerBlock); if (sequenceLength > windowSize) { numBlocks += kSWAExtraBlock; } return numBlocks; } SizeType32 KVCacheManager::calculateMaxBlockRequirements(SizeType32 inputLength, SizeType32 outputLength, SizeType32 sinkTokenLength, SizeType32 windowSize, SizeType32 beamWidth, SizeType32 tokensPerBlock) { // We split beam width > 1, as it introduces a lot of complexity. auto const wholeSequenceLength = inputLength + outputLength; if (beamWidth == 1) { return KVCacheManager::calculateMaxBlockRequirementsPerBeam( wholeSequenceLength, sinkTokenLength, windowSize, tokensPerBlock); } if (windowSize <= outputLength) { // We at most will need outputLength of distinct blocks for SWA return KVCacheManager::calculateMaxBlockRequirementsPerBeam( outputLength, sinkTokenLength, windowSize, tokensPerBlock) * beamWidth; } // Otherwise, we calculate how many tokens will be in output blocks. auto const effectiveAttentionWindow = std::min(windowSize, wholeSequenceLength); auto const numContextTokensInAttentionWindow = effectiveAttentionWindow - outputLength; // This is positive because we handled the other case above. auto const sinkBubbleLength = BaseKVCacheManager::getSinkBubbleLength(sinkTokenLength, tokensPerBlock); auto const numContextBlocks = (numContextTokensInAttentionWindow + sinkBubbleLength) / tokensPerBlock; auto const leftoverContextToken = numContextTokensInAttentionWindow - numContextBlocks * tokensPerBlock; auto numOutputBlocks = tc::ceilDiv(outputLength + leftoverContextToken, tokensPerBlock); if (wholeSequenceLength > windowSize) { numOutputBlocks += kSWAExtraBlock; } return numContextBlocks + numOutputBlocks * beamWidth; } [[nodiscard]] SizeType32 KVCacheManager::calculateMaxAttentionWindow(SizeType32 inputLength, SizeType32 outputLength, SizeType32 sinkTokenLength, SizeType32 blockCapacity, SizeType32 beamWidth, SizeType32 tokensPerBlock) { // The function that gives the number of blocks required given an attention window is only linear by part. It is // however, monotonically increasing in the attention window. Therefore, we need to find in which part of the // function we are. First, are we in a case where not even the entire output will fit? auto const outputBlockRequirements = calculateMaxBlockRequirements(0, outputLength, sinkTokenLength, outputLength, beamWidth, tokensPerBlock); if (outputBlockRequirements > blockCapacity) { return (blockCapacity / beamWidth) * tokensPerBlock; } // Otherwise, we need to determine how many context tokens we can fit on top of the output tokens. First, there // are a few context tokens we might be able to fit 'for free' because the output is not a multiple of the // number of tokens per block. auto const leftoverBlockCapacity = blockCapacity - outputBlockRequirements; return std::min(outputLength + leftoverBlockCapacity * tokensPerBlock, inputLength + outputLength); } } // namespace tensorrt_llm::batch_manager::kv_cache_manager