[None][feat] Add priority-based KV cache offload filtering support (#10751)

Signed-off-by: Yuewei Na <yna@nvidia.com>
Signed-off-by: Yuewei Na <nv-yna@users.noreply.github.com>
Co-authored-by: Yuewei Na <nv-yna@users.noreply.github.com>
This commit is contained in:
Yuewei Na 2026-02-05 02:22:56 -08:00 committed by GitHub
parent 9601b17459
commit 0d18b2d7a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 228 additions and 2 deletions

View File

@ -1687,6 +1687,14 @@ public:
= 0;
virtual void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds) = 0;
//! @brief Get the retention priority of a block by its ID.
//! @param blockId The ID of the block.
//! @param windowSize The attention window size this block belongs to.
//! @return The retention priority of the block, or default priority if block not found.
[[nodiscard]] virtual executor::RetentionPriority getPriorityByBlockId(
KVCacheBlock::IdType blockId, SizeType32 windowSize) const
= 0;
};
class KVCacheManager : public BaseKVCacheManager
@ -1970,6 +1978,9 @@ public:
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds) override;
[[nodiscard]] executor::RetentionPriority getPriorityByBlockId(
KVCacheBlock::IdType blockId, SizeType32 windowSize) const override;
std::optional<KVCacheBlock::IdType> getLastBlockId(LlmRequest::RequestIdType requestId) const override;
/// @brief Calculates the number of kv-cache blocks that a sequence will require, for a single beam.

View File

@ -2605,6 +2605,26 @@ void KVCacheManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& bl
mBlockManager.unpinBlocksById(blockIds);
}
tle::RetentionPriority KVCacheManager::getPriorityByBlockId(KVCacheBlock::IdType blockId, SizeType32 windowSize) const
{
try
{
BlockPtr const& block = mBlockManager.getBlockById(blockId, windowSize);
if (block)
{
return block->getPriority();
}
TLLM_LOG_WARNING("getPriorityByBlockId: Block ID %d not found in window %d", blockId, windowSize);
return tle::KvCacheRetentionConfig::kDefaultRetentionPriority;
}
catch (std::out_of_range const& ex)
{
TLLM_LOG_WARNING(
"getPriorityByBlockId: Block ID %d or window size %d out of range: %s", blockId, windowSize, ex.what());
return tle::KvCacheRetentionConfig::kDefaultRetentionPriority;
}
}
SizeType32 KVCacheManager::copyBlockOffsets(ITensor& output, SizeType32 outputSlotOffset, RequestIdType requestId) const
{
auto const& sequence = getSequence(requestId);

View File

@ -198,6 +198,7 @@ void initBindings(nb::module_& m)
.def_prop_ro("parent_request_id", &GenLlmReq::getParentRequestId)
.def_prop_ro("is_child", &GenLlmReq::isChild)
.def_prop_ro("cache_salt_id", &GenLlmReq::getCacheSaltID)
.def_prop_ro("kv_cache_retention_config", &GenLlmReq::getKvCacheRetentionConfig)
.def_prop_ro("multimodal_hashes",
[](GenLlmReq& self)
{

View File

@ -494,7 +494,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
.def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, nb::call_guard<nb::gil_scoped_release>())
.def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, nb::call_guard<nb::gil_scoped_release>())
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, nb::call_guard<nb::gil_scoped_release>())
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, nb::call_guard<nb::gil_scoped_release>());
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, nb::call_guard<nb::gil_scoped_release>())
.def("get_priority_by_block_id", &BaseKVCacheManager::getPriorityByBlockId, nb::arg("block_id"),
nb::arg("window_size"), nb::call_guard<nb::gil_scoped_release>());
nb::bind_vector<CacheBlockIds>(m, "CacheBlockIds")
.def("__getstate__", [](CacheBlockIds const& v) { return nb::make_tuple(v); })

View File

