diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index d81be0d1e9..c4c6659294 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -1687,6 +1687,14 @@ public: = 0; virtual void unpinBlocksById(std::vector const& blockIds) = 0; + + //! @brief Get the retention priority of a block by its ID. + //! @param blockId The ID of the block. + //! @param windowSize The attention window size this block belongs to. + //! @return The retention priority of the block, or default priority if block not found. + [[nodiscard]] virtual executor::RetentionPriority getPriorityByBlockId( + KVCacheBlock::IdType blockId, SizeType32 windowSize) const + = 0; }; class KVCacheManager : public BaseKVCacheManager @@ -1970,6 +1978,9 @@ public: void unpinBlocksById(std::vector const& blockIds) override; + [[nodiscard]] executor::RetentionPriority getPriorityByBlockId( + KVCacheBlock::IdType blockId, SizeType32 windowSize) const override; + std::optional getLastBlockId(LlmRequest::RequestIdType requestId) const override; /// @brief Calculates the number of kv-cache blocks that a sequence will require, for a single beam. diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index fd5f5487f2..8e7b6ed5a8 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -2605,6 +2605,26 @@ void KVCacheManager::unpinBlocksById(std::vector const& bl mBlockManager.unpinBlocksById(blockIds); } +tle::RetentionPriority KVCacheManager::getPriorityByBlockId(KVCacheBlock::IdType blockId, SizeType32 windowSize) const +{ + try + { + BlockPtr const& block = mBlockManager.getBlockById(blockId, windowSize); + if (block) + { + return block->getPriority(); + } + TLLM_LOG_WARNING("getPriorityByBlockId: Block ID %d not found in window %d", blockId, windowSize); + return tle::KvCacheRetentionConfig::kDefaultRetentionPriority; + } + catch (std::out_of_range const& ex) + { + TLLM_LOG_WARNING( + "getPriorityByBlockId: Block ID %d or window size %d out of range: %s", blockId, windowSize, ex.what()); + return tle::KvCacheRetentionConfig::kDefaultRetentionPriority; + } +} + SizeType32 KVCacheManager::copyBlockOffsets(ITensor& output, SizeType32 outputSlotOffset, RequestIdType requestId) const { auto const& sequence = getSequence(requestId); diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp index a56fc38d00..efd73e5caf 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -198,6 +198,7 @@ void initBindings(nb::module_& m) .def_prop_ro("parent_request_id", &GenLlmReq::getParentRequestId) .def_prop_ro("is_child", &GenLlmReq::isChild) .def_prop_ro("cache_salt_id", &GenLlmReq::getCacheSaltID) + .def_prop_ro("kv_cache_retention_config", &GenLlmReq::getKvCacheRetentionConfig) .def_prop_ro("multimodal_hashes", [](GenLlmReq& self) { diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index d642e609b9..8c9018bc2b 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -494,7 +494,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) .def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, nb::call_guard()) .def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, nb::call_guard()) .def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, nb::call_guard()) - .def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, nb::call_guard()); + .def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, nb::call_guard()) + .def("get_priority_by_block_id", &BaseKVCacheManager::getPriorityByBlockId, nb::arg("block_id"), + nb::arg("window_size"), nb::call_guard()); nb::bind_vector(m, "CacheBlockIds") .def("__getstate__", [](CacheBlockIds const& v) { return nb::make_tuple(v); }) diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index b3535c7b11..3faee32f4f 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -4094,6 +4094,61 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamPriority) } } +TEST_F(KVCacheManagerTest, GetPriorityByBlockId) +{ + auto constexpr numLayers = 2; + auto constexpr numKvHeads = 2; + auto constexpr sizePerHead = 16; + auto constexpr tokensPerBlock = 4; + auto constexpr numBlocks = 8; + auto constexpr maxAttentionWindow = 32; + auto constexpr maxNumSequences = 4; + auto constexpr beamWidth = 1; + auto constexpr dtype = nvinfer1::DataType::kHALF; + auto const stream = std::make_shared(); + SizeType32 constexpr maxNewTokens = 4; + tr::SamplingConfig const samplingConfig{beamWidth}; + bool constexpr isStreaming{false}; + tle::RetentionPriority constexpr highPriority = 80; + + auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {numBlocks, 0}}}; + + KVCacheManager kvCacheManager(numLayers, numKvHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, + beamWidth, std::vector{maxAttentionWindow}, std::nullopt, dtype, 0, stream, + maxAttentionWindow, true); + kvCacheManager.allocatePools(false); + + // Create a sequence and set a custom priority + auto inputTokens = std::make_shared(VecTokens{0, 1, 2, 3, 4, 5, 6, 7}); + auto const inputLength = static_cast(inputTokens->size()); + auto llmRequest = std::make_shared(0, maxNewTokens, inputTokens, samplingConfig, isStreaming); + + // Set high priority for context blocks + llmRequest->setKvCacheRetentionConfig(KvCacheRetentionConfig( + {KvCacheRetentionConfig::TokenRangeRetentionConfig(0, std::nullopt, highPriority)}, highPriority)); + + kvCacheManager.addSequence(0, inputLength, beamWidth, llmRequest); + kvCacheManager.storeContextBlocks(*llmRequest); + + // Get block IDs for the sequence + auto const& seq = kvCacheManager.getSequence(0); + auto cacheBlockIds = seq.getCacheBlockIds(maxAttentionWindow).at(0); + ASSERT_GE(cacheBlockIds.size(), 1); + + // Test 1: Valid block ID should return the set priority + auto const validBlockId = cacheBlockIds[0]; + auto const retrievedPriority = kvCacheManager.getPriorityByBlockId(validBlockId, maxAttentionWindow); + EXPECT_EQ(retrievedPriority, highPriority); + + // Test 2: Invalid block ID (negative) should return default priority + auto const invalidNegative = kvCacheManager.getPriorityByBlockId(-1, maxAttentionWindow); + EXPECT_EQ(invalidNegative, KvCacheRetentionConfig::kDefaultRetentionPriority); + + // Test 3: Invalid block ID (out of range) should return default priority + auto const invalidOutOfRange = kvCacheManager.getPriorityByBlockId(9999, maxAttentionWindow); + EXPECT_EQ(invalidOutOfRange, KvCacheRetentionConfig::kDefaultRetentionPriority); +} + TEST(KVCacheManagerHelpersTest, ChopVectorIntoBlocksBasicNoPartial) { using namespace tensorrt_llm::batch_manager::kv_cache_manager; diff --git a/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py index 380486e935..e2000a35b5 100644 --- a/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py +++ b/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py @@ -69,6 +69,9 @@ class RequestData: computed_position: int # The number of scheduled tokens for the upcoming forward pass. num_scheduled_tokens: int + # The retention priorities for each new block (same length as new_block_ids). + # Used for priority-based offload filtering. None means use default priority. + priorities: Optional[List[int]] = None # A class to store some basic data regarding all inflight requests. @@ -314,8 +317,17 @@ class KvCacheConnectorSchedulerOutputRequest: num_scheduled_tokens = 1 + get_draft_token_length( req) # Specdec with draft tokens is not supported yet. + # Get retention priority for each new block only if retention config is provided + # (for priority-based offload filtering) + priorities = None + if req.kv_cache_retention_config is not None: + priorities = [ + kv_cache_manager.get_priority_by_block_id(block_id) + for block_id in new_block_ids + ] + return RequestData(req.request_id, new_tokens, new_block_ids, - computed_position, num_scheduled_tokens) + computed_position, num_scheduled_tokens, priorities) class KvCacheConnectorSchedulerOutputManager: diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 5cf2458bd4..ce6da4b02c 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -560,6 +560,14 @@ def create_py_executor( raise NotImplementedError( "KV connector is only supported with guaranteed no evict scheduler policy." ) + + max_attention_window = kv_cache_config.max_attention_window + if max_attention_window is not None and len( + set(max_attention_window)) > 1: + raise NotImplementedError( + "KV connector is not supported with VSWA (Variable Sliding Window Attention)." + ) + try: module = importlib.import_module( kv_connector_config.connector_module) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 21a98dbf27..a5548946d8 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -901,6 +901,25 @@ class KVCacheManager(BaseResourceManager): def get_last_block_id(self, request_id: int) -> int: return self.impl.get_last_block_id(request_id) + def get_priority_by_block_id(self, + block_id: int, + window_size: Optional[int] = None) -> int: + """Get the retention priority of a block by its ID. + + Args: + block_id: The ID of the block. + window_size: The attention window size this block belongs to. + Required for VSWA configurations with multiple window sizes. + + Returns: + The retention priority of the block (0-100), or default priority (35) if not found. + """ + if window_size is None: + if len(self.max_attention_window_vec) > 1: + raise ValueError("window_size must be provided for VSWA") + window_size = self.max_attention_window_vec[0] + return self.impl.get_priority_by_block_id(block_id, window_size) + def get_batch_cache_indices( self, request_ids: List[int], diff --git a/tests/integration/defs/llmapi/test_llm_api_connector.py b/tests/integration/defs/llmapi/test_llm_api_connector.py index f3053d73f1..16b50785f7 100644 --- a/tests/integration/defs/llmapi/test_llm_api_connector.py +++ b/tests/integration/defs/llmapi/test_llm_api_connector.py @@ -22,6 +22,7 @@ import pytest from tensorrt_llm import LLM, DisaggregatedParams, SamplingParams from tensorrt_llm.llmapi.llm_args import (CacheTransceiverConfig, KvCacheConfig, KvCacheConnectorConfig) +from tensorrt_llm.llmapi.llm_utils import KvCacheRetentionConfig from ..conftest import llm_models_root @@ -439,3 +440,100 @@ def test_connector_multi_request(enforce_single_worker, model_with_connector): # The KV cache of both prior requests should be freed, allowing the third request to run. model.generate([2] * 110, sampling_params=sampling_params) + + +@pytest.mark.threadleak(enabled=False) +def test_connector_priorities(enforce_single_worker, model_with_connector): + """Test that retention priorities flow through the connector correctly. + + This test verifies that when KvCacheRetentionConfig is provided, + the RequestData.priorities field is populated with the correct + per-block priorities based on the token ranges. + """ + BLOCK_SIZE = 32 + NUM_INPUT_TOKENS = 64 # 2 blocks + NUM_TOKENS = 4 + HIGH_PRIORITY = 80 # For system prompt blocks + LOW_PRIORITY = 10 # For user input / decode blocks + + model_fn, scheduler, worker = model_with_connector + + model = model_fn(disable_overlap_scheduler=True) + + scheduler.get_num_new_matched_tokens.return_value = 0, False + worker.get_finished.return_value = [], [] + + # Create retention config with different priorities for different token ranges: + # - First 32 tokens (block 0): high priority (e.g., system prompt) + # - Remaining tokens (block 1+): low priority (e.g., user input) + retention_config = KvCacheRetentionConfig( + token_range_retention_priorities=[ + KvCacheRetentionConfig.TokenRangeRetentionConfig( + token_start=0, + token_end=32, + priority=HIGH_PRIORITY, + ), + KvCacheRetentionConfig.TokenRangeRetentionConfig( + token_start=32, + token_end=None, # Extend to end of sequence + priority=LOW_PRIORITY, + ), + ], + decode_retention_priority=LOW_PRIORITY, + ) + + sampling_params = SamplingParams(max_tokens=NUM_TOKENS, ignore_eos=True) + + generate_and_sleep(model, [0] * NUM_INPUT_TOKENS, + sampling_params=sampling_params, + kv_cache_retention_config=retention_config) + + # Verify that build_connector_meta was called + assert scheduler.build_connector_meta.call_count >= 1 + + # Check the first call (new request) has priorities set + first_call = scheduler.build_connector_meta.call_args_list[0] + sched_output = first_call.args[0] + + assert len(sched_output.new_requests) == 1 + request = sched_output.new_requests[0] + + # Should have 2 blocks for 64 input tokens with block size 32 + expected_num_blocks = math.ceil(NUM_INPUT_TOKENS / BLOCK_SIZE) + assert len(request.new_block_ids) == expected_num_blocks + + # Priorities should be set and match the retention config + assert request.priorities is not None + assert len(request.priorities) == len(request.new_block_ids) + + # First block should have high priority, second block should have low priority + assert request.priorities[ + 0] == HIGH_PRIORITY, f"Expected priority {HIGH_PRIORITY} for block 0, got {request.priorities[0]}" + assert request.priorities[ + 1] == LOW_PRIORITY, f"Expected priority {LOW_PRIORITY} for block 1, got {request.priorities[1]}" + + +@pytest.mark.threadleak(enabled=False) +def test_connector_priorities_default(enforce_single_worker, + model_with_connector): + """Test that priorities are None when no retention config is provided.""" + model_fn, scheduler, worker = model_with_connector + + model = model_fn(disable_overlap_scheduler=True) + + scheduler.get_num_new_matched_tokens.return_value = 0, False + worker.get_finished.return_value = [], [] + + sampling_params = SamplingParams(max_tokens=4, ignore_eos=True) + + # Generate without retention config + generate_and_sleep(model, [0] * 48, sampling_params=sampling_params) + + first_call = scheduler.build_connector_meta.call_args_list[0] + sched_output = first_call.args[0] + + assert len(sched_output.new_requests) == 1 + request = sched_output.new_requests[0] + + # Without retention config, priorities should be None + assert request.priorities is None