From a6183563b6f604ef7b481ce8ce7af359c6dc1b74 Mon Sep 17 00:00:00 2001 From: Wei Zhao <51183510+wzhao18@users.noreply.github.com> Date: Thu, 4 Jun 2026 03:48:31 -0400 Subject: [PATCH] [Prefix Caching] DeepSeekv4 - Support selective prefix-cache retention for sliding-window KV cache (#43447) Signed-off-by: wzhao18 Signed-off-by: Yifan Qiao Co-authored-by: Yifan Qiao --- tests/v1/core/test_kv_cache_utils.py | 37 ++ tests/v1/core/test_prefix_caching.py | 557 +++++++++++++++++++ vllm/envs.py | 12 + vllm/v1/core/block_pool.py | 16 +- vllm/v1/core/kv_cache_coordinator.py | 54 +- vllm/v1/core/kv_cache_utils.py | 21 + vllm/v1/core/single_type_kv_cache_manager.py | 140 +++-- 7 files changed, 792 insertions(+), 45 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 68ad7bc42ef..c2eb576d895 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -358,6 +358,43 @@ def test_free_kv_cache_block_queue_append_n(): ) +def test_free_kv_cache_block_queue_prepend_n(): + # Seed the queue with one block so prepend has an existing head to splice + # in front of (fake_head->b0->fake_tail). + blocks = [KVCacheBlock(block_id=i) for i in range(6)] + queue = FreeKVCacheBlockQueue(blocks[0:1]) + + # Prepend 0 blocks is a no-op. + queue.prepend_n([]) + assert queue.num_free_blocks == 1 + assert queue.fake_free_list_head.next_free_block is blocks[0] + + # Prepend 2 blocks; they land in front of the existing head, in order. + # fake_head->b4->b5->b0->fake_tail + queue.prepend_n(blocks[4:6]) + assert queue.num_free_blocks == 3 + assert queue.fake_free_list_head.next_free_block is blocks[4] + assert blocks[4].prev_free_block is queue.fake_free_list_head + assert blocks[4].next_free_block is blocks[5] + assert blocks[5].prev_free_block is blocks[4] + assert blocks[5].next_free_block is blocks[0] + assert blocks[0].prev_free_block is blocks[5] + assert blocks[0].next_free_block is queue.fake_free_list_tail + assert queue.fake_free_list_tail.prev_free_block is blocks[0] + + # A second prepend goes ahead of everything previously prepended. + # fake_head->b1->b2->b4->b5->b0->fake_tail + queue.prepend_n(blocks[1:3]) + assert queue.num_free_blocks == 5 + assert queue.fake_free_list_head.next_free_block is blocks[1] + assert blocks[1].next_free_block is blocks[2] + assert blocks[2].next_free_block is blocks[4] + + # The popleft order reflects the front-to-back queue order. + assert [queue.popleft().block_id for _ in range(5)] == [1, 2, 4, 5, 0] + assert queue.num_free_blocks == 0 + + def test_free_kv_cache_block_queue_popleft_n(): blocks = [KVCacheBlock(block_id=i) for i in range(6)] # Create an empty FreeKVCacheBlockQueue with these blocks diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 91c5f37b417..f682940756a 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -39,6 +39,7 @@ from vllm.v1.kv_cache_interface import ( KVCacheGroupSpec, KVCacheSpecKind, MambaSpec, + MLAAttentionSpec, SlidingWindowSpec, ) @@ -2875,6 +2876,350 @@ def test_hybrid_cache_blocks_clamped_to_lcm(): ) +def test_hybrid_local_kv_retention_interval_aligns_in_manager(monkeypatch): + """Verify fixed intervals retain sparse tails plus the latest replay tail.""" + monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "64") + block_size = 8 + kv_cache_config = KVCacheConfig( + num_blocks=100, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer1"], + FullAttentionSpec( + block_size=4 * block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float16, + ), + ), + KVCacheGroupSpec( + ["layer2"], + SlidingWindowSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=block_size, + ), + ), + ], + ) + manager = make_kv_cache_manager( + kv_cache_config=kv_cache_config, + max_model_len=8192, + enable_caching=True, + hash_block_size=block_size, + ) + + # The SWA manager uses the configured 64-token interval (a multiple of the + # 32-token lcm_block_size) as its retention segment. For this 128-token + # prompt, the retained SWA tails are the 64-token interval boundary, the + # 96-token replay boundary, and the 128-token interval boundary. + token_ids = [i for i in range(16) for _ in range(block_size)] + req = make_request("0", token_ids, block_size, sha256) + computed_blocks, _ = manager.get_computed_blocks(req) + blocks = manager.allocate_slots( + req, + len(token_ids), + len(computed_blocks.blocks[0]) * block_size, + computed_blocks, + ) + assert blocks is not None + + pool = manager.block_pool + expected_swa_cached = {7, 11, 15} + for i in range(16): + cached = pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[1]) + if i in expected_swa_cached: + assert cached is not None, f"SWA hash {i} should be cached" + else: + assert cached is None, f"SWA hash {i} should not be cached" + + +@pytest.mark.parametrize( + "interval, expected_match", + [ + # scheduler_block_size is 32 (= lcm(4*8, 8)); 33 is not a multiple of it. + ("33", "multiple of scheduler_block_size"), + # A negative multiple (-32 % 32 == 0) must still be rejected explicitly, + # otherwise it would pass the modulo check and silently degrade to dense. + ("-32", "non-negative"), + ], +) +def test_hybrid_local_kv_retention_interval_rejects_invalid( + monkeypatch, interval, expected_match +): + """A retention interval that is negative or not a multiple of + scheduler_block_size errors out at construction time.""" + monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", interval) + block_size = 8 + kv_cache_config = KVCacheConfig( + num_blocks=100, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer1"], + FullAttentionSpec( + block_size=4 * block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float16, + ), + ), + KVCacheGroupSpec( + ["layer2"], + SlidingWindowSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=block_size, + ), + ), + ], + ) + with pytest.raises(ValueError, match=expected_match): + make_kv_cache_manager( + kv_cache_config=kv_cache_config, + max_model_len=8192, + enable_caching=True, + hash_block_size=block_size, + ) + + +def test_hybrid_local_kv_retention_interval_survives_recycling(monkeypatch): + """Verify retained local checkpoints are reused after block recycling.""" + monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "1024") + hash_block_size = 4 + kv_cache_config = KVCacheConfig( + num_blocks=800, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["full"], + MLAAttentionSpec( + block_size=64 * hash_block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + compress_ratio=4, + ), + ), + KVCacheGroupSpec( + ["swa"], + SlidingWindowSpec( + block_size=16 * hash_block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=512, + ), + ), + KVCacheGroupSpec( + ["c128"], + SlidingWindowSpec( + block_size=2 * hash_block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=128, + ), + ), + KVCacheGroupSpec( + ["c4"], + SlidingWindowSpec( + block_size=hash_block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=8, + ), + ), + ], + ) + manager = make_kv_cache_manager( + kv_cache_config=kv_cache_config, + max_model_len=4096, + enable_caching=True, + hash_block_size=hash_block_size, + ) + + def fill_request(request_id: str, token_offset: int) -> list[int]: + token_ids = [ + token_offset + i for i in range(1024) for _ in range(hash_block_size) + ] + fill_req = make_request(request_id, token_ids, hash_block_size, sha256) + while fill_req.num_computed_tokens < len(token_ids): + num_new_tokens = min(512, len(token_ids) - fill_req.num_computed_tokens) + blocks = manager.allocate_slots(fill_req, num_new_tokens) + assert blocks is not None + fill_req.num_computed_tokens += num_new_tokens + manager.free(fill_req) + return token_ids + + token_ids = fill_request("fill_0", 0) + replay_req = make_request("replay", token_ids[:1800], hash_block_size, sha256) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(replay_req) + assert num_computed_tokens == 1024 + assert [len(blocks) for blocks in computed_blocks.blocks] == [4, 16, 128, 256] + + fill_request("fill_1", 100_000) + replay_req = make_request("replay_again", token_ids[:1800], hash_block_size, sha256) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(replay_req) + assert num_computed_tokens == 1024 + assert [len(blocks) for blocks in computed_blocks.blocks] == [4, 16, 128, 256] + + +def test_hybrid_local_kv_retention_latest_only_reuses_replay_boundary(monkeypatch): + """Verify latest-only retention reuses only the replayable prompt boundary.""" + monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "0") + block_size = 8 + kv_cache_config = KVCacheConfig( + num_blocks=100, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer1"], + FullAttentionSpec( + block_size=4 * block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float16, + ), + ), + KVCacheGroupSpec( + ["layer2"], + SlidingWindowSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=block_size, + ), + ), + ], + ) + manager = make_kv_cache_manager( + kv_cache_config=kv_cache_config, + max_model_len=8192, + enable_caching=True, + hash_block_size=block_size, + ) + + token_ids = [i for i in range(16) for _ in range(block_size)] + req0 = make_request("0", token_ids, block_size, sha256) + computed_blocks, _ = manager.get_computed_blocks(req0) + blocks = manager.allocate_slots( + req0, + len(token_ids), + len(computed_blocks.blocks[0]) * block_size, + computed_blocks, + ) + assert blocks is not None + + pool = manager.block_pool + expected_swa_cached = {11} + for i in range(16): + cached = pool.get_cached_block(req0.block_hashes[i], kv_cache_group_ids=[1]) + if i in expected_swa_cached: + assert cached is not None, f"SWA hash {i} should be cached" + else: + assert cached is None, f"SWA hash {i} should not be cached" + + manager.free(req0) + retained_swa_block = pool.get_cached_block(req0.block_hashes[11], [1]) + assert retained_swa_block is not None + assert retained_swa_block[0].ref_cnt == 0 + + req1 = make_request("1", token_ids, block_size, sha256) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + # Full prompt hits intentionally recompute the final block for logits, so + # the longest usable hit is the previous LCM boundary: 96 tokens. + assert num_computed_tokens == 12 * block_size + assert len(computed_blocks.blocks[1]) == 12 + + shorter_req = make_request("2", token_ids[: 12 * block_size], block_size, sha256) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(shorter_req) + assert num_computed_tokens == 0 + assert len(computed_blocks.blocks[1]) == 0 + + +def test_hybrid_local_kv_retention_mtp_reuses_latest_boundary(monkeypatch): + """Verify MTP/EAGLE SWA retention keeps the extra proof block. + + EAGLE/MTP lookup matches one additional local block after the returned + prefix and then drops it. Sparse retention must therefore cache the normal + local tail at the latest replay boundary plus one extra SWA block. + """ + monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "0") + block_size = 8 + kv_cache_config = KVCacheConfig( + num_blocks=100, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["full"], + FullAttentionSpec( + block_size=4 * block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float16, + ), + ), + KVCacheGroupSpec( + ["swa_mtp"], + SlidingWindowSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=block_size, + ), + is_eagle_group=True, + ), + ], + ) + manager = make_kv_cache_manager( + kv_cache_config=kv_cache_config, + max_model_len=8192, + enable_caching=True, + hash_block_size=block_size, + use_eagle=True, + ) + + # 127 tokens: latest replay boundary is floor((127 - 1) / 32) * 32 = 96. + # The EAGLE/MTP SWA lookup group must cache the local tail ending at + # 104 tokens, and that tail is two 8-token blocks wide: hashes 11 and 12. + token_ids = [i for i in range(15) for _ in range(block_size)] + [15] * 7 + req0 = make_request("0", token_ids, block_size, sha256) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + assert num_computed_tokens == 0 + blocks = manager.allocate_slots( + req0, + len(token_ids), + num_computed_tokens, + computed_blocks, + ) + assert blocks is not None + + pool = manager.block_pool + expected_swa_cached = {11, 12} + for i in range(15): + cached = pool.get_cached_block(req0.block_hashes[i], kv_cache_group_ids=[1]) + if i in expected_swa_cached: + assert cached is not None, f"SWA hash {i} should be cached" + else: + assert cached is None, f"SWA hash {i} should not be cached" + + manager.free(req0) + + req1 = make_request("1", token_ids, block_size, sha256) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + assert num_computed_tokens == 12 * block_size + assert [len(blocks) for blocks in computed_blocks.blocks] == [3, 12] + + def test_block_lookup_cache_single_block_per_key(): cache = BlockHashToBlockMap() key0 = BlockHashWithGroupId(b"hash0") @@ -3058,3 +3403,215 @@ def test_can_fit_full_sequence_full_attention_still_gates_oversized(): req = make_request("oversized", list(range(prompt_len)), block_size, sha256) assert manager.allocate_slots(req, block_size, full_sequence_must_fit=True) is None + + +def test_swa_free_split_keeps_cached_tail_ahead_of_scratch(monkeypatch): + """Default path (no retention): freeing an SWA request must place its + uncached scratch blocks at the front of the free queue (recycled first) + and keep its cached checkpoint blocks at the back (retained for prefix + hits). This split is always-on, independent of the retention interval.""" + monkeypatch.delenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", raising=False) + block_size = 8 + kv_cache_config = KVCacheConfig( + num_blocks=100, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer1"], + FullAttentionSpec( + block_size=4 * block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float16, + ), + ), + KVCacheGroupSpec( + ["layer2"], + SlidingWindowSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=block_size, + ), + ), + ], + ) + manager = make_kv_cache_manager( + kv_cache_config=kv_cache_config, + max_model_len=8192, + enable_caching=True, + hash_block_size=block_size, + ) + + token_ids = [i for i in range(16) for _ in range(block_size)] + req = make_request("0", token_ids, block_size, sha256) + computed_blocks, _ = manager.get_computed_blocks(req) + blocks = manager.allocate_slots( + req, + len(token_ids), + len(computed_blocks.blocks[0]) * block_size, + computed_blocks, + ) + assert blocks is not None + + swa_manager = manager.coordinator.single_type_managers[1] + null_block = manager.block_pool.null_block + cached_ids: set[int] = set() + uncached_ids: set[int] = set() + cached_hash_indices: list[int] = [] + for i, block in enumerate(swa_manager.req_to_blocks[req.request_id]): + if block is null_block: + continue + if block.block_hash is None: + uncached_ids.add(block.block_id) + else: + cached_ids.add(block.block_id) + cached_hash_indices.append(i) + # The dense default mask caches only the per-segment tails, so a 16-block + # SWA prompt must produce a mix of retained and scratch blocks. + assert cached_ids, "expected some retained (cached) SWA tail blocks" + assert uncached_ids, "expected some scratch (uncached) SWA blocks" + + manager.free(req) + + order = [ + b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks() + ] + pos = {bid: i for i, bid in enumerate(order)} + # Every scratch block is recycled before every retained block. + assert max(pos[bid] for bid in uncached_ids) < min(pos[bid] for bid in cached_ids) + # The retained tails survive the free and still serve a prefix-cache hit. + for i in cached_hash_indices: + assert ( + manager.block_pool.get_cached_block( + req.block_hashes[i], kv_cache_group_ids=[1] + ) + is not None + ) + + +def _make_pure_swa_manager(block_size, sliding_window, num_blocks=100, **kwargs): + """Single sliding-window group (UnitaryKVCacheCoordinator).""" + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer"], + SlidingWindowSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=sliding_window, + ), + ), + ], + ) + return make_kv_cache_manager( + kv_cache_config=kv_cache_config, + max_model_len=8192, + enable_caching=True, + hash_block_size=block_size, + **kwargs, + ) + + +def test_pure_swa_retention_interval_caches_sparse_tails(monkeypatch): + """Sparse retention must work for a pure-SWA single-group model, not just + hybrid models: only the per-interval tails plus the latest replay tail are + cached, and a replay still hits the latest replayable boundary.""" + monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "64") + block_size = 16 + manager = _make_pure_swa_manager(block_size, sliding_window=block_size) + assert type(manager.coordinator).__name__ == "UnitaryKVCacheCoordinator" + + token_ids = [i for i in range(16) for _ in range(block_size)] + req = make_request("0", token_ids, block_size, sha256) + computed_blocks, _ = manager.get_computed_blocks(req) + blocks = manager.allocate_slots( + req, + len(token_ids), + len(computed_blocks.blocks[0]) * block_size, + computed_blocks, + ) + assert blocks is not None + + pool = manager.block_pool + cached = { + i + for i in range(16) + if pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[0]) + is not None + } + # per_segment = 64 / 16 = 4, need = cdiv(16-1, 16) = 1 -> segment tails at + # i%4==3 -> {3,7,11,15}; latest replay boundary (255//16*16 = 240) -> tail + # block 14. Crucially this is a strict subset of all 16 blocks: retention + # is actually sparse for pure SWA (not silently dense). + assert cached == {3, 7, 11, 14, 15} + + # A replay of the same prompt hits the latest replayable boundary (240). + replay = make_request("1", token_ids, block_size, sha256) + _, num_computed = manager.get_computed_blocks(replay) + assert num_computed == 240 + + +def test_pure_swa_retention_latest_only(monkeypatch): + """`=0` on a pure-SWA model keeps only the latest replay tail.""" + monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "0") + block_size = 16 + manager = _make_pure_swa_manager(block_size, sliding_window=block_size) + + token_ids = [i for i in range(16) for _ in range(block_size)] + req = make_request("0", token_ids, block_size, sha256) + computed_blocks, _ = manager.get_computed_blocks(req) + blocks = manager.allocate_slots( + req, + len(token_ids), + len(computed_blocks.blocks[0]) * block_size, + computed_blocks, + ) + assert blocks is not None + + pool = manager.block_pool + cached = { + i + for i in range(16) + if pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[0]) + is not None + } + # No segment tails (interval 0); only the latest replay tail (block 14). + assert cached == {14} + + replay = make_request("1", token_ids, block_size, sha256) + _, num_computed = manager.get_computed_blocks(replay) + assert num_computed == 240 + + +def test_pure_swa_retention_dense_default_caches_all(monkeypatch): + """With retention unset, a pure-SWA model must keep the dense behavior: + every block boundary is a potential hit, so all blocks are cached.""" + monkeypatch.delenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", raising=False) + block_size = 16 + manager = _make_pure_swa_manager(block_size, sliding_window=block_size) + + token_ids = [i for i in range(16) for _ in range(block_size)] + req = make_request("0", token_ids, block_size, sha256) + computed_blocks, _ = manager.get_computed_blocks(req) + blocks = manager.allocate_slots( + req, + len(token_ids), + len(computed_blocks.blocks[0]) * block_size, + computed_blocks, + ) + assert blocks is not None + + pool = manager.block_pool + cached = { + i + for i in range(16) + if pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[0]) + is not None + } + assert cached == set(range(16)) diff --git a/vllm/envs.py b/vllm/envs.py index dc11fbd224d..bb3bb34284b 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -279,6 +279,7 @@ if TYPE_CHECKING: VLLM_LORA_ENABLE_DUAL_STREAM: bool = False VLLM_GPU_NIC_PCIE_MAPPING: str = "" VLLM_NIC_SELECTION_VARS: str = "" + VLLM_PREFIX_CACHE_RETENTION_INTERVAL: int | None = None def get_default_cache_root(): @@ -1032,6 +1033,17 @@ environment_variables: dict[str, Callable[[], Any]] = { if "VLLM_PLUGINS" not in os.environ else os.environ["VLLM_PLUGINS"].split(",") ), + # Retain local sliding-window KV checkpoints for prefix caching. + # Unset (default) preserves the dense local checkpointing behavior. `0` + # retains only the latest completed prompt boundary. Positive values retain + # checkpoints at the specified interval boundaries (rounded up to the + # prefix-cache alignment). + # Applies to sliding-window attention for now but not yet Mamba/linear attention. + "VLLM_PREFIX_CACHE_RETENTION_INTERVAL": lambda: ( + int(os.environ["VLLM_PREFIX_CACHE_RETENTION_INTERVAL"]) + if "VLLM_PREFIX_CACHE_RETENTION_INTERVAL" in os.environ + else None + ), # a local directory to look in for unrecognized LoRA adapters. # only works if plugins are enabled and # VLLM_ALLOW_RUNTIME_LORA_UPDATING is enabled. diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 513e4bf380b..4202f527082 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -416,21 +416,29 @@ class BlockPool: if self.metrics_collector: self.metrics_collector.on_block_accessed(block) - def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: + def free_blocks( + self, ordered_blocks: Iterable[KVCacheBlock], prepend: bool = False + ) -> None: """Free a list of blocks. The blocks should be ordered by their eviction priority, where the first block will be evicted first. Args: ordered_blocks: A list of blocks to free ordered by their eviction priority. + prepend: Whether to put newly-free blocks at the front of the free + queue to be prioritized for reuse. """ # Materialize the iterable to allow multiple passes. blocks_list = list(ordered_blocks) for block in blocks_list: block.ref_cnt -= 1 - self.free_block_queue.append_n( - [block for block in blocks_list if block.ref_cnt == 0 and not block.is_null] - ) + freed_blocks = [ + block for block in blocks_list if block.ref_cnt == 0 and not block.is_null + ] + if prepend: + self.free_block_queue.prepend_n(freed_blocks) + else: + self.free_block_queue.append_n(freed_blocks) def evict_blocks(self, block_ids: set[int]) -> None: """evict blocks from the prefix cache by their block IDs. diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 387f1a1e335..89b1e84a44e 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from typing import NamedTuple +from vllm import envs from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector from vllm.v1.core.kv_cache_utils import ( @@ -21,10 +22,41 @@ from vllm.v1.kv_cache_interface import ( FullAttentionSpec, KVCacheConfig, KVCacheSpec, + SlidingWindowSpec, ) from vllm.v1.request import Request +def _validate_prefix_cache_retention_interval( + retention_interval: int | None, + scheduler_block_size: int, + kv_cache_config: KVCacheConfig, +) -> None: + if retention_interval is None: + return + + # Retention only sparsifies sliding-window checkpoints for now; every other + # manager (full attention, Mamba, chunked-local) caches densely and + # ignores it to be conservative. + # TODO: Support Mamba/linear attention. + if not any( + isinstance(g.kv_cache_spec, SlidingWindowSpec) + for g in kv_cache_config.kv_cache_groups + ): + raise ValueError( + "VLLM_PREFIX_CACHE_RETENTION_INTERVAL is set but this model has " + "no sliding-window KV cache group, so retention has no effect. " + "Unset it (the feature only applies to sliding-window attention)." + ) + + if retention_interval < 0 or retention_interval % scheduler_block_size != 0: + raise ValueError( + f"VLLM_PREFIX_CACHE_RETENTION_INTERVAL ({retention_interval}) " + "must be non-negative and a multiple of scheduler_block_size " + f"({scheduler_block_size})." + ) + + class KVCacheCoordinator(ABC): """ Coordinate the KV cache of different KV cache groups. @@ -86,6 +118,14 @@ class KVCacheCoordinator(ABC): for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups) ) + # A positive retention interval must be a multiple of the base hit granularity + # (``scheduler_block_size``) to land on real cache-hit boundaries. + # 0 = keep only the latest replay boundary; None = dense; + self.retention_interval = envs.VLLM_PREFIX_CACHE_RETENTION_INTERVAL + _validate_prefix_cache_retention_interval( + self.retention_interval, self.scheduler_block_size, kv_cache_config + ) + def get_num_blocks_to_allocate( self, request_id: str, @@ -215,7 +255,11 @@ class KVCacheCoordinator(ABC): (including tokens that are already cached). """ for manager in self.single_type_managers: - manager.cache_blocks(request, num_computed_tokens) + manager.cache_blocks( + request, + num_computed_tokens, + retention_interval=self.retention_interval, + ) def free(self, request_id: str) -> None: """ @@ -525,8 +569,14 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): num_computed_tokens, aligned_num_computed_tokens + manager.block_size, ) + # The manager already knows the fine hit granularity + # (``scheduler_block_size``); retention is passed separately so it + # can keep both the coarse segment tails and the fine replay + # boundary (which needs the fine value). manager.cache_blocks( - request, num_tokens_to_cache, alignment_tokens=self.scheduler_block_size + request, + num_tokens_to_cache, + retention_interval=self.retention_interval, ) def find_longest_cache_hit( diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index cfa79f077a1..ae3db581c0d 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -327,6 +327,27 @@ class FreeKVCacheBlockQueue: self.num_free_blocks += 1 + def prepend_n(self, blocks: list[KVCacheBlock]) -> None: + """Put a list of blocks at the front of the free list.""" + if len(blocks) == 0: + return + + first_block = self.fake_free_list_head.next_free_block + assert first_block is not None, ( + "next_free_block of fake_free_list_head should always exist" + ) + + prev_block = self.fake_free_list_head + for block in blocks: + block.prev_free_block = prev_block + prev_block.next_free_block = block + prev_block = block + + prev_block.next_free_block = first_block + first_block.prev_free_block = prev_block + + self.num_free_blocks += len(blocks) + def append_n(self, blocks: list[KVCacheBlock]) -> None: """Put a list of blocks back into the free list diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 281b79639db..a2f2ea6d96a 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -62,6 +62,7 @@ class SingleTypeKVCacheManager(ABC): block until the request finishes. """ self.scheduler_block_size = scheduler_block_size + # The block size for this manager; used for actual block allocation. self.block_size = kv_cache_spec.block_size self.dcp_world_size = dcp_world_size self.pcp_world_size = pcp_world_size @@ -298,7 +299,7 @@ class SingleTypeKVCacheManager(ABC): self, request: Request, num_tokens: int, - alignment_tokens: int | None = None, + retention_interval: int | None = None, ) -> None: """ Cache the blocks for the request. @@ -307,12 +308,10 @@ class SingleTypeKVCacheManager(ABC): request: The request. num_tokens: The total number of tokens that need to be cached (including tokens that are already cached). - alignment_tokens: The cache-hit alignment (in tokens) used by the - coordinator's ``find_longest_cache_hit``. When greater than - this group's ``block_size``, managers whose hit logic only - returns a subset of blocks per alignment-aligned segment - (SWA) skip the rest since they can never participate in a - future cache hit. + retention_interval: Sparse local-checkpoint granularity. ``None`` + keeps dense checkpointing; ``0`` keeps only the latest replay + boundary; a positive multiple of ``scheduler_block_size`` keeps + a tail once per that-sized segment. Only SWA acts on it. """ num_cached_blocks = self.num_cached_block.get(request.request_id, 0) num_full_blocks = num_tokens // self.block_size @@ -320,17 +319,15 @@ class SingleTypeKVCacheManager(ABC): if num_cached_blocks >= num_full_blocks: return - # Fast path: when the coordinator imposes no alignment constraint - if alignment_tokens is None or alignment_tokens <= self.block_size: - block_mask = None - else: - block_mask = self.reachable_block_mask( - num_cached_blocks, - num_full_blocks, - alignment_tokens, - self.kv_cache_spec, - self.use_eagle, - ) + block_mask = self.reachable_block_mask( + start_block=num_cached_blocks, + end_block=num_full_blocks, + alignment_tokens=self.scheduler_block_size, + kv_cache_spec=self.kv_cache_spec, + use_eagle=self.use_eagle, + retention_interval=retention_interval, + num_prompt_tokens=request.num_prompt_tokens, + ) self.block_pool.cache_full_blocks( request=request, blocks=self.req_to_blocks[request.request_id], @@ -347,10 +344,12 @@ class SingleTypeKVCacheManager(ABC): def reachable_block_mask( cls, start_block: int, - num_blocks: int, - alignment_tokens: int, + end_block: int, + alignment_tokens: int | None, kv_cache_spec: KVCacheSpec, use_eagle: bool, + retention_interval: int | None = None, + num_prompt_tokens: int | None = None, ) -> list[bool] | None: """Per-block mask for ``cache_full_blocks``. ``None`` means cache every (non-null) block — the default for full attention. @@ -476,7 +475,12 @@ class SingleTypeKVCacheManager(ABC): # range), so we must cap to the number of blocks that currently exist for # this request. num_skipped_blocks = min(num_skipped_blocks, len(blocks)) - removed_blocks: list[KVCacheBlock] = [] + + # Reuse skipped local blocks in order: + # scratch blocks: no prefix-cache value, reuse first. + # cached blocks: reusable prefix-cache value, reuse last. + removed_cached_blocks: list[KVCacheBlock] = [] + removed_uncached_blocks: list[KVCacheBlock] = [] # Because the block starts from index 0, the num_skipped_block-th block # corresponds to index num_skipped_blocks - 1. for i in range(num_skipped_blocks - 1, -1, -1): @@ -485,9 +489,16 @@ class SingleTypeKVCacheManager(ABC): # should also have been set to null blocks by the previous calls # to this function. break - removed_blocks.append(blocks[i]) + if blocks[i].block_hash is None: + removed_uncached_blocks.append(blocks[i]) + else: + removed_cached_blocks.append(blocks[i]) blocks[i] = self._null_block - self.block_pool.free_blocks(removed_blocks) + # `prepend=True` makes uncached scratch blocks the next allocation + # candidates, while cached blocks stay behind them as best-effort + # prefix-cache entries. + self.block_pool.free_blocks(removed_cached_blocks) + self.block_pool.free_blocks(removed_uncached_blocks, prepend=True) def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: """ @@ -677,30 +688,81 @@ class SlidingWindowManager(SingleTypeKVCacheManager): def reachable_block_mask( cls, start_block: int, - num_blocks: int, - alignment_tokens: int, + end_block: int, + alignment_tokens: int | None, kv_cache_spec: KVCacheSpec, use_eagle: bool, + retention_interval: int | None = None, + num_prompt_tokens: int | None = None, ) -> list[bool] | None: - assert alignment_tokens > kv_cache_spec.block_size assert isinstance(kv_cache_spec, SlidingWindowSpec) - per_segment = alignment_tokens // kv_cache_spec.block_size + if alignment_tokens is None: + # Fast path: when the coordinator imposes no alignment constraint. + return None + assert alignment_tokens % kv_cache_spec.block_size == 0 + + block_size = kv_cache_spec.block_size + # Contiguous blocks a hit needs at a boundary (incl. the EAGLE peek). need = cls._contiguous_blocks_for_hit( window_size=kv_cache_spec.sliding_window, - block_size=kv_cache_spec.block_size, + block_size=block_size, use_eagle=use_eagle, ) - if need >= per_segment: - return None # The matched run's right edge sits on the aligned boundary block when # EAGLE peeks one block past it (shift=1), otherwise on the last block - # before the boundary (shift=0). A block is reachable iff it falls in - # the ``need``-wide run ending at some boundary's right edge. + # before the boundary (shift=0). shift = 1 if use_eagle else 0 - return [ - i >= shift and (i - shift) % per_segment >= per_segment - need - for i in range(start_block, num_blocks) - ] + + mask = [False] * (end_block - start_block) + + # (1) Segment-boundary tails. ``retention_interval``: + # None -> dense (a tail at every ``alignment_tokens`` boundary); + # 0 -> no dense tails (only the replay boundary below); + # >0 -> a tail once per ``retention_interval``-sized segment. + segment_tokens = ( + alignment_tokens + if retention_interval is None + else (None if retention_interval == 0 else retention_interval) + ) + if segment_tokens is not None: + per_segment = segment_tokens // block_size + if need >= per_segment: + # Every block is reachable; cache them all. + return None + for i in range(start_block, end_block): + if i >= shift and (i - shift) % per_segment >= per_segment - need: + mask[i - start_block] = True + + # (2) Replay-boundary tail. ``get_computed_blocks`` caps hits at + # ``num_prompt - 1`` (to recompute the last token's logits), so an exact + # prompt replay can only land on the latest *fine*-aligned boundary. + # Sparse retention would otherwise skip it, so keep its tail explicitly. + if retention_interval is not None and num_prompt_tokens is not None: + latest = (num_prompt_tokens - 1) // alignment_tokens * alignment_tokens + prompt_end_block = latest // block_size + shift + for i in range( + max(start_block, prompt_end_block - need), + min(end_block, prompt_end_block), + ): + mask[i - start_block] = True + + return mask + + def free(self, request_id: str) -> None: + # similar to remove_skipped_blocks(), prepend the uncached blocks + # and append the cached blocks to the free queue + req_blocks = self.req_to_blocks.pop(request_id, []) + if req_blocks: + cached_blocks: list[KVCacheBlock] = [] + uncached_blocks: list[KVCacheBlock] = [] + for block in reversed(req_blocks): + if block.block_hash is None: + uncached_blocks.append(block) + else: + cached_blocks.append(block) + self.block_pool.free_blocks(cached_blocks) + self.block_pool.free_blocks(uncached_blocks, prepend=True) + self.num_cached_block.pop(request_id, None) def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: """ @@ -1152,10 +1214,10 @@ class MambaManager(SingleTypeKVCacheManager): self, request: Request, num_tokens: int, - alignment_tokens: int | None = None, + retention_interval: int | None = None, ) -> None: num_cached_blocks_before = self.num_cached_block.get(request.request_id, 0) - super().cache_blocks(request, num_tokens, alignment_tokens=alignment_tokens) + super().cache_blocks(request, num_tokens, retention_interval=retention_interval) num_cached_blocks_after = self.num_cached_block.get(request.request_id, 0) if num_cached_blocks_after > num_cached_blocks_before: for block in self.req_to_blocks[request.request_id][ @@ -1188,7 +1250,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager): self, request: Request, num_tokens: int, - alignment_tokens: int | None = None, + retention_interval: int | None = None, ) -> None: # We do not cache blocks for cross-attention to be shared between # requests, so this method is not relevant.