@ -4094,6 +4094,61 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamPriority)
}
}
TEST_F(KVCacheManagerTest, GetPriorityByBlockId)
{
auto constexpr numLayers = 2;
auto constexpr numKvHeads = 2;
auto constexpr sizePerHead = 16;
auto constexpr tokensPerBlock = 4;
auto constexpr numBlocks = 8;
auto constexpr maxAttentionWindow = 32;
auto constexpr maxNumSequences = 4;
auto constexpr beamWidth = 1;
auto constexpr dtype = nvinfer1::DataType::kHALF;
auto const stream = std::make_shared<tr::CudaStream>();
SizeType32 constexpr maxNewTokens = 4;
tr::SamplingConfig const samplingConfig{beamWidth};
bool constexpr isStreaming{false};
tle::RetentionPriority constexpr highPriority = 80;
auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {numBlocks, 0}}};
KVCacheManager kvCacheManager(numLayers, numKvHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences,
beamWidth, std::vector<BlockManager::SizeType32>{maxAttentionWindow}, std::nullopt, dtype, 0, stream,
maxAttentionWindow, true);
kvCacheManager.allocatePools(false);
// Create a sequence and set a custom priority
auto inputTokens = std::make_shared<VecTokens>(VecTokens{0, 1, 2, 3, 4, 5, 6, 7});
auto const inputLength = static_cast<SizeType32>(inputTokens->size());
auto llmRequest = std::make_shared<LlmRequest>(0, maxNewTokens, inputTokens, samplingConfig, isStreaming);
// Set high priority for context blocks
llmRequest->setKvCacheRetentionConfig(KvCacheRetentionConfig(
{KvCacheRetentionConfig::TokenRangeRetentionConfig(0, std::nullopt, highPriority)}, highPriority));
kvCacheManager.addSequence(0, inputLength, beamWidth, llmRequest);
kvCacheManager.storeContextBlocks(*llmRequest);
// Get block IDs for the sequence
auto const& seq = kvCacheManager.getSequence(0);
auto cacheBlockIds = seq.getCacheBlockIds(maxAttentionWindow).at(0);
ASSERT_GE(cacheBlockIds.size(), 1);
// Test 1: Valid block ID should return the set priority
auto const validBlockId = cacheBlockIds[0];
auto const retrievedPriority = kvCacheManager.getPriorityByBlockId(validBlockId, maxAttentionWindow);
EXPECT_EQ(retrievedPriority, highPriority);
// Test 2: Invalid block ID (negative) should return default priority
auto const invalidNegative = kvCacheManager.getPriorityByBlockId(-1, maxAttentionWindow);
EXPECT_EQ(invalidNegative, KvCacheRetentionConfig::kDefaultRetentionPriority);
// Test 3: Invalid block ID (out of range) should return default priority
auto const invalidOutOfRange = kvCacheManager.getPriorityByBlockId(9999, maxAttentionWindow);
EXPECT_EQ(invalidOutOfRange, KvCacheRetentionConfig::kDefaultRetentionPriority);
}
TEST(KVCacheManagerHelpersTest, ChopVectorIntoBlocksBasicNoPartial)
{
using namespace tensorrt_llm::batch_manager::kv_cache_manager;

View File

@ -69,6 +69,9 @@ class RequestData:
computed_position: int
# The number of scheduled tokens for the upcoming forward pass.
num_scheduled_tokens: int
# The retention priorities for each new block (same length as new_block_ids).
# Used for priority-based offload filtering. None means use default priority.
priorities: Optional[List[int]] = None
# A class to store some basic data regarding all inflight requests.
@ -314,8 +317,17 @@ class KvCacheConnectorSchedulerOutputRequest:
num_scheduled_tokens = 1 + get_draft_token_length(
req) # Specdec with draft tokens is not supported yet.
# Get retention priority for each new block only if retention config is provided
# (for priority-based offload filtering)
priorities = None
if req.kv_cache_retention_config is not None:
priorities = [
kv_cache_manager.get_priority_by_block_id(block_id)
for block_id in new_block_ids
]
return RequestData(req.request_id, new_tokens, new_block_ids,
computed_position, num_scheduled_tokens)
computed_position, num_scheduled_tokens, priorities)
class KvCacheConnectorSchedulerOutputManager:

View File

