[Prefix Caching] DeepSeekv4 - Support selective prefix-cache retention for sliding-window KV cache (#43447)

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Co-authored-by: Yifan Qiao <yifanqiao@inferact.ai>
This commit is contained in:
Wei Zhao
2026-06-04 03:48:31 -04:00
committed by GitHub
parent 22c2e87555
commit a6183563b6
7 changed files with 792 additions and 45 deletions
+37
View File
@@ -358,6 +358,43 @@ def test_free_kv_cache_block_queue_append_n():
)
def test_free_kv_cache_block_queue_prepend_n():
# Seed the queue with one block so prepend has an existing head to splice
# in front of (fake_head->b0->fake_tail).
blocks = [KVCacheBlock(block_id=i) for i in range(6)]
queue = FreeKVCacheBlockQueue(blocks[0:1])
# Prepend 0 blocks is a no-op.
queue.prepend_n([])
assert queue.num_free_blocks == 1
assert queue.fake_free_list_head.next_free_block is blocks[0]
# Prepend 2 blocks; they land in front of the existing head, in order.
# fake_head->b4->b5->b0->fake_tail
queue.prepend_n(blocks[4:6])
assert queue.num_free_blocks == 3
assert queue.fake_free_list_head.next_free_block is blocks[4]
assert blocks[4].prev_free_block is queue.fake_free_list_head
assert blocks[4].next_free_block is blocks[5]
assert blocks[5].prev_free_block is blocks[4]
assert blocks[5].next_free_block is blocks[0]
assert blocks[0].prev_free_block is blocks[5]
assert blocks[0].next_free_block is queue.fake_free_list_tail
assert queue.fake_free_list_tail.prev_free_block is blocks[0]
# A second prepend goes ahead of everything previously prepended.
# fake_head->b1->b2->b4->b5->b0->fake_tail
queue.prepend_n(blocks[1:3])
assert queue.num_free_blocks == 5
assert queue.fake_free_list_head.next_free_block is blocks[1]
assert blocks[1].next_free_block is blocks[2]
assert blocks[2].next_free_block is blocks[4]
# The popleft order reflects the front-to-back queue order.
assert [queue.popleft().block_id for _ in range(5)] == [1, 2, 4, 5, 0]
assert queue.num_free_blocks == 0
def test_free_kv_cache_block_queue_popleft_n():
blocks = [KVCacheBlock(block_id=i) for i in range(6)]
# Create an empty FreeKVCacheBlockQueue with these blocks
+557
View File
@@ -39,6 +39,7 @@ from vllm.v1.kv_cache_interface import (
KVCacheGroupSpec,
KVCacheSpecKind,
MambaSpec,
MLAAttentionSpec,
SlidingWindowSpec,
)
@@ -2875,6 +2876,350 @@ def test_hybrid_cache_blocks_clamped_to_lcm():
)
def test_hybrid_local_kv_retention_interval_aligns_in_manager(monkeypatch):
"""Verify fixed intervals retain sparse tails plus the latest replay tail."""
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "64")
block_size = 8
kv_cache_config = KVCacheConfig(
num_blocks=100,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(
block_size=4 * block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float16,
),
),
KVCacheGroupSpec(
["layer2"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
),
),
],
)
manager = make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
# The SWA manager uses the configured 64-token interval (a multiple of the
# 32-token lcm_block_size) as its retention segment. For this 128-token
# prompt, the retained SWA tails are the 64-token interval boundary, the
# 96-token replay boundary, and the 128-token interval boundary.
token_ids = [i for i in range(16) for _ in range(block_size)]
req = make_request("0", token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req)
blocks = manager.allocate_slots(
req,
len(token_ids),
len(computed_blocks.blocks[0]) * block_size,
computed_blocks,
)
assert blocks is not None
pool = manager.block_pool
expected_swa_cached = {7, 11, 15}
for i in range(16):
cached = pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[1])
if i in expected_swa_cached:
assert cached is not None, f"SWA hash {i} should be cached"
else:
assert cached is None, f"SWA hash {i} should not be cached"
@pytest.mark.parametrize(
"interval, expected_match",
[
# scheduler_block_size is 32 (= lcm(4*8, 8)); 33 is not a multiple of it.
("33", "multiple of scheduler_block_size"),
# A negative multiple (-32 % 32 == 0) must still be rejected explicitly,
# otherwise it would pass the modulo check and silently degrade to dense.
("-32", "non-negative"),
],
)
def test_hybrid_local_kv_retention_interval_rejects_invalid(
monkeypatch, interval, expected_match
):
"""A retention interval that is negative or not a multiple of
scheduler_block_size errors out at construction time."""
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", interval)
block_size = 8
kv_cache_config = KVCacheConfig(
num_blocks=100,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(
block_size=4 * block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float16,
),
),
KVCacheGroupSpec(
["layer2"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
),
),
],
)
with pytest.raises(ValueError, match=expected_match):
make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
def test_hybrid_local_kv_retention_interval_survives_recycling(monkeypatch):
"""Verify retained local checkpoints are reused after block recycling."""
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "1024")
hash_block_size = 4
kv_cache_config = KVCacheConfig(
num_blocks=800,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["full"],
MLAAttentionSpec(
block_size=64 * hash_block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.uint8,
compress_ratio=4,
),
),
KVCacheGroupSpec(
["swa"],
SlidingWindowSpec(
block_size=16 * hash_block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=512,
),
),
KVCacheGroupSpec(
["c128"],
SlidingWindowSpec(
block_size=2 * hash_block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=128,
),
),
KVCacheGroupSpec(
["c4"],
SlidingWindowSpec(
block_size=hash_block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=8,
),
),
],
)
manager = make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=4096,
enable_caching=True,
hash_block_size=hash_block_size,
)
def fill_request(request_id: str, token_offset: int) -> list[int]:
token_ids = [
token_offset + i for i in range(1024) for _ in range(hash_block_size)
]
fill_req = make_request(request_id, token_ids, hash_block_size, sha256)
while fill_req.num_computed_tokens < len(token_ids):
num_new_tokens = min(512, len(token_ids) - fill_req.num_computed_tokens)
blocks = manager.allocate_slots(fill_req, num_new_tokens)
assert blocks is not None
fill_req.num_computed_tokens += num_new_tokens
manager.free(fill_req)
return token_ids
token_ids = fill_request("fill_0", 0)
replay_req = make_request("replay", token_ids[:1800], hash_block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(replay_req)
assert num_computed_tokens == 1024
assert [len(blocks) for blocks in computed_blocks.blocks] == [4, 16, 128, 256]
fill_request("fill_1", 100_000)
replay_req = make_request("replay_again", token_ids[:1800], hash_block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(replay_req)
assert num_computed_tokens == 1024
assert [len(blocks) for blocks in computed_blocks.blocks] == [4, 16, 128, 256]
def test_hybrid_local_kv_retention_latest_only_reuses_replay_boundary(monkeypatch):
"""Verify latest-only retention reuses only the replayable prompt boundary."""
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "0")
block_size = 8
kv_cache_config = KVCacheConfig(
num_blocks=100,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(
block_size=4 * block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float16,
),
),
KVCacheGroupSpec(
["layer2"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
),
),
],
)
manager = make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
token_ids = [i for i in range(16) for _ in range(block_size)]
req0 = make_request("0", token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req0)
blocks = manager.allocate_slots(
req0,
len(token_ids),
len(computed_blocks.blocks[0]) * block_size,
computed_blocks,
)
assert blocks is not None
pool = manager.block_pool
expected_swa_cached = {11}
for i in range(16):
cached = pool.get_cached_block(req0.block_hashes[i], kv_cache_group_ids=[1])
if i in expected_swa_cached:
assert cached is not None, f"SWA hash {i} should be cached"
else:
assert cached is None, f"SWA hash {i} should not be cached"
manager.free(req0)
retained_swa_block = pool.get_cached_block(req0.block_hashes[11], [1])
assert retained_swa_block is not None
assert retained_swa_block[0].ref_cnt == 0
req1 = make_request("1", token_ids, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
# Full prompt hits intentionally recompute the final block for logits, so
# the longest usable hit is the previous LCM boundary: 96 tokens.
assert num_computed_tokens == 12 * block_size
assert len(computed_blocks.blocks[1]) == 12
shorter_req = make_request("2", token_ids[: 12 * block_size], block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(shorter_req)
assert num_computed_tokens == 0
assert len(computed_blocks.blocks[1]) == 0
def test_hybrid_local_kv_retention_mtp_reuses_latest_boundary(monkeypatch):
"""Verify MTP/EAGLE SWA retention keeps the extra proof block.
EAGLE/MTP lookup matches one additional local block after the returned
prefix and then drops it. Sparse retention must therefore cache the normal
local tail at the latest replay boundary plus one extra SWA block.
"""
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "0")
block_size = 8
kv_cache_config = KVCacheConfig(
num_blocks=100,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["full"],
FullAttentionSpec(
block_size=4 * block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float16,
),
),
KVCacheGroupSpec(
["swa_mtp"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
),
is_eagle_group=True,
),
],
)
manager = make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
use_eagle=True,
)
# 127 tokens: latest replay boundary is floor((127 - 1) / 32) * 32 = 96.
# The EAGLE/MTP SWA lookup group must cache the local tail ending at
# 104 tokens, and that tail is two 8-token blocks wide: hashes 11 and 12.
token_ids = [i for i in range(15) for _ in range(block_size)] + [15] * 7
req0 = make_request("0", token_ids, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert num_computed_tokens == 0
blocks = manager.allocate_slots(
req0,
len(token_ids),
num_computed_tokens,
computed_blocks,
)
assert blocks is not None
pool = manager.block_pool
expected_swa_cached = {11, 12}
for i in range(15):
cached = pool.get_cached_block(req0.block_hashes[i], kv_cache_group_ids=[1])
if i in expected_swa_cached:
assert cached is not None, f"SWA hash {i} should be cached"
else:
assert cached is None, f"SWA hash {i} should not be cached"
manager.free(req0)
req1 = make_request("1", token_ids, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert num_computed_tokens == 12 * block_size
assert [len(blocks) for blocks in computed_blocks.blocks] == [3, 12]
def test_block_lookup_cache_single_block_per_key():
cache = BlockHashToBlockMap()
key0 = BlockHashWithGroupId(b"hash0")
@@ -3058,3 +3403,215 @@ def test_can_fit_full_sequence_full_attention_still_gates_oversized():
req = make_request("oversized", list(range(prompt_len)), block_size, sha256)
assert manager.allocate_slots(req, block_size, full_sequence_must_fit=True) is None
def test_swa_free_split_keeps_cached_tail_ahead_of_scratch(monkeypatch):
"""Default path (no retention): freeing an SWA request must place its
uncached scratch blocks at the front of the free queue (recycled first)
and keep its cached checkpoint blocks at the back (retained for prefix
hits). This split is always-on, independent of the retention interval."""
monkeypatch.delenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", raising=False)
block_size = 8
kv_cache_config = KVCacheConfig(
num_blocks=100,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(
block_size=4 * block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float16,
),
),
KVCacheGroupSpec(
["layer2"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
),
),
],
)
manager = make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
token_ids = [i for i in range(16) for _ in range(block_size)]
req = make_request("0", token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req)
blocks = manager.allocate_slots(
req,
len(token_ids),
len(computed_blocks.blocks[0]) * block_size,
computed_blocks,
)
assert blocks is not None
swa_manager = manager.coordinator.single_type_managers[1]
null_block = manager.block_pool.null_block
cached_ids: set[int] = set()
uncached_ids: set[int] = set()
cached_hash_indices: list[int] = []
for i, block in enumerate(swa_manager.req_to_blocks[req.request_id]):
if block is null_block:
continue
if block.block_hash is None:
uncached_ids.add(block.block_id)
else:
cached_ids.add(block.block_id)
cached_hash_indices.append(i)
# The dense default mask caches only the per-segment tails, so a 16-block
# SWA prompt must produce a mix of retained and scratch blocks.
assert cached_ids, "expected some retained (cached) SWA tail blocks"
assert uncached_ids, "expected some scratch (uncached) SWA blocks"
manager.free(req)
order = [
b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks()
]
pos = {bid: i for i, bid in enumerate(order)}
# Every scratch block is recycled before every retained block.
assert max(pos[bid] for bid in uncached_ids) < min(pos[bid] for bid in cached_ids)
# The retained tails survive the free and still serve a prefix-cache hit.
for i in cached_hash_indices:
assert (
manager.block_pool.get_cached_block(
req.block_hashes[i], kv_cache_group_ids=[1]
)
is not None
)
def _make_pure_swa_manager(block_size, sliding_window, num_blocks=100, **kwargs):
"""Single sliding-window group (UnitaryKVCacheCoordinator)."""
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=sliding_window,
),
),
],
)
return make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
**kwargs,
)
def test_pure_swa_retention_interval_caches_sparse_tails(monkeypatch):
"""Sparse retention must work for a pure-SWA single-group model, not just
hybrid models: only the per-interval tails plus the latest replay tail are
cached, and a replay still hits the latest replayable boundary."""
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "64")
block_size = 16
manager = _make_pure_swa_manager(block_size, sliding_window=block_size)
assert type(manager.coordinator).__name__ == "UnitaryKVCacheCoordinator"
token_ids = [i for i in range(16) for _ in range(block_size)]
req = make_request("0", token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req)
blocks = manager.allocate_slots(
req,
len(token_ids),
len(computed_blocks.blocks[0]) * block_size,
computed_blocks,
)
assert blocks is not None
pool = manager.block_pool
cached = {
i
for i in range(16)
if pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[0])
is not None
}
# per_segment = 64 / 16 = 4, need = cdiv(16-1, 16) = 1 -> segment tails at
# i%4==3 -> {3,7,11,15}; latest replay boundary (255//16*16 = 240) -> tail
# block 14. Crucially this is a strict subset of all 16 blocks: retention
# is actually sparse for pure SWA (not silently dense).
assert cached == {3, 7, 11, 14, 15}
# A replay of the same prompt hits the latest replayable boundary (240).
replay = make_request("1", token_ids, block_size, sha256)
_, num_computed = manager.get_computed_blocks(replay)
assert num_computed == 240
def test_pure_swa_retention_latest_only(monkeypatch):
"""`=0` on a pure-SWA model keeps only the latest replay tail."""
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "0")
block_size = 16
manager = _make_pure_swa_manager(block_size, sliding_window=block_size)
token_ids = [i for i in range(16) for _ in range(block_size)]
req = make_request("0", token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req)
blocks = manager.allocate_slots(
req,
len(token_ids),
len(computed_blocks.blocks[0]) * block_size,
computed_blocks,
)
assert blocks is not None
pool = manager.block_pool
cached = {
i
for i in range(16)
if pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[0])
is not None
}
# No segment tails (interval 0); only the latest replay tail (block 14).
assert cached == {14}
replay = make_request("1", token_ids, block_size, sha256)
_, num_computed = manager.get_computed_blocks(replay)
assert num_computed == 240
def test_pure_swa_retention_dense_default_caches_all(monkeypatch):
"""With retention unset, a pure-SWA model must keep the dense behavior:
every block boundary is a potential hit, so all blocks are cached."""
monkeypatch.delenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", raising=False)
block_size = 16
manager = _make_pure_swa_manager(block_size, sliding_window=block_size)
token_ids = [i for i in range(16) for _ in range(block_size)]
req = make_request("0", token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req)
blocks = manager.allocate_slots(
req,
len(token_ids),
len(computed_blocks.blocks[0]) * block_size,
computed_blocks,
)
assert blocks is not None
pool = manager.block_pool
cached = {
i
for i in range(16)
if pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[0])
is not None
}
assert cached == set(range(16))