/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "tensorrt_llm/batch_manager/kvCacheManager.h" namespace tensorrt_llm::batch_manager::kv_cache_manager { class BlockIterator; class BlockRange { public: // C++20 std::default_sentinel_t equivalent struct Sentinel { }; static BlockRange fromAllBlockIds(BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId, SizeType32 beam = kFIRST_AND_ONLY_BEAM) { assert(kFIRST_AND_ONLY_BEAM == beam); auto const windowSize = firstWindowSize(cacheManager); auto const blockIds = cacheManager.getSequence(requestId).getCacheBlockIds(windowSize).at(kFIRST_AND_ONLY_BEAM); return BlockRange(cacheManager, blockIds, requestId); } static BlockRange fromNewlyAllocatedBlockIds( BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId) { auto const windowSize = firstWindowSize(cacheManager); auto const blockIds = cacheManager.getNewlyAllocatedBlockIds(requestId, windowSize); return BlockRange(cacheManager, blockIds, requestId); } BlockRange(runtime::ITensor::SharedPtr pool, std::vector const& blockIds) // Only used in tests : mManager{nullptr} , mPool{std::move(pool)} , mWindowSize{0} , mRequestId{0} , mBlockIds{blockIds} { TLLM_CHECK(mPool); } [[nodiscard]] BlockIterator begin() const; [[nodiscard]] Sentinel end() const { return {}; } [[nodiscard]] size_t size() const { return mBlockIds.size(); } [[nodiscard]] std::vector const& getBlockIds() const { return mBlockIds; } void setBlockIds(std::vector blockIds) { mBlockIds = std::move(blockIds); } [[nodiscard]] std::vector getBlockHashes() const { TLLM_CHECK(mManager); std::vector blockHashes; blockHashes.reserve(mBlockIds.size()); auto& blockManager = mManager->getBlockManager(); for (auto id : mBlockIds) { blockHashes.emplace_back(blockManager.getBlockById(id, mWindowSize)->getHash()); } return blockHashes; } void updatePoolIdx(SizeType32 poolIdx) { TLLM_CHECK(mManager); mPool = mManager->getBlockManager().getPrimaryPool(poolIdx); auto const newWindowSize = mManager->getBlockManager().getPoolWindowSize(poolIdx); if (newWindowSize != mWindowSize) { mWindowSize = newWindowSize; mBlockIds = mManager->getSequence(mRequestId).getCacheBlockIds(mWindowSize).at(kFIRST_AND_ONLY_BEAM); } } friend class BlockIterator; private: BlockRange( BaseKVCacheManager const& cacheManager, std::vector blockIds, LlmRequest::RequestIdType requestId) : mManager(&cacheManager) , mPool(cacheManager.getBlockManager().getPrimaryPool(kFIRST_POOL_INDEX)) , mWindowSize(firstWindowSize(cacheManager)) , mRequestId(requestId) , mBlockIds(std::move(blockIds)) { } static SizeType32 firstWindowSize(BaseKVCacheManager const& cacheManager) { constexpr SizeType32 FIRST_POOL_IDX = 0; return cacheManager.getBlockManager().getPoolWindowSize(FIRST_POOL_IDX); } private: BaseKVCacheManager const* mManager; runtime::ITensor::SharedPtr mPool; SizeType32 mWindowSize; const LlmRequest::RequestIdType mRequestId; std::vector mBlockIds; static constexpr SizeType32 kFIRST_AND_ONLY_BEAM = 0; static constexpr SizeType32 kFIRST_POOL_INDEX = 0; }; class BlockIterator { public: using iterator_category = std::forward_iterator_tag; using value_type = runtime::ITensor; using pointer = runtime::ITensor::SharedPtr; using reference = value_type&; using SizeType32 = tensorrt_llm::runtime::SizeType32; BlockIterator(BlockRange const* range, size_t idx) : mRange{range} , mIdx{idx} { TLLM_CHECK(mIdx == 0 || mIdx < mRange->mBlockIds.size()); update(); } [[nodiscard]] pointer operator->() { return mCurrent; } [[nodiscard]] reference operator*() { return *mCurrent; } BlockIterator& operator++() { mIdx++; update(); return *this; } BlockIterator operator++(int) { auto ret = *this; mIdx++; update(); return ret; } operator runtime::ITensor::SharedPtr() { return mCurrent; } [[nodiscard]] bool operator==(BlockIterator const& other) const { return mIdx == other.mIdx && mRange == other.mRange; } [[nodiscard]] bool operator==(BlockRange::Sentinel other) const { return mIdx == mRange->mBlockIds.size(); } template [[nodiscard]] bool operator!=(T const& other) const { return !(*this == other); } private: void update() { if (mIdx < mRange->mBlockIds.size()) { mCurrent = runtime::ITensor::slice(mRange->mPool, mRange->mBlockIds.at(mIdx), 1); } } BlockRange const* mRange; runtime::ITensor::SharedPtr mCurrent; size_t mIdx; }; inline BlockIterator BlockRange::begin() const { return {this, 0}; } } // namespace tensorrt_llm::batch_manager::kv_cache_manager