@ -560,6 +560,14 @@ def create_py_executor(
raise NotImplementedError(
"KV connector is only supported with guaranteed no evict scheduler policy."
)
max_attention_window = kv_cache_config.max_attention_window
if max_attention_window is not None and len(
set(max_attention_window)) > 1:
raise NotImplementedError(
"KV connector is not supported with VSWA (Variable Sliding Window Attention)."
)
try:
module = importlib.import_module(
kv_connector_config.connector_module)

View File

@ -901,6 +901,25 @@ class KVCacheManager(BaseResourceManager):
def get_last_block_id(self, request_id: int) -> int:
return self.impl.get_last_block_id(request_id)
def get_priority_by_block_id(self,
block_id: int,
window_size: Optional[int] = None) -> int:
"""Get the retention priority of a block by its ID.
Args:
block_id: The ID of the block.
window_size: The attention window size this block belongs to.
Required for VSWA configurations with multiple window sizes.
Returns:
The retention priority of the block (0-100), or default priority (35) if not found.
"""
if window_size is None:
if len(self.max_attention_window_vec) > 1:
raise ValueError("window_size must be provided for VSWA")
window_size = self.max_attention_window_vec[0]
return self.impl.get_priority_by_block_id(block_id, window_size)
def get_batch_cache_indices(
self,
request_ids: List[int],

View File

@ -22,6 +22,7 @@ import pytest
from tensorrt_llm import LLM, DisaggregatedParams, SamplingParams
from tensorrt_llm.llmapi.llm_args import (CacheTransceiverConfig, KvCacheConfig,
KvCacheConnectorConfig)
from tensorrt_llm.llmapi.llm_utils import KvCacheRetentionConfig
from ..conftest import llm_models_root
@ -439,3 +440,100 @@ def test_connector_multi_request(enforce_single_worker, model_with_connector):
# The KV cache of both prior requests should be freed, allowing the third request to run.
model.generate([2] * 110, sampling_params=sampling_params)
@pytest.mark.threadleak(enabled=False)
def test_connector_priorities(enforce_single_worker, model_with_connector):
"""Test that retention priorities flow through the connector correctly.
This test verifies that when KvCacheRetentionConfig is provided,
the RequestData.priorities field is populated with the correct
per-block priorities based on the token ranges.
"""
BLOCK_SIZE = 32
NUM_INPUT_TOKENS = 64 # 2 blocks
NUM_TOKENS = 4
HIGH_PRIORITY = 80 # For system prompt blocks
LOW_PRIORITY = 10 # For user input / decode blocks
model_fn, scheduler, worker = model_with_connector
model = model_fn(disable_overlap_scheduler=True)
scheduler.get_num_new_matched_tokens.return_value = 0, False
worker.get_finished.return_value = [], []
# Create retention config with different priorities for different token ranges:
# - First 32 tokens (block 0): high priority (e.g., system prompt)
# - Remaining tokens (block 1+): low priority (e.g., user input)
retention_config = KvCacheRetentionConfig(
token_range_retention_priorities=[
KvCacheRetentionConfig.TokenRangeRetentionConfig(
token_start=0,
token_end=32,
priority=HIGH_PRIORITY,
),
KvCacheRetentionConfig.TokenRangeRetentionConfig(
token_start=32,
token_end=None, # Extend to end of sequence
priority=LOW_PRIORITY,
),
],
decode_retention_priority=LOW_PRIORITY,
)
sampling_params = SamplingParams(max_tokens=NUM_TOKENS, ignore_eos=True)
generate_and_sleep(model, [0] * NUM_INPUT_TOKENS,
sampling_params=sampling_params,
kv_cache_retention_config=retention_config)
# Verify that build_connector_meta was called
assert scheduler.build_connector_meta.call_count >= 1
# Check the first call (new request) has priorities set
first_call = scheduler.build_connector_meta.call_args_list[0]
sched_output = first_call.args[0]
assert len(sched_output.new_requests) == 1
request = sched_output.new_requests[0]
# Should have 2 blocks for 64 input tokens with block size 32
expected_num_blocks = math.ceil(NUM_INPUT_TOKENS / BLOCK_SIZE)
assert len(request.new_block_ids) == expected_num_blocks
# Priorities should be set and match the retention config
assert request.priorities is not None
assert len(request.priorities) == len(request.new_block_ids)
# First block should have high priority, second block should have low priority
assert request.priorities[
0] == HIGH_PRIORITY, f"Expected priority {HIGH_PRIORITY} for block 0, got {request.priorities[0]}"
assert request.priorities[
1] == LOW_PRIORITY, f"Expected priority {LOW_PRIORITY} for block 1, got {request.priorities[1]}"
@pytest.mark.threadleak(enabled=False)
def test_connector_priorities_default(enforce_single_worker,
model_with_connector):
"""Test that priorities are None when no retention config is provided."""
model_fn, scheduler, worker = model_with_connector
model = model_fn(disable_overlap_scheduler=True)
scheduler.get_num_new_matched_tokens.return_value = 0, False
worker.get_finished.return_value = [], []
sampling_params = SamplingParams(max_tokens=4, ignore_eos=True)
# Generate without retention config
generate_and_sleep(model, [0] * 48, sampling_params=sampling_params)
first_call = scheduler.build_connector_meta.call_args_list[0]
sched_output = first_call.args[0]
assert len(sched_output.new_requests) == 1
request = sched_output.new_requests[0]
# Without retention config, priorities should be None
assert request.priorities is None