mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Prefix Caching] DeepSeekv4 - Support selective prefix-cache retention for sliding-window KV cache (#43447)
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com> Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai> Co-authored-by: Yifan Qiao <yifanqiao@inferact.ai>
This commit is contained in:
@@ -358,6 +358,43 @@ def test_free_kv_cache_block_queue_append_n():
|
||||
)
|
||||
|
||||
|
||||
def test_free_kv_cache_block_queue_prepend_n():
|
||||
# Seed the queue with one block so prepend has an existing head to splice
|
||||
# in front of (fake_head->b0->fake_tail).
|
||||
blocks = [KVCacheBlock(block_id=i) for i in range(6)]
|
||||
queue = FreeKVCacheBlockQueue(blocks[0:1])
|
||||
|
||||
# Prepend 0 blocks is a no-op.
|
||||
queue.prepend_n([])
|
||||
assert queue.num_free_blocks == 1
|
||||
assert queue.fake_free_list_head.next_free_block is blocks[0]
|
||||
|
||||
# Prepend 2 blocks; they land in front of the existing head, in order.
|
||||
# fake_head->b4->b5->b0->fake_tail
|
||||
queue.prepend_n(blocks[4:6])
|
||||
assert queue.num_free_blocks == 3
|
||||
assert queue.fake_free_list_head.next_free_block is blocks[4]
|
||||
assert blocks[4].prev_free_block is queue.fake_free_list_head
|
||||
assert blocks[4].next_free_block is blocks[5]
|
||||
assert blocks[5].prev_free_block is blocks[4]
|
||||
assert blocks[5].next_free_block is blocks[0]
|
||||
assert blocks[0].prev_free_block is blocks[5]
|
||||
assert blocks[0].next_free_block is queue.fake_free_list_tail
|
||||
assert queue.fake_free_list_tail.prev_free_block is blocks[0]
|
||||
|
||||
# A second prepend goes ahead of everything previously prepended.
|
||||
# fake_head->b1->b2->b4->b5->b0->fake_tail
|
||||
queue.prepend_n(blocks[1:3])
|
||||
assert queue.num_free_blocks == 5
|
||||
assert queue.fake_free_list_head.next_free_block is blocks[1]
|
||||
assert blocks[1].next_free_block is blocks[2]
|
||||
assert blocks[2].next_free_block is blocks[4]
|
||||
|
||||
# The popleft order reflects the front-to-back queue order.
|
||||
assert [queue.popleft().block_id for _ in range(5)] == [1, 2, 4, 5, 0]
|
||||
assert queue.num_free_blocks == 0
|
||||
|
||||
|
||||
def test_free_kv_cache_block_queue_popleft_n():
|
||||
blocks = [KVCacheBlock(block_id=i) for i in range(6)]
|
||||
# Create an empty FreeKVCacheBlockQueue with these blocks
|
||||
|
||||
@@ -39,6 +39,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
KVCacheGroupSpec,
|
||||
KVCacheSpecKind,
|
||||
MambaSpec,
|
||||
MLAAttentionSpec,
|
||||
SlidingWindowSpec,
|
||||
)
|
||||
|
||||
@@ -2875,6 +2876,350 @@ def test_hybrid_cache_blocks_clamped_to_lcm():
|
||||
)
|
||||
|
||||
|
||||
def test_hybrid_local_kv_retention_interval_aligns_in_manager(monkeypatch):
|
||||
"""Verify fixed intervals retain sparse tails plus the latest replay tail."""
|
||||
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "64")
|
||||
block_size = 8
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=100,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer1"],
|
||||
FullAttentionSpec(
|
||||
block_size=4 * block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float16,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer2"],
|
||||
SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=block_size,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
manager = make_kv_cache_manager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# The SWA manager uses the configured 64-token interval (a multiple of the
|
||||
# 32-token lcm_block_size) as its retention segment. For this 128-token
|
||||
# prompt, the retained SWA tails are the 64-token interval boundary, the
|
||||
# 96-token replay boundary, and the 128-token interval boundary.
|
||||
token_ids = [i for i in range(16) for _ in range(block_size)]
|
||||
req = make_request("0", token_ids, block_size, sha256)
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
blocks = manager.allocate_slots(
|
||||
req,
|
||||
len(token_ids),
|
||||
len(computed_blocks.blocks[0]) * block_size,
|
||||
computed_blocks,
|
||||
)
|
||||
assert blocks is not None
|
||||
|
||||
pool = manager.block_pool
|
||||
expected_swa_cached = {7, 11, 15}
|
||||
for i in range(16):
|
||||
cached = pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[1])
|
||||
if i in expected_swa_cached:
|
||||
assert cached is not None, f"SWA hash {i} should be cached"
|
||||
else:
|
||||
assert cached is None, f"SWA hash {i} should not be cached"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"interval, expected_match",
|
||||
[
|
||||
# scheduler_block_size is 32 (= lcm(4*8, 8)); 33 is not a multiple of it.
|
||||
("33", "multiple of scheduler_block_size"),
|
||||
# A negative multiple (-32 % 32 == 0) must still be rejected explicitly,
|
||||
# otherwise it would pass the modulo check and silently degrade to dense.
|
||||
("-32", "non-negative"),
|
||||
],
|
||||
)
|
||||
def test_hybrid_local_kv_retention_interval_rejects_invalid(
|
||||
monkeypatch, interval, expected_match
|
||||
):
|
||||
"""A retention interval that is negative or not a multiple of
|
||||
scheduler_block_size errors out at construction time."""
|
||||
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", interval)
|
||||
block_size = 8
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=100,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer1"],
|
||||
FullAttentionSpec(
|
||||
block_size=4 * block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float16,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer2"],
|
||||
SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=block_size,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
with pytest.raises(ValueError, match=expected_match):
|
||||
make_kv_cache_manager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
|
||||
def test_hybrid_local_kv_retention_interval_survives_recycling(monkeypatch):
|
||||
"""Verify retained local checkpoints are reused after block recycling."""
|
||||
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "1024")
|
||||
hash_block_size = 4
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=800,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["full"],
|
||||
MLAAttentionSpec(
|
||||
block_size=64 * hash_block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.uint8,
|
||||
compress_ratio=4,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["swa"],
|
||||
SlidingWindowSpec(
|
||||
block_size=16 * hash_block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=512,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["c128"],
|
||||
SlidingWindowSpec(
|
||||
block_size=2 * hash_block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=128,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["c4"],
|
||||
SlidingWindowSpec(
|
||||
block_size=hash_block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=8,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
manager = make_kv_cache_manager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=4096,
|
||||
enable_caching=True,
|
||||
hash_block_size=hash_block_size,
|
||||
)
|
||||
|
||||
def fill_request(request_id: str, token_offset: int) -> list[int]:
|
||||
token_ids = [
|
||||
token_offset + i for i in range(1024) for _ in range(hash_block_size)
|
||||
]
|
||||
fill_req = make_request(request_id, token_ids, hash_block_size, sha256)
|
||||
while fill_req.num_computed_tokens < len(token_ids):
|
||||
num_new_tokens = min(512, len(token_ids) - fill_req.num_computed_tokens)
|
||||
blocks = manager.allocate_slots(fill_req, num_new_tokens)
|
||||
assert blocks is not None
|
||||
fill_req.num_computed_tokens += num_new_tokens
|
||||
manager.free(fill_req)
|
||||
return token_ids
|
||||
|
||||
token_ids = fill_request("fill_0", 0)
|
||||
replay_req = make_request("replay", token_ids[:1800], hash_block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(replay_req)
|
||||
assert num_computed_tokens == 1024
|
||||
assert [len(blocks) for blocks in computed_blocks.blocks] == [4, 16, 128, 256]
|
||||
|
||||
fill_request("fill_1", 100_000)
|
||||
replay_req = make_request("replay_again", token_ids[:1800], hash_block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(replay_req)
|
||||
assert num_computed_tokens == 1024
|
||||
assert [len(blocks) for blocks in computed_blocks.blocks] == [4, 16, 128, 256]
|
||||
|
||||
|
||||
def test_hybrid_local_kv_retention_latest_only_reuses_replay_boundary(monkeypatch):
|
||||
"""Verify latest-only retention reuses only the replayable prompt boundary."""
|
||||
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "0")
|
||||
block_size = 8
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=100,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer1"],
|
||||
FullAttentionSpec(
|
||||
block_size=4 * block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float16,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer2"],
|
||||
SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=block_size,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
manager = make_kv_cache_manager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
token_ids = [i for i in range(16) for _ in range(block_size)]
|
||||
req0 = make_request("0", token_ids, block_size, sha256)
|
||||
computed_blocks, _ = manager.get_computed_blocks(req0)
|
||||
blocks = manager.allocate_slots(
|
||||
req0,
|
||||
len(token_ids),
|
||||
len(computed_blocks.blocks[0]) * block_size,
|
||||
computed_blocks,
|
||||
)
|
||||
assert blocks is not None
|
||||
|
||||
pool = manager.block_pool
|
||||
expected_swa_cached = {11}
|
||||
for i in range(16):
|
||||
cached = pool.get_cached_block(req0.block_hashes[i], kv_cache_group_ids=[1])
|
||||
if i in expected_swa_cached:
|
||||
assert cached is not None, f"SWA hash {i} should be cached"
|
||||
else:
|
||||
assert cached is None, f"SWA hash {i} should not be cached"
|
||||
|
||||
manager.free(req0)
|
||||
retained_swa_block = pool.get_cached_block(req0.block_hashes[11], [1])
|
||||
assert retained_swa_block is not None
|
||||
assert retained_swa_block[0].ref_cnt == 0
|
||||
|
||||
req1 = make_request("1", token_ids, block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
# Full prompt hits intentionally recompute the final block for logits, so
|
||||
# the longest usable hit is the previous LCM boundary: 96 tokens.
|
||||
assert num_computed_tokens == 12 * block_size
|
||||
assert len(computed_blocks.blocks[1]) == 12
|
||||
|
||||
shorter_req = make_request("2", token_ids[: 12 * block_size], block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(shorter_req)
|
||||
assert num_computed_tokens == 0
|
||||
assert len(computed_blocks.blocks[1]) == 0
|
||||
|
||||
|
||||
def test_hybrid_local_kv_retention_mtp_reuses_latest_boundary(monkeypatch):
|
||||
"""Verify MTP/EAGLE SWA retention keeps the extra proof block.
|
||||
|
||||
EAGLE/MTP lookup matches one additional local block after the returned
|
||||
prefix and then drops it. Sparse retention must therefore cache the normal
|
||||
local tail at the latest replay boundary plus one extra SWA block.
|
||||
"""
|
||||
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "0")
|
||||
block_size = 8
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=100,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["full"],
|
||||
FullAttentionSpec(
|
||||
block_size=4 * block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float16,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["swa_mtp"],
|
||||
SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=block_size,
|
||||
),
|
||||
is_eagle_group=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
manager = make_kv_cache_manager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
use_eagle=True,
|
||||
)
|
||||
|
||||
# 127 tokens: latest replay boundary is floor((127 - 1) / 32) * 32 = 96.
|
||||
# The EAGLE/MTP SWA lookup group must cache the local tail ending at
|
||||
# 104 tokens, and that tail is two 8-token blocks wide: hashes 11 and 12.
|
||||
token_ids = [i for i in range(15) for _ in range(block_size)] + [15] * 7
|
||||
req0 = make_request("0", token_ids, block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(
|
||||
req0,
|
||||
len(token_ids),
|
||||
num_computed_tokens,
|
||||
computed_blocks,
|
||||
)
|
||||
assert blocks is not None
|
||||
|
||||
pool = manager.block_pool
|
||||
expected_swa_cached = {11, 12}
|
||||
for i in range(15):
|
||||
cached = pool.get_cached_block(req0.block_hashes[i], kv_cache_group_ids=[1])
|
||||
if i in expected_swa_cached:
|
||||
assert cached is not None, f"SWA hash {i} should be cached"
|
||||
else:
|
||||
assert cached is None, f"SWA hash {i} should not be cached"
|
||||
|
||||
manager.free(req0)
|
||||
|
||||
req1 = make_request("1", token_ids, block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert num_computed_tokens == 12 * block_size
|
||||
assert [len(blocks) for blocks in computed_blocks.blocks] == [3, 12]
|
||||
|
||||
|
||||
def test_block_lookup_cache_single_block_per_key():
|
||||
cache = BlockHashToBlockMap()
|
||||
key0 = BlockHashWithGroupId(b"hash0")
|
||||
@@ -3058,3 +3403,215 @@ def test_can_fit_full_sequence_full_attention_still_gates_oversized():
|
||||
req = make_request("oversized", list(range(prompt_len)), block_size, sha256)
|
||||
|
||||
assert manager.allocate_slots(req, block_size, full_sequence_must_fit=True) is None
|
||||
|
||||
|
||||
def test_swa_free_split_keeps_cached_tail_ahead_of_scratch(monkeypatch):
|
||||
"""Default path (no retention): freeing an SWA request must place its
|
||||
uncached scratch blocks at the front of the free queue (recycled first)
|
||||
and keep its cached checkpoint blocks at the back (retained for prefix
|
||||
hits). This split is always-on, independent of the retention interval."""
|
||||
monkeypatch.delenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", raising=False)
|
||||
block_size = 8
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=100,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer1"],
|
||||
FullAttentionSpec(
|
||||
block_size=4 * block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float16,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer2"],
|
||||
SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=block_size,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
manager = make_kv_cache_manager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
token_ids = [i for i in range(16) for _ in range(block_size)]
|
||||
req = make_request("0", token_ids, block_size, sha256)
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
blocks = manager.allocate_slots(
|
||||
req,
|
||||
len(token_ids),
|
||||
len(computed_blocks.blocks[0]) * block_size,
|
||||
computed_blocks,
|
||||
)
|
||||
assert blocks is not None
|
||||
|
||||
swa_manager = manager.coordinator.single_type_managers[1]
|
||||
null_block = manager.block_pool.null_block
|
||||
cached_ids: set[int] = set()
|
||||
uncached_ids: set[int] = set()
|
||||
cached_hash_indices: list[int] = []
|
||||
for i, block in enumerate(swa_manager.req_to_blocks[req.request_id]):
|
||||
if block is null_block:
|
||||
continue
|
||||
if block.block_hash is None:
|
||||
uncached_ids.add(block.block_id)
|
||||
else:
|
||||
cached_ids.add(block.block_id)
|
||||
cached_hash_indices.append(i)
|
||||
# The dense default mask caches only the per-segment tails, so a 16-block
|
||||
# SWA prompt must produce a mix of retained and scratch blocks.
|
||||
assert cached_ids, "expected some retained (cached) SWA tail blocks"
|
||||
assert uncached_ids, "expected some scratch (uncached) SWA blocks"
|
||||
|
||||
manager.free(req)
|
||||
|
||||
order = [
|
||||
b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks()
|
||||
]
|
||||
pos = {bid: i for i, bid in enumerate(order)}
|
||||
# Every scratch block is recycled before every retained block.
|
||||
assert max(pos[bid] for bid in uncached_ids) < min(pos[bid] for bid in cached_ids)
|
||||
# The retained tails survive the free and still serve a prefix-cache hit.
|
||||
for i in cached_hash_indices:
|
||||
assert (
|
||||
manager.block_pool.get_cached_block(
|
||||
req.block_hashes[i], kv_cache_group_ids=[1]
|
||||
)
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
def _make_pure_swa_manager(block_size, sliding_window, num_blocks=100, **kwargs):
|
||||
"""Single sliding-window group (UnitaryKVCacheCoordinator)."""
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer"],
|
||||
SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=sliding_window,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
return make_kv_cache_manager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def test_pure_swa_retention_interval_caches_sparse_tails(monkeypatch):
|
||||
"""Sparse retention must work for a pure-SWA single-group model, not just
|
||||
hybrid models: only the per-interval tails plus the latest replay tail are
|
||||
cached, and a replay still hits the latest replayable boundary."""
|
||||
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "64")
|
||||
block_size = 16
|
||||
manager = _make_pure_swa_manager(block_size, sliding_window=block_size)
|
||||
assert type(manager.coordinator).__name__ == "UnitaryKVCacheCoordinator"
|
||||
|
||||
token_ids = [i for i in range(16) for _ in range(block_size)]
|
||||
req = make_request("0", token_ids, block_size, sha256)
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
blocks = manager.allocate_slots(
|
||||
req,
|
||||
len(token_ids),
|
||||
len(computed_blocks.blocks[0]) * block_size,
|
||||
computed_blocks,
|
||||
)
|
||||
assert blocks is not None
|
||||
|
||||
pool = manager.block_pool
|
||||
cached = {
|
||||
i
|
||||
for i in range(16)
|
||||
if pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[0])
|
||||
is not None
|
||||
}
|
||||
# per_segment = 64 / 16 = 4, need = cdiv(16-1, 16) = 1 -> segment tails at
|
||||
# i%4==3 -> {3,7,11,15}; latest replay boundary (255//16*16 = 240) -> tail
|
||||
# block 14. Crucially this is a strict subset of all 16 blocks: retention
|
||||
# is actually sparse for pure SWA (not silently dense).
|
||||
assert cached == {3, 7, 11, 14, 15}
|
||||
|
||||
# A replay of the same prompt hits the latest replayable boundary (240).
|
||||
replay = make_request("1", token_ids, block_size, sha256)
|
||||
_, num_computed = manager.get_computed_blocks(replay)
|
||||
assert num_computed == 240
|
||||
|
||||
|
||||
def test_pure_swa_retention_latest_only(monkeypatch):
|
||||
"""`=0` on a pure-SWA model keeps only the latest replay tail."""
|
||||
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "0")
|
||||
block_size = 16
|
||||
manager = _make_pure_swa_manager(block_size, sliding_window=block_size)
|
||||
|
||||
token_ids = [i for i in range(16) for _ in range(block_size)]
|
||||
req = make_request("0", token_ids, block_size, sha256)
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
blocks = manager.allocate_slots(
|
||||
req,
|
||||
len(token_ids),
|
||||
len(computed_blocks.blocks[0]) * block_size,
|
||||
computed_blocks,
|
||||
)
|
||||
assert blocks is not None
|
||||
|
||||
pool = manager.block_pool
|
||||
cached = {
|
||||
i
|
||||
for i in range(16)
|
||||
if pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[0])
|
||||
is not None
|
||||
}
|
||||
# No segment tails (interval 0); only the latest replay tail (block 14).
|
||||
assert cached == {14}
|
||||
|
||||
replay = make_request("1", token_ids, block_size, sha256)
|
||||
_, num_computed = manager.get_computed_blocks(replay)
|
||||
assert num_computed == 240
|
||||
|
||||
|
||||
def test_pure_swa_retention_dense_default_caches_all(monkeypatch):
|
||||
"""With retention unset, a pure-SWA model must keep the dense behavior:
|
||||
every block boundary is a potential hit, so all blocks are cached."""
|
||||
monkeypatch.delenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", raising=False)
|
||||
block_size = 16
|
||||
manager = _make_pure_swa_manager(block_size, sliding_window=block_size)
|
||||
|
||||
token_ids = [i for i in range(16) for _ in range(block_size)]
|
||||
req = make_request("0", token_ids, block_size, sha256)
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
blocks = manager.allocate_slots(
|
||||
req,
|
||||
len(token_ids),
|
||||
len(computed_blocks.blocks[0]) * block_size,
|
||||
computed_blocks,
|
||||
)
|
||||
assert blocks is not None
|
||||
|
||||
pool = manager.block_pool
|
||||
cached = {
|
||||
i
|
||||
for i in range(16)
|
||||
if pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[0])
|
||||
is not None
|
||||
}
|
||||
assert cached == set(range(16))
|
||||
|
||||
Reference in New Issue
Block a user