/* * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/runtime/common.h" #include namespace tensorrt_llm::batch_manager::kv_cache_manager { enum class CacheType { kSELF = 0, kCROSS = 1, kSELFKONLY = 2, }; //! @brief Encapsulates parameters to configure paged KV cache. class KvCacheConfig { public: using SizeType32 = tensorrt_llm::runtime::SizeType32; explicit KvCacheConfig(std::optional maxTokens = std::nullopt, std::optional> maxAttentionWindowVec = std::nullopt, std::optional sinkTokenLength = std::nullopt, std::optional freeGpuMemoryFraction = std::nullopt, bool enableBlockReuse = true, bool useUvm = false, std::optional hostCacheSize = std::nullopt, bool onboardBlocks = true, std::optional crossKvCacheFraction = std::nullopt, std::optional secondaryOffloadMinPriority = std::nullopt, size_t eventBufferMaxSize = 0, bool enablePartialReuse = true, bool copyOnPartialReuse = true) : maxTokens{maxTokens} , maxAttentionWindowVec{std::move(maxAttentionWindowVec)} , sinkTokenLength{sinkTokenLength} , freeGpuMemoryFraction{freeGpuMemoryFraction} , enableBlockReuse(enableBlockReuse) , useUvm(useUvm) , hostCacheSize(hostCacheSize) , onboardBlocks(onboardBlocks) , crossKvCacheFraction{crossKvCacheFraction} , secondaryOffloadMinPriority(secondaryOffloadMinPriority) , eventBufferMaxSize(eventBufferMaxSize) , enablePartialReuse(enablePartialReuse) , copyOnPartialReuse(copyOnPartialReuse) { } explicit KvCacheConfig(executor::KvCacheConfig const& kvCacheConfig) : KvCacheConfig(kvCacheConfig.getMaxTokens(), kvCacheConfig.getMaxAttentionWindowVec(), kvCacheConfig.getSinkTokenLength(), kvCacheConfig.getFreeGpuMemoryFraction(), kvCacheConfig.getEnableBlockReuse(), false, kvCacheConfig.getHostCacheSize(), kvCacheConfig.getOnboardBlocks(), kvCacheConfig.getCrossKvCacheFraction(), kvCacheConfig.getSecondaryOffloadMinPriority(), kvCacheConfig.getEventBufferMaxSize(), kvCacheConfig.getEnablePartialReuse(), kvCacheConfig.getCopyOnPartialReuse()) { } bool operator==(KvCacheConfig const& other) const { return maxTokens == other.maxTokens && maxAttentionWindowVec == other.maxAttentionWindowVec && sinkTokenLength == other.sinkTokenLength && freeGpuMemoryFraction == other.freeGpuMemoryFraction && enableBlockReuse == other.enableBlockReuse && useUvm == other.useUvm && hostCacheSize == other.hostCacheSize && onboardBlocks == other.onboardBlocks && crossKvCacheFraction == other.crossKvCacheFraction && secondaryOffloadMinPriority == other.secondaryOffloadMinPriority && eventBufferMaxSize == other.eventBufferMaxSize && enablePartialReuse == other.enablePartialReuse && copyOnPartialReuse == other.copyOnPartialReuse; } friend std::ostream& operator<<(std::ostream& os, KvCacheConfig const& self); std::optional maxTokens; std::optional> maxAttentionWindowVec; std::optional sinkTokenLength; std::optional freeGpuMemoryFraction; bool enableBlockReuse; static constexpr auto kDefaultGpuMemFraction = 0.9F; bool useUvm; std::optional hostCacheSize; bool onboardBlocks; // Cross will use crossKvCacheFraction of KV Cache and self attention will use the rest. std::optional crossKvCacheFraction; // The minimum priority level to allow blocks to be offloaded to secondary memory. std::optional secondaryOffloadMinPriority; // Maximum size of the KV Cache event buffer size_t eventBufferMaxSize; bool enablePartialReuse; bool copyOnPartialReuse; }; } // namespace tensorrt_llm::batch_manager::kv_cache_manager