/* * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "kvCacheManager.h" #include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/batch_manager/peftCacheManager.h" #include "tensorrt_llm/nanobind/common/bindTypes.h" #include "tensorrt_llm/nanobind/common/customCasters.h" #include "tensorrt_llm/runtime/torch.h" #include "tensorrt_llm/runtime/torchView.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace tb = tensorrt_llm::batch_manager; namespace tbc = tensorrt_llm::batch_manager::kv_connector; namespace tbk = tensorrt_llm::batch_manager::kv_cache_manager; namespace tr = tensorrt_llm::runtime; namespace nb = nanobind; using BlockKey = tbk::BlockKey; using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens; using SizeType32 = tensorrt_llm::runtime::SizeType32; using TokenIdType = tensorrt_llm::runtime::TokenIdType; using VecTokens = std::vector; using CudaStreamPtr = std::shared_ptr; using CacheBlockIds = std::vector>; NB_MAKE_OPAQUE(CacheBlockIds); namespace { std::optional from_torch(std::optional torchPtr) { if (torchPtr) { return tr::TorchView::of(torchPtr.value()); } return std::nullopt; } class PyKvCacheManager : public tbk::BaseKVCacheManager { public: NB_TRAMPOLINE(tbk::BaseKVCacheManager, 28); // using BaseKVCacheManager::BaseKVCacheManager; // Inherit constructors void allocatePools(bool useUvm = false) override { NB_OVERRIDE_PURE(allocatePools, useUvm); } void releasePools() override { NB_OVERRIDE_PURE(releasePools); } void startScheduling() override { NB_OVERRIDE_PURE(startScheduling); } SizeType32 getTokensPerBlock() const override { NB_OVERRIDE_PURE(getTokensPerBlock); } SizeType32 getMaxNumBlocks() const override { NB_OVERRIDE_PURE(getMaxNumBlocks); } SizeType32 getNumPools() const override { NB_OVERRIDE_PURE(getNumPools); } tbk::KvCacheStats getKvCacheStats() const override { NB_OVERRIDE_PURE(getKvCacheStats); } void addToken(tb::LlmRequest::RequestIdType requestId) override { NB_OVERRIDE_PURE(addToken, requestId); } void addSequence(tb::LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, tensorrt_llm::common::OptionalRef llmRequest = std::nullopt) override { NB_OVERRIDE_PURE(addSequence, requestId, inputLength, beamWidth, llmRequest); } void removeSequence(tb::LlmRequest::RequestIdType requestId, tensorrt_llm::common::OptionalRef llmRequest = std::nullopt) override { NB_OVERRIDE_PURE(removeSequence, requestId, llmRequest); } tbk::GenerationRequest const& getSequence(tb::LlmRequest::RequestIdType requestId) const override { NB_OVERRIDE_PURE(getSequence, requestId); } void schedulingRemoveSequence(tb::LlmRequest::RequestIdType requestId) override { NB_OVERRIDE_PURE(schedulingRemoveSequence, requestId); } tensorrt_llm::runtime::ITensor::SharedPtr getBlockPoolPointers() const override { NB_OVERRIDE_PURE(getBlockPoolPointers); } tensorrt_llm::runtime::ITensor::SharedPtr getLayerToPoolMapping() const override { NB_OVERRIDE_PURE(getLayerToPoolMapping); } void getBlockOffsetsOfBatch(tensorrt_llm::runtime::ITensor& output, SizeType32 firstBatchSlotIdx, SizeType32 batchSize, SizeType32 beamWidth) const override { NB_OVERRIDE_PURE(getBlockOffsetsOfBatch, output, firstBatchSlotIdx, batchSize, beamWidth); } SizeType32 copyBlockOffsets(tensorrt_llm::runtime::ITensor& output, SizeType32 outputSlotOffset, tb::LlmRequest::RequestIdType requestId) const override { NB_OVERRIDE_PURE(copyBlockOffsets, output, outputSlotOffset, requestId); } bool isEnableBlockReuse() const override { NB_OVERRIDE_PURE(isEnableBlockReuse); } void rewindKVCache(tb::LlmRequest::RequestIdType requestId, SizeType32 rewindLengths) override { NB_OVERRIDE_PURE(rewindKVCache, requestId, rewindLengths); } bool isCrossKv() const override { NB_OVERRIDE_PURE(isCrossKv); } std::optional findNewContextBlock( VecUniqueTokens const& uniqueTokens, tb::LlmRequest const& llmRequest) const override { NB_OVERRIDE_PURE(findNewContextBlock, uniqueTokens, llmRequest); } void storeContextBlocks(tb::LlmRequest const& llmRequest) override { NB_OVERRIDE_PURE(storeContextBlocks, llmRequest); } std::vector> const& getCacheBlockIds( tb::LlmRequest::RequestIdType requestId, SizeType32 windowSize) const override { NB_OVERRIDE_PURE(getCacheBlockIds, requestId, windowSize); } std::vector>> getBatchCacheBlockIds( std::vector const& requestIds, SizeType32 windowSize) const override { NB_OVERRIDE_PURE(getBatchCacheBlockIds, requestIds, windowSize); } std::vector getNewlyAllocatedBlockIds( tb::LlmRequest::RequestIdType requestId, SizeType32 windowSize) const override { NB_OVERRIDE_PURE(getNewlyAllocatedBlockIds, requestId, windowSize); } SizeType32 getUsedNumBlocks() const override { NB_OVERRIDE_PURE(getUsedNumBlocks); } SizeType32 getNumFreeBlocks() const override { NB_OVERRIDE_PURE(getNumFreeBlocks); } tbk::BlockManager const& getBlockManager() const override { NB_OVERRIDE_PURE(getBlockManager); } std::deque getLatestEvents( std::optional timeout = std::nullopt) const override { NB_OVERRIDE_PURE(getLatestEvents, timeout); } tensorrt_llm::runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 layer_idx) const override { NB_OVERRIDE_PURE(getPrimaryPool, layer_idx); } SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const override { NB_OVERRIDE_PURE(getPoolLayerIdx, layer_idx); } void refreshBlocks() override { NB_OVERRIDE_PURE(refreshBlocks); } void flushIterationEvents() override { NB_OVERRIDE_PURE(flushIterationEvents); } }; // TODO: Deduplicate executor bindings KvCacheStats class PyBasePeftCacheManager : public tb::BasePeftCacheManager { public: ~PyBasePeftCacheManager() override = default; NB_TRAMPOLINE(tb::BasePeftCacheManager, 8); void addRequestPeft(tb::BasePeftCacheManager::LlmRequestPtr llmRequest, bool tryGpuCache = true) override { NB_OVERRIDE_PURE(addRequestPeft, llmRequest, tryGpuCache); } tb::BasePeftCacheManager::PeftTable ensureBatch(tb::RequestVector const& contextRequests, tb::RequestVector const& generationRequests, bool resetGpuCache = false) override { NB_OVERRIDE_PURE(ensureBatch, contextRequests, generationRequests, resetGpuCache); } void resetDeviceCache() override { NB_OVERRIDE_PURE(resetDeviceCache); } void markRequestDone(tb::LlmRequest const& llmReq, bool pause = false) override { NB_OVERRIDE_PURE(markRequestDone, llmReq, pause); } tr::SizeType32 getMaxDevicePages() const override { NB_OVERRIDE_PURE(getMaxDevicePages); } tr::SizeType32 getMaxHostPages() const override { NB_OVERRIDE_PURE(getMaxHostPages); } tr::SizeType32 determineNumPages(std::shared_ptr llmRequest) const override { NB_OVERRIDE_PURE(determineNumPages, llmRequest); } bool enabled() const override { NB_OVERRIDE_PURE(enabled); } }; } // namespace void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) { nb::class_(m, "KvCacheStats") .def(nb::init<>()) .def_rw("max_num_blocks", &tbk::KvCacheStats::maxNumBlocks) .def_rw("free_num_blocks", &tbk::KvCacheStats::freeNumBlocks) .def_rw("used_num_blocks", &tbk::KvCacheStats::usedNumBlocks) .def_rw("tokens_per_block", &tbk::KvCacheStats::toksPerBlock) .def_rw("alloc_total_blocks", &tbk::KvCacheStats::allocTotalBlocks) .def_rw("alloc_new_blocks", &tbk::KvCacheStats::allocNewBlocks) .def_rw("reused_blocks", &tbk::KvCacheStats::reusedBlocks) .def_rw("missed_blocks", &tbk::KvCacheStats::missedBlocks) .def_rw("cache_hit_rate", &tbk::KvCacheStats::cacheHitRate) .def_rw("num_free_blocks_per_window_size", &tbk::KvCacheStats::numFreeBlocksPerWindowSize) .def_ro("allocated_bytes", &tbk::KvCacheStats::allocatedBytes); nb::class_(m, "TempAttentionWindowInputs") .def(nb::init<>()) .def_rw("paged_context_fmha", &tbk::TempAttentionWindowInputs::pagedContextFMHA) .def_rw("max_input_len", &tbk::TempAttentionWindowInputs::maxInputLen) .def_rw("max_num_tokens", &tbk::TempAttentionWindowInputs::maxNumTokens); nb::class_(m, "BlockKey") .def(nb::init<>()) .def(nb::init>(), nb::arg("tokens"), nb::arg("lora_task_id") = std::nullopt) .def(nb::init, VecUniqueTokens const&>(), nb::arg("uses_extra_ids"), nb::arg("lora_task_id"), nb::arg("unique_tokens")) .def_ro("uses_extra_ids", &tbk::BlockKey::usesExtraIds) .def_ro("lora_task_id", &tbk::BlockKey::loraTaskId) .def_ro("unique_tokens", &tbk::BlockKey::uniqueTokens); nb::class_(m, "BlockKeyHasher") .def_static("hash", &tbk::BlockKeyHasher::hash, nb::arg("block_key"), nb::arg("parent_hash") = 0); nb::class_(m, "KVCacheEventManager") .def(nb::init, std::optional, SizeType32>(), nb::arg("max_kv_event_entries"), nb::arg("attention_dp_rank") = std::nullopt, nb::arg("attention_dp_size") = std::nullopt, nb::arg("attention_dp_events_gather_period_ms") = 5); nb::class_(m, "BaseKVCacheManager") .def_static("calculate_max_num_blocks", &tbk::BaseKVCacheManager::calculateMaxNumBlocks, nb::arg("config"), nb::arg("is_cross_attention"), nb::arg("dtype"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("window_size_to_layers"), nb::arg("allotted_primary_mem_bytes"), nb::arg("allotted_secondary_mem_bytes"), nb::arg("extra_cost_memory"), nb::arg("kv_factor"), nb::call_guard()) .def("allocate_pools", &BaseKVCacheManager::allocatePools, nb::call_guard()) .def("release_pools", &BaseKVCacheManager::releasePools, nb::call_guard()) .def("start_scheduling", &BaseKVCacheManager::startScheduling, nb::call_guard()) .def_prop_ro("tokens_per_block", &BaseKVCacheManager::getTokensPerBlock) .def_prop_ro("max_num_blocks", &BaseKVCacheManager::getMaxNumBlocks) .def_prop_ro("num_pools", &BaseKVCacheManager::getNumPools) .def("get_kv_cache_stats", &BaseKVCacheManager::getKvCacheStats, nb::call_guard()) .def_prop_ro("max_blocks_per_seq", [](tbk::BaseKVCacheManager& self) { return self.getOffsetTableDimensions().maxBlocksPerSeq; }) .def("get_needed_blocks_one_step", &BaseKVCacheManager::getNeededBlocksOneStep, nb::call_guard()) .def("get_remaining_blocks_to_completion", &BaseKVCacheManager::getRemainingBlocksToCompletion, nb::call_guard()) .def("add_token", &BaseKVCacheManager::addToken, nb::call_guard()) .def("add_sequence", &BaseKVCacheManager::addSequence, nb::call_guard()) .def("remove_sequence", &BaseKVCacheManager::removeSequence, nb::call_guard()) .def("scheduling_remove_sequence", &BaseKVCacheManager::schedulingRemoveSequence, nb::call_guard()) .def( "get_block_pool_pointers", [](tbk::BaseKVCacheManager& self) { std::optional block_pool_pointers{std::nullopt}; auto tensor = self.getBlockPoolPointers(); if (tensor) { std::shared_ptr _tensor = std::move(tensor); block_pool_pointers = tr::Torch::tensor(_tensor); } return block_pool_pointers; }, nb::call_guard()) .def( "get_block_scale_pool_pointers", [](tbk::BaseKVCacheManager& self) { std::optional block_scale_pool_pointers{std::nullopt}; auto tensor = self.getBlockScalePoolPointers(); if (tensor) { std::shared_ptr _tensor = std::move(tensor); block_scale_pool_pointers = tr::Torch::tensor(_tensor); } return block_scale_pool_pointers; }, nb::call_guard()) .def( "get_layer_to_pool_mapping", [](tbk::BaseKVCacheManager& self) { std::optional layer_to_pool_mapping{std::nullopt}; auto tensor = self.getLayerToPoolMapping(); if (tensor) { std::shared_ptr _tensor = std::move(tensor); layer_to_pool_mapping = tr::Torch::tensor(_tensor); } return layer_to_pool_mapping; }, nb::call_guard()) .def( "get_primary_pool_data", [](tbk::BaseKVCacheManager& self, SizeType32 layer_idx) -> at::Tensor { auto pool = tr::Torch::tensor(self.getPrimaryPool(layer_idx)); auto pool_layer_idx = self.getPoolLayerIdx(layer_idx); return pool.index({torch::indexing::Slice(), pool_layer_idx}); }, nb::call_guard()) .def( "get_unique_primary_pool", [](tbk::BaseKVCacheManager& self) { return self.getUniquePrimaryPool(); }, nb::call_guard()) .def( "get_block_offsets_of_batch", [](tbk::BaseKVCacheManager& self, at::Tensor output, SizeType32 firstBatchSlotIdx, SizeType32 batchSize, SizeType32 beamWidth) { auto _output = from_torch(output); TLLM_CHECK_WITH_INFO(_output.has_value(), "Invalid output tensor."); self.getBlockOffsetsOfBatch(*(_output.value()), firstBatchSlotIdx, batchSize, beamWidth); }, nb::call_guard()) .def( "copy_block_offsets", [](tbk::BaseKVCacheManager& self, at::Tensor output, SizeType32 outputSlotOffset, tb::LlmRequest::RequestIdType requestId) { auto _output = from_torch(output); TLLM_CHECK_WITH_INFO(_output.has_value(), "Invalid output tensor."); auto maxBlockCount = self.copyBlockOffsets(*(_output.value()), outputSlotOffset, requestId); return maxBlockCount; }, nb::call_guard()) .def( "copy_batch_block_offsets", [](tbk::BaseKVCacheManager& self, at::Tensor output, std::vector const& requestIds, SizeType32 const beamWidth, SizeType32 const offset) { auto _output = from_torch(output); TLLM_CHECK_WITH_INFO(_output.has_value(), "Invalid output tensor."); for (size_t i = 0; i < requestIds.size(); ++i) { self.copyBlockOffsets(*(_output.value()), i * beamWidth + offset, requestIds[i]); } }, nb::call_guard()) .def( "get_latest_events", [](tbk::BaseKVCacheManager& self, std::optional timeout_ms = std::nullopt) { if (timeout_ms) { return self.getLatestEvents(std::chrono::milliseconds(static_cast(*timeout_ms))); } return self.getLatestEvents(std::nullopt); }, nb::arg("timeout_ms") = std::nullopt, nb::call_guard()) .def_prop_ro("enable_block_reuse", &BaseKVCacheManager::isEnableBlockReuse) .def("rewind_kv_cache", &BaseKVCacheManager::rewindKVCache, nb::call_guard()) .def_prop_ro("cross_kv", &BaseKVCacheManager::isCrossKv) .def("store_context_blocks", &BaseKVCacheManager::storeContextBlocks, nb::call_guard()) .def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds, nb::call_guard()) .def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds, nb::call_guard()) .def("get_newly_allocated_block_ids", &BaseKVCacheManager::getNewlyAllocatedBlockIds, nb::call_guard()) .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents, nb::call_guard()); nb::bind_vector(m, "CacheBlockIds") .def("__getstate__", [](CacheBlockIds const& v) { return nb::make_tuple(v); }) .def("__setstate__", [](CacheBlockIds& self, nb::tuple const& t) { if (t.size() != 1) throw std::runtime_error("Invalid state!"); new (&self) CacheBlockIds(nb::cast>>(t[0])); }); nb::enum_(m, "CacheType") .value("SELF", tbk::CacheType::kSELF) .value("CROSS", tbk::CacheType::kCROSS) .value("SELFKONLY", tbk::CacheType::kSELFKONLY); nb::class_(m, "KVCacheManager") .def( nb::init const&, SizeType32, SizeType32, std::map> const&, SizeType32, SizeType32, std::vector const&, std::optional const&, nvinfer1::DataType, SizeType32, int64_t, std::optional, bool, bool, tbk::CacheType, std::optional, std::shared_ptr, bool, bool, std::shared_ptr>(), nb::arg("num_kv_heads_per_layer"), nb::arg("size_per_head"), nb::arg("tokens_per_block"), nb::arg("blocks_per_window"), nb::arg("max_num_sequences"), nb::arg("max_beam_width"), nb::arg("max_attention_window_vec"), nb::arg("temp_attention_window_inputs").none(), nb::arg("dtype"), nb::arg("sink_token_length"), nb::arg("stream"), nb::arg("max_sequence_length").none(), nb::arg("enable_block_reuse") = false, nb::arg("onboard_blocks") = true, nb::arg("cache_type") = tbk::CacheType::kSELF, nb::arg("secondary_offload_min_priority") = std::nullopt, nb::arg("event_manager") = nullptr, nb::arg("enable_partial_reuse") = true, nb::arg("copy_on_partial_reuse") = true, nb::arg("kv_connector_manager") = nullptr, nb::call_guard()); } void tb::BasePeftCacheManagerBindings::initBindings(nb::module_& m) { nb::class_(m, "BasePeftCacheManager") .def("add_request_peft", &tb::BasePeftCacheManager::addRequestPeft, nb::arg("request"), nb::arg("try_gpu_cache") = true, nb::call_guard()) .def( "ensure_batch", [](tb::BasePeftCacheManager& self, tb::RequestVector const& contextRequests, tb::RequestVector const& generationRequests, bool resetGpuCache) { return self.ensureBatch(contextRequests, generationRequests, resetGpuCache); }, nb::arg("context_requests"), nb::arg("generation_requests"), nb::arg("reset_gpu_cache") = false, nb::call_guard()) .def( "reset_device_cache", &tb::BasePeftCacheManager::resetDeviceCache, nb::call_guard()) .def("mark_request_done", &tb::BasePeftCacheManager::markRequestDone, nb::arg("request"), nb::arg("pause") = false, nb::call_guard()) .def_prop_ro("max_device_pages", &tb::BasePeftCacheManager::getMaxDevicePages) .def_prop_ro("max_host_pages", &tb::BasePeftCacheManager::getMaxHostPages) .def("determine_num_pages", &tb::BasePeftCacheManager::determineNumPages, nb::arg("request"), nb::call_guard()) .def_prop_ro("enabled", &tb::BasePeftCacheManager::enabled); nb::class_(m, "PeftCacheManager") .def(nb::init(), nb::arg("config"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager"), nb::call_guard()) .def("is_task_cached", &tb::PeftCacheManager::isTaskCached, nb::arg("taskId"), nb::call_guard()); nb::class_(m, "NoOpPeftCacheManager") .def(nb::init<>(), nb::call_guard()); }