mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[None][feat] Add priority-based KV cache offload filtering support (#10751)
Signed-off-by: Yuewei Na <yna@nvidia.com> Signed-off-by: Yuewei Na <nv-yna@users.noreply.github.com> Co-authored-by: Yuewei Na <nv-yna@users.noreply.github.com>
This commit is contained in:
parent
9601b17459
commit
0d18b2d7a4
@ -1687,6 +1687,14 @@ public:
|
||||
= 0;
|
||||
|
||||
virtual void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds) = 0;
|
||||
|
||||
//! @brief Get the retention priority of a block by its ID.
|
||||
//! @param blockId The ID of the block.
|
||||
//! @param windowSize The attention window size this block belongs to.
|
||||
//! @return The retention priority of the block, or default priority if block not found.
|
||||
[[nodiscard]] virtual executor::RetentionPriority getPriorityByBlockId(
|
||||
KVCacheBlock::IdType blockId, SizeType32 windowSize) const
|
||||
= 0;
|
||||
};
|
||||
|
||||
class KVCacheManager : public BaseKVCacheManager
|
||||
@ -1970,6 +1978,9 @@ public:
|
||||
|
||||
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds) override;
|
||||
|
||||
[[nodiscard]] executor::RetentionPriority getPriorityByBlockId(
|
||||
KVCacheBlock::IdType blockId, SizeType32 windowSize) const override;
|
||||
|
||||
std::optional<KVCacheBlock::IdType> getLastBlockId(LlmRequest::RequestIdType requestId) const override;
|
||||
|
||||
/// @brief Calculates the number of kv-cache blocks that a sequence will require, for a single beam.
|
||||
|
||||
@ -2605,6 +2605,26 @@ void KVCacheManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& bl
|
||||
mBlockManager.unpinBlocksById(blockIds);
|
||||
}
|
||||
|
||||
tle::RetentionPriority KVCacheManager::getPriorityByBlockId(KVCacheBlock::IdType blockId, SizeType32 windowSize) const
|
||||
{
|
||||
try
|
||||
{
|
||||
BlockPtr const& block = mBlockManager.getBlockById(blockId, windowSize);
|
||||
if (block)
|
||||
{
|
||||
return block->getPriority();
|
||||
}
|
||||
TLLM_LOG_WARNING("getPriorityByBlockId: Block ID %d not found in window %d", blockId, windowSize);
|
||||
return tle::KvCacheRetentionConfig::kDefaultRetentionPriority;
|
||||
}
|
||||
catch (std::out_of_range const& ex)
|
||||
{
|
||||
TLLM_LOG_WARNING(
|
||||
"getPriorityByBlockId: Block ID %d or window size %d out of range: %s", blockId, windowSize, ex.what());
|
||||
return tle::KvCacheRetentionConfig::kDefaultRetentionPriority;
|
||||
}
|
||||
}
|
||||
|
||||
SizeType32 KVCacheManager::copyBlockOffsets(ITensor& output, SizeType32 outputSlotOffset, RequestIdType requestId) const
|
||||
{
|
||||
auto const& sequence = getSequence(requestId);
|
||||
|
||||
@ -198,6 +198,7 @@ void initBindings(nb::module_& m)
|
||||
.def_prop_ro("parent_request_id", &GenLlmReq::getParentRequestId)
|
||||
.def_prop_ro("is_child", &GenLlmReq::isChild)
|
||||
.def_prop_ro("cache_salt_id", &GenLlmReq::getCacheSaltID)
|
||||
.def_prop_ro("kv_cache_retention_config", &GenLlmReq::getKvCacheRetentionConfig)
|
||||
.def_prop_ro("multimodal_hashes",
|
||||
[](GenLlmReq& self)
|
||||
{
|
||||
|
||||
@ -494,7 +494,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
|
||||
.def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, nb::call_guard<nb::gil_scoped_release>());
|
||||
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("get_priority_by_block_id", &BaseKVCacheManager::getPriorityByBlockId, nb::arg("block_id"),
|
||||
nb::arg("window_size"), nb::call_guard<nb::gil_scoped_release>());
|
||||
|
||||
nb::bind_vector<CacheBlockIds>(m, "CacheBlockIds")
|
||||
.def("__getstate__", [](CacheBlockIds const& v) { return nb::make_tuple(v); })
|
||||
|
||||
@ -4094,6 +4094,61 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamPriority)
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(KVCacheManagerTest, GetPriorityByBlockId)
|
||||
{
|
||||
auto constexpr numLayers = 2;
|
||||
auto constexpr numKvHeads = 2;
|
||||
auto constexpr sizePerHead = 16;
|
||||
auto constexpr tokensPerBlock = 4;
|
||||
auto constexpr numBlocks = 8;
|
||||
auto constexpr maxAttentionWindow = 32;
|
||||
auto constexpr maxNumSequences = 4;
|
||||
auto constexpr beamWidth = 1;
|
||||
auto constexpr dtype = nvinfer1::DataType::kHALF;
|
||||
auto const stream = std::make_shared<tr::CudaStream>();
|
||||
SizeType32 constexpr maxNewTokens = 4;
|
||||
tr::SamplingConfig const samplingConfig{beamWidth};
|
||||
bool constexpr isStreaming{false};
|
||||
tle::RetentionPriority constexpr highPriority = 80;
|
||||
|
||||
auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {numBlocks, 0}}};
|
||||
|
||||
KVCacheManager kvCacheManager(numLayers, numKvHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences,
|
||||
beamWidth, std::vector<BlockManager::SizeType32>{maxAttentionWindow}, std::nullopt, dtype, 0, stream,
|
||||
maxAttentionWindow, true);
|
||||
kvCacheManager.allocatePools(false);
|
||||
|
||||
// Create a sequence and set a custom priority
|
||||
auto inputTokens = std::make_shared<VecTokens>(VecTokens{0, 1, 2, 3, 4, 5, 6, 7});
|
||||
auto const inputLength = static_cast<SizeType32>(inputTokens->size());
|
||||
auto llmRequest = std::make_shared<LlmRequest>(0, maxNewTokens, inputTokens, samplingConfig, isStreaming);
|
||||
|
||||
// Set high priority for context blocks
|
||||
llmRequest->setKvCacheRetentionConfig(KvCacheRetentionConfig(
|
||||
{KvCacheRetentionConfig::TokenRangeRetentionConfig(0, std::nullopt, highPriority)}, highPriority));
|
||||
|
||||
kvCacheManager.addSequence(0, inputLength, beamWidth, llmRequest);
|
||||
kvCacheManager.storeContextBlocks(*llmRequest);
|
||||
|
||||
// Get block IDs for the sequence
|
||||
auto const& seq = kvCacheManager.getSequence(0);
|
||||
auto cacheBlockIds = seq.getCacheBlockIds(maxAttentionWindow).at(0);
|
||||
ASSERT_GE(cacheBlockIds.size(), 1);
|
||||
|
||||
// Test 1: Valid block ID should return the set priority
|
||||
auto const validBlockId = cacheBlockIds[0];
|
||||
auto const retrievedPriority = kvCacheManager.getPriorityByBlockId(validBlockId, maxAttentionWindow);
|
||||
EXPECT_EQ(retrievedPriority, highPriority);
|
||||
|
||||
// Test 2: Invalid block ID (negative) should return default priority
|
||||
auto const invalidNegative = kvCacheManager.getPriorityByBlockId(-1, maxAttentionWindow);
|
||||
EXPECT_EQ(invalidNegative, KvCacheRetentionConfig::kDefaultRetentionPriority);
|
||||
|
||||
// Test 3: Invalid block ID (out of range) should return default priority
|
||||
auto const invalidOutOfRange = kvCacheManager.getPriorityByBlockId(9999, maxAttentionWindow);
|
||||
EXPECT_EQ(invalidOutOfRange, KvCacheRetentionConfig::kDefaultRetentionPriority);
|
||||
}
|
||||
|
||||
TEST(KVCacheManagerHelpersTest, ChopVectorIntoBlocksBasicNoPartial)
|
||||
{
|
||||
using namespace tensorrt_llm::batch_manager::kv_cache_manager;
|
||||
|
||||
@ -69,6 +69,9 @@ class RequestData:
|
||||
computed_position: int
|
||||
# The number of scheduled tokens for the upcoming forward pass.
|
||||
num_scheduled_tokens: int
|
||||
# The retention priorities for each new block (same length as new_block_ids).
|
||||
# Used for priority-based offload filtering. None means use default priority.
|
||||
priorities: Optional[List[int]] = None
|
||||
|
||||
|
||||
# A class to store some basic data regarding all inflight requests.
|
||||
@ -314,8 +317,17 @@ class KvCacheConnectorSchedulerOutputRequest:
|
||||
num_scheduled_tokens = 1 + get_draft_token_length(
|
||||
req) # Specdec with draft tokens is not supported yet.
|
||||
|
||||
# Get retention priority for each new block only if retention config is provided
|
||||
# (for priority-based offload filtering)
|
||||
priorities = None
|
||||
if req.kv_cache_retention_config is not None:
|
||||
priorities = [
|
||||
kv_cache_manager.get_priority_by_block_id(block_id)
|
||||
for block_id in new_block_ids
|
||||
]
|
||||
|
||||
return RequestData(req.request_id, new_tokens, new_block_ids,
|
||||
computed_position, num_scheduled_tokens)
|
||||
computed_position, num_scheduled_tokens, priorities)
|
||||
|
||||
|
||||
class KvCacheConnectorSchedulerOutputManager:
|
||||
|
||||
@ -560,6 +560,14 @@ def create_py_executor(
|
||||
raise NotImplementedError(
|
||||
"KV connector is only supported with guaranteed no evict scheduler policy."
|
||||
)
|
||||
|
||||
max_attention_window = kv_cache_config.max_attention_window
|
||||
if max_attention_window is not None and len(
|
||||
set(max_attention_window)) > 1:
|
||||
raise NotImplementedError(
|
||||
"KV connector is not supported with VSWA (Variable Sliding Window Attention)."
|
||||
)
|
||||
|
||||
try:
|
||||
module = importlib.import_module(
|
||||
kv_connector_config.connector_module)
|
||||
|
||||
@ -901,6 +901,25 @@ class KVCacheManager(BaseResourceManager):
|
||||
def get_last_block_id(self, request_id: int) -> int:
|
||||
return self.impl.get_last_block_id(request_id)
|
||||
|
||||
def get_priority_by_block_id(self,
|
||||
block_id: int,
|
||||
window_size: Optional[int] = None) -> int:
|
||||
"""Get the retention priority of a block by its ID.
|
||||
|
||||
Args:
|
||||
block_id: The ID of the block.
|
||||
window_size: The attention window size this block belongs to.
|
||||
Required for VSWA configurations with multiple window sizes.
|
||||
|
||||
Returns:
|
||||
The retention priority of the block (0-100), or default priority (35) if not found.
|
||||
"""
|
||||
if window_size is None:
|
||||
if len(self.max_attention_window_vec) > 1:
|
||||
raise ValueError("window_size must be provided for VSWA")
|
||||
window_size = self.max_attention_window_vec[0]
|
||||
return self.impl.get_priority_by_block_id(block_id, window_size)
|
||||
|
||||
def get_batch_cache_indices(
|
||||
self,
|
||||
request_ids: List[int],
|
||||
|
||||
@ -22,6 +22,7 @@ import pytest
|
||||
from tensorrt_llm import LLM, DisaggregatedParams, SamplingParams
|
||||
from tensorrt_llm.llmapi.llm_args import (CacheTransceiverConfig, KvCacheConfig,
|
||||
KvCacheConnectorConfig)
|
||||
from tensorrt_llm.llmapi.llm_utils import KvCacheRetentionConfig
|
||||
|
||||
from ..conftest import llm_models_root
|
||||
|
||||
@ -439,3 +440,100 @@ def test_connector_multi_request(enforce_single_worker, model_with_connector):
|
||||
|
||||
# The KV cache of both prior requests should be freed, allowing the third request to run.
|
||||
model.generate([2] * 110, sampling_params=sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
def test_connector_priorities(enforce_single_worker, model_with_connector):
|
||||
"""Test that retention priorities flow through the connector correctly.
|
||||
|
||||
This test verifies that when KvCacheRetentionConfig is provided,
|
||||
the RequestData.priorities field is populated with the correct
|
||||
per-block priorities based on the token ranges.
|
||||
"""
|
||||
BLOCK_SIZE = 32
|
||||
NUM_INPUT_TOKENS = 64 # 2 blocks
|
||||
NUM_TOKENS = 4
|
||||
HIGH_PRIORITY = 80 # For system prompt blocks
|
||||
LOW_PRIORITY = 10 # For user input / decode blocks
|
||||
|
||||
model_fn, scheduler, worker = model_with_connector
|
||||
|
||||
model = model_fn(disable_overlap_scheduler=True)
|
||||
|
||||
scheduler.get_num_new_matched_tokens.return_value = 0, False
|
||||
worker.get_finished.return_value = [], []
|
||||
|
||||
# Create retention config with different priorities for different token ranges:
|
||||
# - First 32 tokens (block 0): high priority (e.g., system prompt)
|
||||
# - Remaining tokens (block 1+): low priority (e.g., user input)
|
||||
retention_config = KvCacheRetentionConfig(
|
||||
token_range_retention_priorities=[
|
||||
KvCacheRetentionConfig.TokenRangeRetentionConfig(
|
||||
token_start=0,
|
||||
token_end=32,
|
||||
priority=HIGH_PRIORITY,
|
||||
),
|
||||
KvCacheRetentionConfig.TokenRangeRetentionConfig(
|
||||
token_start=32,
|
||||
token_end=None, # Extend to end of sequence
|
||||
priority=LOW_PRIORITY,
|
||||
),
|
||||
],
|
||||
decode_retention_priority=LOW_PRIORITY,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=NUM_TOKENS, ignore_eos=True)
|
||||
|
||||
generate_and_sleep(model, [0] * NUM_INPUT_TOKENS,
|
||||
sampling_params=sampling_params,
|
||||
kv_cache_retention_config=retention_config)
|
||||
|
||||
# Verify that build_connector_meta was called
|
||||
assert scheduler.build_connector_meta.call_count >= 1
|
||||
|
||||
# Check the first call (new request) has priorities set
|
||||
first_call = scheduler.build_connector_meta.call_args_list[0]
|
||||
sched_output = first_call.args[0]
|
||||
|
||||
assert len(sched_output.new_requests) == 1
|
||||
request = sched_output.new_requests[0]
|
||||
|
||||
# Should have 2 blocks for 64 input tokens with block size 32
|
||||
expected_num_blocks = math.ceil(NUM_INPUT_TOKENS / BLOCK_SIZE)
|
||||
assert len(request.new_block_ids) == expected_num_blocks
|
||||
|
||||
# Priorities should be set and match the retention config
|
||||
assert request.priorities is not None
|
||||
assert len(request.priorities) == len(request.new_block_ids)
|
||||
|
||||
# First block should have high priority, second block should have low priority
|
||||
assert request.priorities[
|
||||
0] == HIGH_PRIORITY, f"Expected priority {HIGH_PRIORITY} for block 0, got {request.priorities[0]}"
|
||||
assert request.priorities[
|
||||
1] == LOW_PRIORITY, f"Expected priority {LOW_PRIORITY} for block 1, got {request.priorities[1]}"
|
||||
|
||||
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
def test_connector_priorities_default(enforce_single_worker,
|
||||
model_with_connector):
|
||||
"""Test that priorities are None when no retention config is provided."""
|
||||
model_fn, scheduler, worker = model_with_connector
|
||||
|
||||
model = model_fn(disable_overlap_scheduler=True)
|
||||
|
||||
scheduler.get_num_new_matched_tokens.return_value = 0, False
|
||||
worker.get_finished.return_value = [], []
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=4, ignore_eos=True)
|
||||
|
||||
# Generate without retention config
|
||||
generate_and_sleep(model, [0] * 48, sampling_params=sampling_params)
|
||||
|
||||
first_call = scheduler.build_connector_meta.call_args_list[0]
|
||||
sched_output = first_call.args[0]
|
||||
|
||||
assert len(sched_output.new_requests) == 1
|
||||
request = sched_output.new_requests[0]
|
||||
|
||||
# Without retention config, priorities should be None
|
||||
assert request.priorities is None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user