mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
a6183563b6
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com> Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai> Co-authored-by: Yifan Qiao <yifanqiao@inferact.ai>
3618 lines
124 KiB
Python
3618 lines
124 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Compare the with and without prefix caching."""
|
|
|
|
import copy
|
|
from collections.abc import Callable
|
|
from math import lcm
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
import vllm.v1.core.kv_cache_manager as kv_cache_manager
|
|
import vllm.v1.core.kv_cache_utils as kv_cache_utils
|
|
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved, BlockStored
|
|
from vllm.lora.request import LoRARequest
|
|
from vllm.multimodal.inputs import (
|
|
MultiModalFeatureSpec,
|
|
MultiModalKwargsItem,
|
|
PlaceholderRange,
|
|
)
|
|
from vllm.sampling_params import SamplingParams
|
|
from vllm.utils.hashing import sha256, sha256_cbor
|
|
from vllm.v1.core.block_pool import BlockHashToBlockMap, BlockPool
|
|
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
|
from vllm.v1.core.kv_cache_utils import (
|
|
BlockHash,
|
|
BlockHashWithGroupId,
|
|
KVCacheBlock,
|
|
get_block_hash,
|
|
get_group_id,
|
|
get_request_block_hasher,
|
|
hash_block_tokens,
|
|
init_none_hash,
|
|
make_block_hash_with_group_id,
|
|
)
|
|
from vllm.v1.kv_cache_interface import (
|
|
FullAttentionSpec,
|
|
KVCacheConfig,
|
|
KVCacheGroupSpec,
|
|
KVCacheSpecKind,
|
|
MambaSpec,
|
|
MLAAttentionSpec,
|
|
SlidingWindowSpec,
|
|
)
|
|
|
|
pytestmark = pytest.mark.cpu_test
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _auto_init_hash_fn(request):
|
|
hash_fn: Callable
|
|
if "hash_fn" in request.fixturenames:
|
|
hash_fn = request.getfixturevalue("hash_fn")
|
|
else:
|
|
hash_fn = sha256
|
|
init_none_hash(hash_fn)
|
|
|
|
|
|
def make_request(
|
|
request_id: str,
|
|
prompt_token_ids: list[int],
|
|
block_size: int,
|
|
hash_fn: Callable,
|
|
mm_positions: list[PlaceholderRange] | None = None,
|
|
mm_hashes: list[str] | None = None,
|
|
prompt_logprobs: int | None = None,
|
|
cache_salt: str | None = None,
|
|
lora_request: LoRARequest | None = None,
|
|
):
|
|
mm_features = []
|
|
if mm_positions is not None:
|
|
for j, position in enumerate(mm_positions):
|
|
identifier = mm_hashes[j] if mm_hashes else f"hash_{j}"
|
|
mm_feature = MultiModalFeatureSpec(
|
|
data=MultiModalKwargsItem.dummy(),
|
|
mm_position=position,
|
|
identifier=identifier,
|
|
modality="image",
|
|
)
|
|
mm_features.append(mm_feature)
|
|
|
|
sampling_params = SamplingParams(max_tokens=17, prompt_logprobs=prompt_logprobs)
|
|
sampling_params.update_from_generation_config({}, eos_token_id=100)
|
|
|
|
return Request(
|
|
request_id=request_id,
|
|
prompt_token_ids=prompt_token_ids,
|
|
mm_features=mm_features if mm_features else None,
|
|
sampling_params=sampling_params,
|
|
pooling_params=None,
|
|
lora_request=lora_request,
|
|
cache_salt=cache_salt,
|
|
block_hasher=get_request_block_hasher(block_size, hash_fn),
|
|
)
|
|
|
|
|
|
def make_kv_cache_manager(kv_cache_config: KVCacheConfig, **kwargs) -> KVCacheManager:
|
|
"""Build a ``KVCacheManager``, deriving ``scheduler_block_size`` from the
|
|
config (LCM of group block sizes) unless explicitly provided. This mirrors
|
|
``resolve_kv_cache_block_sizes`` for the non-context-parallel case used by
|
|
these tests, so callers don't have to pass it at every site."""
|
|
kwargs.setdefault(
|
|
"scheduler_block_size",
|
|
lcm(*(g.kv_cache_spec.block_size for g in kv_cache_config.kv_cache_groups)),
|
|
)
|
|
return KVCacheManager(kv_cache_config, **kwargs)
|
|
|
|
|
|
def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
|
|
return KVCacheConfig(
|
|
num_blocks=num_blocks,
|
|
kv_cache_tensors=[],
|
|
kv_cache_groups=[
|
|
KVCacheGroupSpec(
|
|
["layer"],
|
|
FullAttentionSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
def make_kv_cache_config_hybrid_model(
|
|
block_size: int,
|
|
num_blocks: int,
|
|
sliding_window_blocks: int,
|
|
second_spec_type: str = "sliding_window",
|
|
) -> KVCacheConfig:
|
|
if second_spec_type == "sliding_window":
|
|
second_spec = SlidingWindowSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
sliding_window=sliding_window_blocks * block_size,
|
|
)
|
|
elif second_spec_type == "mamba":
|
|
second_spec = MambaSpec(
|
|
block_size=block_size,
|
|
shapes=(1, 1),
|
|
dtypes=(torch.float32,),
|
|
)
|
|
|
|
return KVCacheConfig(
|
|
num_blocks=num_blocks,
|
|
kv_cache_tensors=[],
|
|
kv_cache_groups=[
|
|
KVCacheGroupSpec(
|
|
["layer1"],
|
|
FullAttentionSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
),
|
|
),
|
|
KVCacheGroupSpec(
|
|
["layer2"],
|
|
second_spec,
|
|
),
|
|
KVCacheGroupSpec(
|
|
["layer3"],
|
|
second_spec,
|
|
),
|
|
],
|
|
)
|
|
|
|
|
|
def make_kv_cache_config_three_types(
|
|
block_size: int, num_blocks: int, third_spec_type: str = "mamba"
|
|
) -> KVCacheConfig:
|
|
if third_spec_type == "mamba":
|
|
third_spec = MambaSpec(
|
|
block_size=block_size,
|
|
shapes=(1, 1),
|
|
dtypes=(torch.float32,),
|
|
)
|
|
elif third_spec_type == "sliding_window":
|
|
third_spec = SlidingWindowSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
sliding_window=4 * block_size,
|
|
)
|
|
|
|
return KVCacheConfig(
|
|
num_blocks=num_blocks,
|
|
kv_cache_tensors=[],
|
|
kv_cache_groups=[
|
|
KVCacheGroupSpec(
|
|
["layer1"],
|
|
FullAttentionSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
),
|
|
),
|
|
KVCacheGroupSpec(
|
|
["layer2"],
|
|
SlidingWindowSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
sliding_window=2 * block_size,
|
|
),
|
|
),
|
|
KVCacheGroupSpec(
|
|
["layer3"],
|
|
third_spec,
|
|
),
|
|
],
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
|
def test_prefill(hash_fn):
|
|
block_size = 16
|
|
manager = make_kv_cache_manager(
|
|
make_kv_cache_config(block_size, 11),
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
|
|
# Complete 3 blocks (48 tokens)
|
|
common_token_ids = [i for i in range(3) for _ in range(16)]
|
|
|
|
# Fully cache miss
|
|
# Incomplete 1 block (7 tokens)
|
|
unique_token_ids = [3] * 7
|
|
all_token_ids = common_token_ids + unique_token_ids
|
|
req0 = make_request("0", all_token_ids, block_size, hash_fn)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
|
assert len(req0.block_hashes) == 3
|
|
assert not computed_blocks.blocks[0]
|
|
assert num_computed_tokens == 0
|
|
blocks = manager.allocate_slots(
|
|
req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4],)
|
|
|
|
# Check full block metadata
|
|
parent_block_hash = None
|
|
for block_id in (1, 2, 3):
|
|
block_tokens = tuple(all_token_ids[(block_id - 1) * 16 : block_id * 16])
|
|
block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens)
|
|
blk_hash = manager.block_pool.blocks[block_id].block_hash
|
|
assert blk_hash is not None
|
|
assert get_block_hash(blk_hash) == block_hash
|
|
assert get_group_id(blk_hash) == 0
|
|
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
|
parent_block_hash = block_hash
|
|
|
|
# Check partial block metadata
|
|
for block_id in (4,):
|
|
assert manager.block_pool.blocks[block_id].block_hash is None
|
|
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
|
|
|
# Cache hit in the common prefix when the original block is still in use.
|
|
# Incomplete 1 block (5 tokens)
|
|
unique_token_ids = [3] * 5
|
|
req1 = make_request("1", common_token_ids + unique_token_ids, block_size, hash_fn)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
assert len(req1.block_hashes) == 3
|
|
assert computed_blocks.get_block_ids() == ([1, 2, 3],)
|
|
assert num_computed_tokens == 3 * 16
|
|
num_new_tokens = 53 - 3 * 16
|
|
blocks = manager.allocate_slots(
|
|
req1, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert blocks is not None and blocks.get_block_ids() == ([5],)
|
|
for block in computed_blocks.blocks[0]:
|
|
assert block.ref_cnt == 2
|
|
|
|
# At this point, we should have 5 free blocks left.
|
|
free_block_queue = manager.block_pool.free_block_queue
|
|
assert free_block_queue.num_free_blocks == 5
|
|
|
|
manager.free(req0)
|
|
manager.free(req1)
|
|
|
|
# All blocks should be available.
|
|
assert free_block_queue.num_free_blocks == 10
|
|
# The order should be
|
|
# [unallocated (6, 7, 8, 9, 10)]
|
|
# [unique_req0 (4)]
|
|
# [unique_req1 (5)]
|
|
# [common (3, 2, 1)]
|
|
assert [
|
|
b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks()
|
|
] == [6, 7, 8, 9, 10, 4, 5, 3, 2, 1]
|
|
|
|
# Cache hit in the common prefix when the original block is already free.
|
|
# Incomplete 1 block (6 tokens)
|
|
unique_token_ids = [3] * 6
|
|
req2 = make_request("2", common_token_ids + unique_token_ids, block_size, hash_fn)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
|
assert len(req2.block_hashes) == 3
|
|
assert computed_blocks.get_block_ids() == ([1, 2, 3],)
|
|
assert num_computed_tokens == 3 * 16
|
|
num_new_tokens = 53 - 3 * 16
|
|
blocks = manager.allocate_slots(
|
|
req2, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert blocks is not None and blocks.get_block_ids() == ([6],)
|
|
|
|
# Although we only have 6 free blocks, we have 8 blocks in
|
|
# the free block queue due to lazy removal.
|
|
assert free_block_queue.num_free_blocks == 6
|
|
assert all([b.ref_cnt == 0 for b in free_block_queue.get_all_free_blocks()])
|
|
assert len([b for b in free_block_queue.get_all_free_blocks()]) == 6
|
|
|
|
manager.free(req2)
|
|
|
|
# Cache miss and eviction.
|
|
req3 = make_request("3", [99] * (16 * 10), block_size, hash_fn)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
|
assert not computed_blocks.blocks[0]
|
|
assert num_computed_tokens == 0
|
|
blocks = manager.allocate_slots(
|
|
req3, 16 * 10, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
# This block ID order also checks the eviction order.
|
|
assert blocks is not None and blocks.get_block_ids() == (
|
|
[7, 8, 9, 10, 4, 5, 6, 3, 2, 1],
|
|
)
|
|
|
|
assert free_block_queue.num_free_blocks == 0
|
|
assert (
|
|
free_block_queue.fake_free_list_head.next_free_block
|
|
is free_block_queue.fake_free_list_tail
|
|
)
|
|
assert (
|
|
free_block_queue.fake_free_list_tail.prev_free_block
|
|
is free_block_queue.fake_free_list_head
|
|
)
|
|
|
|
|
|
def test_prefill_hybrid_model():
|
|
block_size = 16
|
|
manager = make_kv_cache_manager(
|
|
make_kv_cache_config_hybrid_model(block_size, 21, 2),
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
|
|
hash_fn = sha256
|
|
|
|
# Complete 3 blocks (48 tokens)
|
|
num_full_blocks = 3
|
|
common_token_ids = [i for i in range(num_full_blocks) for _ in range(block_size)]
|
|
|
|
# Fully cache miss
|
|
# Incomplete 1 block (7 tokens)
|
|
unique_token_ids = [3] * 7
|
|
all_token_ids = common_token_ids + unique_token_ids
|
|
req0 = make_request("0", all_token_ids, block_size, hash_fn)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
|
assert len(req0.block_hashes) == 3
|
|
assert not computed_blocks.blocks[0]
|
|
assert num_computed_tokens == 0
|
|
blocks = manager.allocate_slots(
|
|
req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert blocks is not None and blocks.get_block_ids() == (
|
|
[1, 2, 3, 4],
|
|
[5, 6, 7, 8],
|
|
[9, 10, 11, 12],
|
|
)
|
|
|
|
# Check full block metadata
|
|
parent_block_hash = None
|
|
for length, block_ids in zip((1, 2, 3), ((1, 5, 9), (2, 6, 10), (3, 7, 11))):
|
|
block_tokens = tuple(all_token_ids[(length - 1) * 16 : length * 16])
|
|
block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens)
|
|
for group_id, block_id in enumerate(block_ids):
|
|
blk_hash = manager.block_pool.blocks[block_id].block_hash
|
|
assert blk_hash is not None
|
|
assert get_block_hash(blk_hash) == block_hash
|
|
assert get_group_id(blk_hash) == group_id
|
|
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
|
parent_block_hash = block_hash
|
|
|
|
# Check partial block metadata
|
|
for block_id in (4, 8, 12):
|
|
assert manager.block_pool.blocks[block_id].block_hash is None
|
|
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
|
|
|
# Cache hit in the common prefix
|
|
# Incomplete 1 block (5 tokens)
|
|
unique_token_ids = [3] * 5
|
|
all_token_ids = common_token_ids + unique_token_ids
|
|
req1 = make_request("1", common_token_ids + unique_token_ids, block_size, hash_fn)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
assert len(req1.block_hashes) == 3
|
|
assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6, 7], [0, 10, 11])
|
|
assert num_computed_tokens == 3 * 16
|
|
num_new_tokens = 53 - 3 * 16
|
|
blocks = manager.allocate_slots(
|
|
req1, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert blocks is not None and blocks.get_block_ids() == ([13], [14], [15])
|
|
for block_per_group in computed_blocks.blocks:
|
|
for block in block_per_group:
|
|
if block != manager.block_pool.null_block:
|
|
assert block.ref_cnt == 2
|
|
|
|
block_hashes = req1.block_hashes
|
|
manager.free(req0)
|
|
manager.free(req1)
|
|
|
|
# Evict the blocks outside sliding window, does not affect the hit length.
|
|
_test_partial_request_hit(
|
|
manager,
|
|
block_size,
|
|
num_full_blocks,
|
|
"2",
|
|
all_token_ids,
|
|
[
|
|
make_block_hash_with_group_id(block_hashes[0], 1),
|
|
make_block_hash_with_group_id(block_hashes[0], 2),
|
|
],
|
|
3,
|
|
)
|
|
|
|
# Evict the first block of full attention, makes total cache miss.
|
|
_test_partial_request_hit(
|
|
manager,
|
|
block_size,
|
|
num_full_blocks,
|
|
"3",
|
|
all_token_ids,
|
|
[make_block_hash_with_group_id(block_hashes[0], 0)],
|
|
0,
|
|
)
|
|
|
|
# Evict the last block of all layers, reduces the hit length to 2.
|
|
_test_partial_request_hit(
|
|
manager,
|
|
block_size,
|
|
num_full_blocks,
|
|
"4",
|
|
all_token_ids,
|
|
[
|
|
make_block_hash_with_group_id(block_hashes[2], 0),
|
|
make_block_hash_with_group_id(block_hashes[2], 1),
|
|
make_block_hash_with_group_id(block_hashes[2], 2),
|
|
],
|
|
2,
|
|
)
|
|
|
|
# Evict the last block of full attention, reduces the hit length to 2.
|
|
_test_partial_request_hit(
|
|
manager,
|
|
block_size,
|
|
num_full_blocks,
|
|
"5",
|
|
all_token_ids,
|
|
[make_block_hash_with_group_id(block_hashes[2], 0)],
|
|
2,
|
|
)
|
|
|
|
# Evict the last block of sliding window, reduces the hit length to 2.
|
|
_test_partial_request_hit(
|
|
manager,
|
|
block_size,
|
|
num_full_blocks,
|
|
"6",
|
|
all_token_ids,
|
|
[make_block_hash_with_group_id(block_hashes[2], 1)],
|
|
2,
|
|
)
|
|
|
|
# Evict the last block of sliding window, reduces the hit length to 2.
|
|
_test_partial_request_hit(
|
|
manager,
|
|
block_size,
|
|
num_full_blocks,
|
|
"7",
|
|
all_token_ids,
|
|
[make_block_hash_with_group_id(block_hashes[2], 2)],
|
|
2,
|
|
)
|
|
|
|
# Evict different set of blocks for full attention and sliding window makes
|
|
# total cache miss.
|
|
# The cache hit length of full attention is 1 * block_size.
|
|
# The cache hit length of sliding window is 2 * block_size.
|
|
# Then it is cache miss as the two type of layers
|
|
# have different hit length.
|
|
_test_partial_request_hit(
|
|
manager,
|
|
block_size,
|
|
num_full_blocks,
|
|
"8",
|
|
all_token_ids,
|
|
[
|
|
make_block_hash_with_group_id(block_hashes[2], 0),
|
|
make_block_hash_with_group_id(block_hashes[0], 1),
|
|
make_block_hash_with_group_id(block_hashes[0], 2),
|
|
],
|
|
0,
|
|
)
|
|
|
|
|
|
def test_prefill_hybrid_model_eagle():
|
|
block_size = 16
|
|
kv_cache_config = make_kv_cache_config_hybrid_model(block_size, 31, 3)
|
|
manager = make_kv_cache_manager(
|
|
kv_cache_config,
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
use_eagle=True,
|
|
)
|
|
|
|
hash_fn = sha256
|
|
|
|
# Complete 6 blocks (96 tokens)
|
|
num_full_blocks = 6
|
|
common_token_ids = [i for i in range(num_full_blocks) for _ in range(block_size)]
|
|
|
|
# Fully cache miss
|
|
# Incomplete 1 block (7 tokens)
|
|
unique_token_ids = [6] * 7
|
|
all_token_ids = common_token_ids + unique_token_ids
|
|
req0 = make_request("0", all_token_ids, block_size, hash_fn)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
|
assert len(req0.block_hashes) == len(all_token_ids) // block_size
|
|
assert not computed_blocks.blocks[0]
|
|
assert num_computed_tokens == 0
|
|
blocks = manager.allocate_slots(
|
|
req0, len(all_token_ids), num_computed_tokens, computed_blocks
|
|
)
|
|
block_ids = (
|
|
[1, 2, 3, 4, 5, 6, 7],
|
|
[8, 9, 10, 11, 12, 13, 14],
|
|
[15, 16, 17, 18, 19, 20, 21],
|
|
)
|
|
assert blocks is not None and blocks.get_block_ids() == block_ids
|
|
|
|
# Check full block metadata
|
|
parent_block_hash = None
|
|
for i, full_block_ids in enumerate(zip(*(row[:-1] for row in block_ids))):
|
|
block_tokens = tuple(all_token_ids[i * block_size : (i + 1) * block_size])
|
|
block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens)
|
|
for group_id, block_id in enumerate(full_block_ids):
|
|
blk_hash = manager.block_pool.blocks[block_id].block_hash
|
|
assert blk_hash is not None
|
|
assert get_block_hash(blk_hash) == block_hash
|
|
assert get_group_id(blk_hash) == group_id
|
|
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
|
parent_block_hash = block_hash
|
|
|
|
# Check partial block metadata
|
|
for partial_block_id in (row[-1] for row in block_ids):
|
|
assert manager.block_pool.blocks[partial_block_id].block_hash is None
|
|
assert manager.block_pool.blocks[partial_block_id].ref_cnt == 1
|
|
|
|
# Cache hit in the common prefix
|
|
# Incomplete 1 block (5 tokens)
|
|
unique_token_ids = [6] * 5
|
|
all_token_ids = common_token_ids + unique_token_ids
|
|
req1 = make_request("1", all_token_ids, block_size, hash_fn)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
assert len(req1.block_hashes) == num_full_blocks
|
|
assert computed_blocks.get_block_ids() == (
|
|
[1, 2, 3, 4, 5],
|
|
[0, 0, 10, 11, 12],
|
|
[0, 0, 17, 18, 19],
|
|
)
|
|
assert num_computed_tokens == 5 * block_size
|
|
num_new_tokens = len(all_token_ids) - num_computed_tokens
|
|
blocks = manager.allocate_slots(
|
|
req1, num_new_tokens, num_computed_tokens, computed_blocks
|
|
)
|
|
assert blocks is not None and blocks.get_block_ids() == (
|
|
[22, 23],
|
|
[24, 25],
|
|
[26, 27],
|
|
)
|
|
for block_per_group in computed_blocks.blocks:
|
|
for block in block_per_group:
|
|
if block != manager.block_pool.null_block:
|
|
assert block.ref_cnt == 2
|
|
|
|
block_hashes = req1.block_hashes
|
|
manager.free(req0)
|
|
manager.free(req1)
|
|
|
|
# Evict the blocks outside sliding window, does not affect the hit length.
|
|
_test_partial_request_hit(
|
|
manager,
|
|
block_size,
|
|
num_full_blocks,
|
|
"2",
|
|
all_token_ids,
|
|
[
|
|
make_block_hash_with_group_id(block_hashes[0], 1),
|
|
make_block_hash_with_group_id(block_hashes[0], 2),
|
|
],
|
|
5,
|
|
)
|
|
|
|
# Evict the first block of full attention, makes total cache miss.
|
|
_test_partial_request_hit(
|
|
manager,
|
|
block_size,
|
|
num_full_blocks,
|
|
"3",
|
|
all_token_ids,
|
|
[make_block_hash_with_group_id(block_hashes[0], 0)],
|
|
0,
|
|
)
|
|
|
|
# Evict the last block of all layers, reduces the hit length to 4.
|
|
_test_partial_request_hit(
|
|
manager,
|
|
block_size,
|
|
num_full_blocks,
|
|
"4",
|
|
all_token_ids,
|
|
[
|
|
make_block_hash_with_group_id(block_hashes[-1], 0),
|
|
make_block_hash_with_group_id(block_hashes[-1], 1),
|
|
make_block_hash_with_group_id(block_hashes[-1], 2),
|
|
],
|
|
4,
|
|
)
|
|
|
|
# Evict the last block of full attention, reduces the hit length to 4.
|
|
_test_partial_request_hit(
|
|
manager,
|
|
block_size,
|
|
num_full_blocks,
|
|
"5",
|
|
all_token_ids,
|
|
[make_block_hash_with_group_id(block_hashes[-1], 0)],
|
|
4,
|
|
)
|
|
|
|
# Since the last block of full attention is dropped for eagle, evict
|
|
# the second last block of sliding window, reduces the hit length to 3.
|
|
_test_partial_request_hit(
|
|
manager,
|
|
block_size,
|
|
num_full_blocks,
|
|
"6",
|
|
all_token_ids,
|
|
[make_block_hash_with_group_id(block_hashes[-2], 1)],
|
|
3,
|
|
)
|
|
|
|
# Since the last block of full attention is dropped for eagle, evict
|
|
# the second last block of sliding window, reduces the hit length to 3.
|
|
_test_partial_request_hit(
|
|
manager,
|
|
block_size,
|
|
num_full_blocks,
|
|
"7",
|
|
all_token_ids,
|
|
[make_block_hash_with_group_id(block_hashes[-2], 2)],
|
|
3,
|
|
)
|
|
|
|
# Evict different set of blocks for full attention and sliding window.
|
|
# Full loses its last block so it drops to 4 full blocks after the eagle
|
|
# pop; SWA lost block 0 (outside the sliding window of the final hit),
|
|
# which is not required for the K+1 anchor at position 4. Coordinated
|
|
# single-drop aligns both groups at hit=4.
|
|
_test_partial_request_hit(
|
|
manager,
|
|
block_size,
|
|
num_full_blocks,
|
|
"8",
|
|
all_token_ids,
|
|
[
|
|
make_block_hash_with_group_id(block_hashes[-1], 0),
|
|
make_block_hash_with_group_id(block_hashes[0], 1),
|
|
make_block_hash_with_group_id(block_hashes[0], 2),
|
|
],
|
|
4,
|
|
)
|
|
|
|
|
|
def _test_partial_request_hit(
|
|
manager: KVCacheManager,
|
|
block_size: int,
|
|
num_full_blocks,
|
|
request_id: str,
|
|
prompt_token_ids: list[int],
|
|
hash_to_evict: list[BlockHashWithGroupId],
|
|
expect_hit_length: int,
|
|
):
|
|
cached_block_hash_to_block_bak = copy.copy(
|
|
manager.block_pool.cached_block_hash_to_block._cache
|
|
)
|
|
req = make_request(request_id, prompt_token_ids, block_size, sha256)
|
|
for hash_with_group_id in hash_to_evict:
|
|
manager.block_pool.cached_block_hash_to_block._cache.pop(hash_with_group_id)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
|
assert len(req.block_hashes) == num_full_blocks
|
|
assert num_computed_tokens == expect_hit_length * block_size
|
|
for block_per_group in computed_blocks.blocks:
|
|
assert len(block_per_group) == num_computed_tokens // block_size
|
|
for hash_with_group_id in hash_to_evict:
|
|
manager.block_pool.cached_block_hash_to_block._cache[hash_with_group_id] = (
|
|
cached_block_hash_to_block_bak[hash_with_group_id]
|
|
)
|
|
manager.free(req)
|
|
|
|
|
|
def _make_hybrid_kv_cache_config(
|
|
block_size: int, num_blocks: int, spec_types: list[str]
|
|
) -> KVCacheConfig:
|
|
"""
|
|
Create a KVCacheConfig with the specified spec types.
|
|
|
|
Args:
|
|
block_size: The block size for KV cache.
|
|
num_blocks: The number of blocks in the KV cache.
|
|
spec_types: List of spec type strings. Supported types:
|
|
- "full": FullAttentionSpec
|
|
- "sliding_window": SlidingWindowSpec with window=2*block_size
|
|
- "sliding_window_large": SlidingWindowSpec with window=4*block_size
|
|
- "mamba": MambaSpec
|
|
"""
|
|
spec_map = {
|
|
"full": lambda: FullAttentionSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
),
|
|
"sliding_window": lambda: SlidingWindowSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
sliding_window=2 * block_size,
|
|
),
|
|
"sliding_window_large": lambda: SlidingWindowSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
sliding_window=4 * block_size,
|
|
),
|
|
"mamba": lambda: MambaSpec(
|
|
block_size=block_size,
|
|
shapes=(1, 1),
|
|
dtypes=(torch.float32,),
|
|
),
|
|
"mamba_align": lambda: MambaSpec(
|
|
block_size=block_size,
|
|
shapes=(1, 1),
|
|
dtypes=(torch.float32,),
|
|
mamba_cache_mode="align",
|
|
),
|
|
}
|
|
|
|
kv_cache_groups = [
|
|
KVCacheGroupSpec([f"layer{i}"], spec_map[spec_type]())
|
|
for i, spec_type in enumerate(spec_types)
|
|
]
|
|
|
|
return KVCacheConfig(
|
|
num_blocks=num_blocks,
|
|
kv_cache_tensors=[],
|
|
kv_cache_groups=kv_cache_groups,
|
|
)
|
|
|
|
|
|
# Test cases covering various combinations of KV cache spec types:
|
|
# - Varying number of groups (2, 3, or 4)
|
|
# - 0, 1, or 2 full attention groups
|
|
# - Sliding window with different window sizes
|
|
# - Interleaved group IDs (full attn and other types mixed)
|
|
# - Mamba spec combinations
|
|
_HYBRID_MODEL_TEST_CASES = [
|
|
# 2 groups: 1 full + 1 other
|
|
pytest.param(["full", "sliding_window"], id="2g-full+sw"),
|
|
pytest.param(["full", "mamba"], id="2g-full+mamba"),
|
|
# 2 groups: 0 full (all other types)
|
|
pytest.param(["sliding_window", "mamba"], id="2g-sw+mamba"),
|
|
pytest.param(["sliding_window", "sliding_window_large"], id="2g-sw+sw_large"),
|
|
# 3 groups: 1 full + 2 others (same type)
|
|
pytest.param(["full", "sliding_window", "sliding_window"], id="3g-full+2sw"),
|
|
pytest.param(["full", "mamba", "mamba"], id="3g-full+2mamba"),
|
|
# 3 groups: 1 full + 2 others (different types)
|
|
pytest.param(["full", "sliding_window", "mamba"], id="3g-full+sw+mamba"),
|
|
pytest.param(
|
|
["full", "sliding_window", "sliding_window_large"],
|
|
id="3g-full+sw+sw_large",
|
|
),
|
|
# 3 groups: 2 full + 1 other
|
|
pytest.param(["full", "full", "sliding_window"], id="3g-2full+sw"),
|
|
pytest.param(["full", "full", "mamba"], id="3g-2full+mamba"),
|
|
# 4 groups: interleaved (full, other, full, other)
|
|
pytest.param(
|
|
["full", "sliding_window", "full", "sliding_window_large"],
|
|
id="4g-interleaved-full+sw+sw_large",
|
|
),
|
|
pytest.param(
|
|
["full", "mamba", "full", "mamba"],
|
|
id="4g-interleaved-full+mamba",
|
|
),
|
|
# 4 groups: interleaved with different sliding windows
|
|
pytest.param(
|
|
["full", "sliding_window", "full", "sliding_window_large"],
|
|
id="4g-interleaved-full+sw_mixed",
|
|
),
|
|
# 4 groups: 0 full (all other types)
|
|
pytest.param(
|
|
["sliding_window", "mamba", "sliding_window_large", "mamba"],
|
|
id="4g-sw+mamba+sw_large+mamba",
|
|
),
|
|
# 4 groups: 2 full + 2 others (grouped)
|
|
pytest.param(
|
|
["full", "full", "sliding_window", "mamba"],
|
|
id="4g-2full+sw+mamba",
|
|
),
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("spec_types", _HYBRID_MODEL_TEST_CASES)
|
|
def test_prefill_hybrid_model_combinations(spec_types: list[str]):
|
|
"""
|
|
Test prefix caching with hybrid models containing various combinations of
|
|
KV cache spec types.
|
|
|
|
This unified test covers:
|
|
- Various combinations (full attn + other attn types)
|
|
- Varying number of groups (2, 3, or 4)
|
|
- 0, 1, or 2 full attention groups in the combination
|
|
- Two sliding_window attn groups with different window sizes
|
|
- Interleaved group IDs (full attn and other types alternating)
|
|
- Mamba spec with other attention types
|
|
"""
|
|
block_size = 16
|
|
num_groups = len(spec_types)
|
|
# Allocate enough blocks for all groups
|
|
num_blocks = 10 * num_groups
|
|
|
|
kv_cache_config = _make_hybrid_kv_cache_config(block_size, num_blocks, spec_types)
|
|
manager = make_kv_cache_manager(
|
|
kv_cache_config,
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
|
|
hash_fn = sha256
|
|
|
|
# Complete 3 blocks (48 tokens)
|
|
common_token_ids = [i for i in range(3) for _ in range(block_size)]
|
|
unique_token_ids = [3] * 7
|
|
all_token_ids = common_token_ids + unique_token_ids
|
|
|
|
# First request: no cache hit initially
|
|
req0 = make_request("0", all_token_ids, block_size, hash_fn)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
|
|
|
assert len(req0.block_hashes) == 3
|
|
assert not computed_blocks.blocks[0] # No cache hit initially
|
|
assert num_computed_tokens == 0
|
|
|
|
blocks = manager.allocate_slots(
|
|
req0, 55, len(computed_blocks.blocks[0]) * block_size, computed_blocks
|
|
)
|
|
assert blocks is not None
|
|
# Should have blocks for all groups
|
|
assert len(blocks.get_block_ids()) == num_groups
|
|
|
|
manager.new_step_starts()
|
|
|
|
# Second request: should hit cached blocks for common prefix
|
|
req1 = make_request("1", common_token_ids + [4] * 5, block_size, hash_fn)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
|
|
# Should hit cached blocks for all groups
|
|
assert num_computed_tokens == 3 * block_size
|
|
assert len(computed_blocks.blocks) == num_groups
|
|
|
|
# Allocate and verify blocks for second request
|
|
blocks = manager.allocate_slots(
|
|
req1,
|
|
len(common_token_ids) + 5 - num_computed_tokens,
|
|
num_computed_tokens,
|
|
computed_blocks,
|
|
)
|
|
assert blocks is not None
|
|
assert len(blocks.get_block_ids()) == num_groups
|
|
|
|
manager.free(req0)
|
|
manager.free(req1)
|
|
|
|
|
|
# Test cases with eagle enabled: Only test a single simple case for now.
|
|
# - 2 groups: 1 full + 1 other
|
|
_EAGLE_HYBRID_MODEL_TEST_CASES = [
|
|
# 2 groups: 1 full + 1 other
|
|
pytest.param(["full", "sliding_window"], 3, id="2g-full+sw"),
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("spec_types,expect_hit_length", _EAGLE_HYBRID_MODEL_TEST_CASES)
|
|
def test_prefill_hybrid_model_combinations_eagle(
|
|
spec_types: list[str], expect_hit_length: int
|
|
):
|
|
"""
|
|
Test prefix caching with hybrid models (1 full attn + 1 other) with EAGLE.
|
|
More complex hybrid models with EAGLE are not yet supported (see issue #32802).
|
|
"""
|
|
block_size = 16
|
|
num_groups = len(spec_types)
|
|
# Allocate enough blocks for all groups
|
|
num_blocks = 10 * num_groups
|
|
|
|
kv_cache_config = _make_hybrid_kv_cache_config(block_size, num_blocks, spec_types)
|
|
manager = make_kv_cache_manager(
|
|
kv_cache_config,
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
use_eagle=True,
|
|
)
|
|
|
|
hash_fn = sha256
|
|
|
|
# Complete 3 blocks (48 tokens)
|
|
num_full_blocks = 4
|
|
common_token_ids = [i for i in range(num_full_blocks) for _ in range(block_size)]
|
|
unique_token_ids = [4] * 7
|
|
all_token_ids = common_token_ids + unique_token_ids
|
|
|
|
# First request: no cache hit initially
|
|
req0 = make_request("0", all_token_ids, block_size, hash_fn)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
|
|
|
assert len(req0.block_hashes) == num_full_blocks
|
|
assert not computed_blocks.blocks[0] # No cache hit initially
|
|
assert num_computed_tokens == 0
|
|
|
|
blocks = manager.allocate_slots(
|
|
req0, len(all_token_ids), num_computed_tokens, computed_blocks
|
|
)
|
|
assert blocks is not None
|
|
# Should have blocks for all groups
|
|
assert len(blocks.get_block_ids()) == num_groups
|
|
|
|
# Second request: should hit cached blocks for common prefix
|
|
all_token_ids = common_token_ids + [6] * 5
|
|
req1 = make_request("1", all_token_ids, block_size, hash_fn)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
|
|
# Should hit cached blocks for all groups
|
|
assert num_computed_tokens == expect_hit_length * block_size
|
|
assert len(computed_blocks.blocks) == num_groups
|
|
# Verify each group has the correct number of computed blocks
|
|
for block_per_group in computed_blocks.blocks:
|
|
assert len(block_per_group) == expect_hit_length
|
|
|
|
# Allocate and verify blocks for second request
|
|
blocks = manager.allocate_slots(
|
|
req1,
|
|
len(all_token_ids) - num_computed_tokens,
|
|
num_computed_tokens,
|
|
computed_blocks,
|
|
)
|
|
assert blocks is not None
|
|
assert len(blocks.get_block_ids()) == num_groups
|
|
|
|
manager.free(req0)
|
|
manager.free(req1)
|
|
|
|
|
|
def test_prefill_hybrid_model_mamba_align():
|
|
"""Test that MambaManager.cache_blocks() handles null blocks in align mode.
|
|
|
|
Regression test for https://github.com/vllm-project/vllm/issues/34361.
|
|
In mamba_cache_mode="align", allocate_new_blocks() pads req_to_blocks with
|
|
null blocks. cache_full_blocks() correctly skips them, but
|
|
MambaManager.cache_blocks() must also skip null blocks when tracking
|
|
cached_blocks_this_step.
|
|
"""
|
|
block_size = 16
|
|
num_blocks = 30
|
|
|
|
kv_cache_config = _make_hybrid_kv_cache_config(
|
|
block_size, num_blocks, ["full", "mamba_align"]
|
|
)
|
|
manager = make_kv_cache_manager(
|
|
kv_cache_config,
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
|
|
hash_fn = sha256
|
|
|
|
# 3 full blocks (48 tokens) + 7 partial tokens = 55 tokens total
|
|
all_token_ids = [i for i in range(3) for _ in range(block_size)] + [3] * 7
|
|
|
|
# First request: allocate_slots should not crash with the assertion error
|
|
# in MambaManager.cache_blocks() when null blocks are present.
|
|
req0 = make_request("0", all_token_ids, block_size, hash_fn)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
|
assert num_computed_tokens == 0
|
|
|
|
blocks = manager.allocate_slots(req0, 55, num_computed_tokens, computed_blocks)
|
|
assert blocks is not None
|
|
assert len(blocks.get_block_ids()) == 2 # full_attn + mamba groups
|
|
|
|
manager.free(req0)
|
|
|
|
|
|
def test_prefill_plp():
|
|
"""Test prefill with APC and some prompt logprobs (plp) requests.
|
|
|
|
1. Schedule plp request and validate APC block allocation
|
|
2. Schedule non-plp request and validate blocks
|
|
3. Schedule plp request; no hit should occur; validate blocks
|
|
"""
|
|
block_size = 16
|
|
manager = make_kv_cache_manager(
|
|
make_kv_cache_config(block_size, 11),
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
# the default hash function is sha256
|
|
hash_fn = sha256
|
|
|
|
# Complete 3 blocks (48 tokens)
|
|
common_token_ids = [i for i in range(3) for _ in range(16)]
|
|
|
|
# Request #0 is a prompt logprobs request
|
|
# Fully cache miss
|
|
# Incomplete 1 block (7 tokens)
|
|
unique_token_ids = [3] * 7
|
|
all_token_ids = common_token_ids + unique_token_ids
|
|
req0 = make_request("0", all_token_ids, block_size, hash_fn, prompt_logprobs=5)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
|
assert len(req0.block_hashes) == 3
|
|
assert not computed_blocks.blocks[0]
|
|
assert num_computed_tokens == 0
|
|
blocks = manager.allocate_slots(
|
|
req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4],)
|
|
req0_block_hashes = [b.block_hash for b in blocks.blocks[0]]
|
|
|
|
# Check full block metadata
|
|
parent_block_hash = None
|
|
for block_id in (1, 2, 3):
|
|
block_tokens = tuple(all_token_ids[(block_id - 1) * 16 : block_id * 16])
|
|
block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens)
|
|
blk_hash = manager.block_pool.blocks[block_id].block_hash
|
|
assert blk_hash is not None
|
|
assert get_block_hash(blk_hash) == block_hash
|
|
assert get_group_id(blk_hash) == 0
|
|
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
|
parent_block_hash = block_hash
|
|
|
|
# Check partial block metadata
|
|
for block_id in (4,):
|
|
assert manager.block_pool.blocks[block_id].block_hash is None
|
|
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
|
|
|
# Request #1 is a non-prompt-logprobs request:
|
|
# Cache hit in the common prefix when the original block is still in use.
|
|
# Incomplete 1 block (5 tokens)
|
|
unique_token_ids = [3] * 5
|
|
req1 = make_request("1", common_token_ids + unique_token_ids, block_size, hash_fn)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
assert len(req1.block_hashes) == 3
|
|
assert computed_blocks.get_block_ids() == ([1, 2, 3],)
|
|
assert num_computed_tokens == 3 * 16
|
|
num_new_tokens = 53 - 3 * 16
|
|
blocks = manager.allocate_slots(
|
|
req1, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert blocks is not None and blocks.get_block_ids() == ([5],)
|
|
for block in computed_blocks.blocks[0]:
|
|
assert block.ref_cnt == 2
|
|
|
|
# At this point, we should have 5 free blocks left.
|
|
assert manager.block_pool.free_block_queue.num_free_blocks == 5
|
|
|
|
manager.free(req0)
|
|
manager.free(req1)
|
|
|
|
# All blocks should be available.
|
|
assert manager.block_pool.free_block_queue.num_free_blocks == 10
|
|
# The order should be
|
|
# [unallocated (6, 7, 8, 9, 10)]
|
|
# [unique_req0 (4)]
|
|
# [unique_req1 (5)]
|
|
# [common (3, 2, 1)]
|
|
assert [
|
|
b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks()
|
|
] == [6, 7, 8, 9, 10, 4, 5, 3, 2, 1]
|
|
|
|
# Request #2 is a prompt-logprobs request:
|
|
# NO cache hit in the common prefix; duplicates request #0 cached blocks
|
|
unique_token_ids = [3] * 6
|
|
req2 = make_request(
|
|
"2", common_token_ids + unique_token_ids, block_size, hash_fn, prompt_logprobs=5
|
|
)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
|
assert len(req2.block_hashes) == 3
|
|
assert not computed_blocks.blocks[0]
|
|
assert num_computed_tokens == 0
|
|
blocks = manager.allocate_slots(
|
|
req2, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert blocks is not None
|
|
block_ids = blocks.get_block_ids()
|
|
# Duplicate cached blocks have different ids but same hashes vs request #0
|
|
assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes
|
|
assert block_ids != ([1, 2, 3, 4],)
|
|
|
|
# Request #2 block hashes are valid since request #0 hashes are.
|
|
# Check block reference counts.
|
|
for block_id in block_ids[0]:
|
|
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
|
|
|
manager.free(req2)
|
|
|
|
|
|
def test_decode():
|
|
block_size = 16
|
|
manager = make_kv_cache_manager(
|
|
make_kv_cache_config(block_size, 11),
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
|
|
# Complete 3 blocks (48 tokens)
|
|
common_token_ids = [i for i in range(3) for _ in range(16)]
|
|
|
|
# Fully cache miss
|
|
# Incomplete 1 block (7 tokens)
|
|
unique_token_ids = [3] * 7
|
|
req0 = make_request("0", common_token_ids + unique_token_ids, block_size, sha256)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
|
assert not computed_blocks.blocks[0]
|
|
assert num_computed_tokens == 0
|
|
blocks = manager.allocate_slots(
|
|
req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4],)
|
|
|
|
# Append slots without allocating a new block.
|
|
req0.num_computed_tokens = 55
|
|
for _ in range(4):
|
|
req0.append_output_token_ids(8)
|
|
new_blocks = manager.allocate_slots(
|
|
req0, 4, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
|
|
assert (
|
|
manager.coordinator.single_type_managers[0]
|
|
.req_to_blocks[req0.request_id][-1]
|
|
.block_hash
|
|
is None
|
|
)
|
|
|
|
# Append slots with allocating a new block.
|
|
req0.num_computed_tokens = 59
|
|
# 9 tokens to fill the previous block, and 10 tokens to fill
|
|
# the preallocated block.
|
|
for _ in range(9 + 10):
|
|
req0.append_output_token_ids(7)
|
|
new_blocks = manager.allocate_slots(
|
|
req0, 19, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert new_blocks is not None and len(new_blocks.blocks[0]) == 1
|
|
assert (
|
|
manager.coordinator.single_type_managers[0]
|
|
.req_to_blocks[req0.request_id][-2]
|
|
.block_hash
|
|
is not None
|
|
)
|
|
assert (
|
|
manager.coordinator.single_type_managers[0]
|
|
.req_to_blocks[req0.request_id][-1]
|
|
.block_hash
|
|
is None
|
|
)
|
|
|
|
|
|
def test_evict():
|
|
block_size = 16
|
|
manager = make_kv_cache_manager(
|
|
make_kv_cache_config(block_size, 11),
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
|
|
last_token_id = 5 * 16 + 7
|
|
req0 = make_request("0", list(range(last_token_id)), block_size, sha256)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
|
assert not computed_blocks.blocks[0]
|
|
assert num_computed_tokens == 0
|
|
blocks = manager.allocate_slots(
|
|
req0, 5 * 16 + 7, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
# 5 full + 1 partial
|
|
assert blocks is not None and len(blocks.blocks[0]) == 6
|
|
|
|
# 3 blocks.
|
|
req1 = make_request(
|
|
"1", list(range(last_token_id, last_token_id + 3 * 16)), block_size, sha256
|
|
)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
assert not computed_blocks.blocks[0]
|
|
assert num_computed_tokens == 0
|
|
blocks = manager.allocate_slots(
|
|
req1, 3 * 16, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert blocks is not None and len(blocks.blocks[0]) == 3 # 3 full blocks
|
|
last_token_id += 3 * 16
|
|
|
|
# 10 - (6 + 3) == 1
|
|
assert manager.block_pool.free_block_queue.num_free_blocks == 1
|
|
|
|
manager.free(req0)
|
|
manager.free(req1)
|
|
assert manager.block_pool.free_block_queue.num_free_blocks == 10
|
|
assert [
|
|
b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks()
|
|
] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7]
|
|
|
|
# Touch the first 2 blocks.
|
|
req2 = make_request("2", list(range(2 * 16 + 3)), block_size, sha256)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
|
assert computed_blocks.get_block_ids() == ([1, 2],)
|
|
assert num_computed_tokens == 2 * 16
|
|
blocks = manager.allocate_slots(
|
|
req2, 3, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert blocks is not None and blocks.get_block_ids() == ([10],)
|
|
assert manager.block_pool.free_block_queue.num_free_blocks == 7
|
|
|
|
|
|
def test_hash_block_correct_reuse():
|
|
"""
|
|
This tests when a previously cached block is reused as a new block,
|
|
its hash metadata should be correctly reset.
|
|
"""
|
|
block_size = 16
|
|
manager = make_kv_cache_manager(
|
|
make_kv_cache_config(16, 2),
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
|
|
# Allocate 1 block and cache it.
|
|
num_tokens = block_size * 1
|
|
req = make_request("0", list(range(num_tokens)), block_size, sha256)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
|
assert not computed_blocks.blocks[0]
|
|
assert num_computed_tokens == 0
|
|
blocks = manager.allocate_slots(
|
|
req, num_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert blocks is not None and len(blocks.blocks[0]) == 1
|
|
|
|
# Deallocate the block.
|
|
manager.free(req)
|
|
|
|
# Allocate a new block that's not full, make sure hash info on the
|
|
# block is cleared.
|
|
req = make_request("1", list(range(num_tokens - 1)), block_size, sha256)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
|
assert not computed_blocks.blocks[0]
|
|
assert num_computed_tokens == 0
|
|
blocks = manager.allocate_slots(
|
|
req, num_tokens - 1, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert blocks is not None and len(blocks.blocks[0]) == 1
|
|
|
|
assert manager.block_pool.blocks[blocks.blocks[0][0].block_id].block_hash is None
|
|
|
|
|
|
def test_computed_blocks_not_evicted():
|
|
"""
|
|
Test that the computed blocks are not evicted when getting new blocks
|
|
for a request if there are any other free blocks.
|
|
"""
|
|
block_size = 16
|
|
manager = make_kv_cache_manager(
|
|
make_kv_cache_config(block_size, 3),
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
|
|
# Allocate a block and cache it.
|
|
num_tokens = block_size * 1
|
|
req0 = make_request("0", list(range(num_tokens)), block_size, sha256)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
|
assert not computed_blocks.blocks[0]
|
|
assert num_computed_tokens == 0
|
|
blocks = manager.allocate_slots(
|
|
req0, num_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert blocks is not None and len(blocks.blocks[0]) == 1
|
|
assert blocks.blocks[0][0].block_id == 1
|
|
|
|
# Allocate another block.
|
|
req1 = make_request(
|
|
"1", list(range(num_tokens, num_tokens * 2)), block_size, sha256
|
|
)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
assert not computed_blocks.blocks[0]
|
|
assert num_computed_tokens == 0
|
|
blocks = manager.allocate_slots(
|
|
req1, num_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert blocks is not None and len(blocks.blocks[0]) == 1
|
|
assert blocks.blocks[0][0].block_id == 2
|
|
|
|
# Free the blocks.
|
|
manager.free(req0)
|
|
manager.free(req1)
|
|
|
|
# Now if we have a cache hit on the first block, we should evict the second
|
|
# cached block rather than the first one.
|
|
req2 = make_request("2", list(range(num_tokens * 2)), block_size, sha256)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
|
assert len(computed_blocks.blocks[0]) == 1
|
|
assert computed_blocks.blocks[0][0].block_id == 1
|
|
assert num_computed_tokens == block_size
|
|
|
|
blocks = manager.allocate_slots(
|
|
req2,
|
|
num_tokens * 2 - num_tokens,
|
|
len(computed_blocks.blocks[0]) * 16,
|
|
computed_blocks,
|
|
)
|
|
assert blocks is not None and len(blocks.blocks[0]) == 1
|
|
assert blocks.blocks[0][0].block_id == 2
|
|
|
|
|
|
def test_basic_prefix_caching_disabled():
|
|
"""
|
|
This tests that the prefix caching is disabled.
|
|
"""
|
|
block_size = 4
|
|
manager = make_kv_cache_manager(
|
|
make_kv_cache_config(block_size, 5),
|
|
max_model_len=8192,
|
|
enable_caching=False,
|
|
hash_block_size=block_size,
|
|
)
|
|
|
|
req1 = make_request(
|
|
"1", list(range(10)), block_size, sha256
|
|
) # 2 blocks and some more
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
assert not computed_blocks.blocks[0]
|
|
assert num_computed_tokens == 0
|
|
blocks = manager.allocate_slots(
|
|
req1, 10, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert blocks is not None and len(blocks.blocks[0]) == 3
|
|
|
|
# Free the blocks.
|
|
manager.free(req1)
|
|
|
|
# No caching.
|
|
req2 = make_request("2", list(range(16)), block_size, sha256) # shared prefix
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
|
assert not computed_blocks.blocks[0]
|
|
assert num_computed_tokens == 0
|
|
blocks = manager.allocate_slots(
|
|
req2, 16, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert blocks is not None and len(blocks.blocks[0]) == 4
|
|
|
|
# New requests should not have any blocks.
|
|
req3 = make_request("3", list(range(4)), block_size, sha256)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
|
assert not computed_blocks.blocks[0]
|
|
assert num_computed_tokens == 0
|
|
blocks = manager.allocate_slots(
|
|
req3, 4, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert not blocks
|
|
|
|
|
|
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
|
def test_cache_blocks(hash_fn):
|
|
"""
|
|
This is a unit test that tests the correctness of the _cache_full_blocks
|
|
function of KVCacheManager.
|
|
"""
|
|
|
|
block_size = 4
|
|
block_pool = BlockPool(
|
|
num_gpu_blocks=5,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
# Req:
|
|
# Block 0: [0, 1, 2, 3]
|
|
# Block 1: [4, 5, 6, 7]
|
|
# Block 2: [8, 9, 10, 11]
|
|
# Block 3: [12, 13]
|
|
req = make_request("0", list(range(14)), block_size, hash_fn)
|
|
|
|
# Test that blocks are cached correctly for 2 full blocks from the start.
|
|
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
|
|
|
|
block_pool.cache_full_blocks(
|
|
request=req,
|
|
blocks=blocks,
|
|
num_cached_blocks=0,
|
|
num_full_blocks=2,
|
|
block_size=block_size,
|
|
kv_cache_group_id=0,
|
|
)
|
|
|
|
assert len(block_pool.cached_block_hash_to_block) == 2
|
|
assert all([block.block_hash is not None for block in blocks])
|
|
|
|
# Test that blocks that don't start from the beginning are cached
|
|
# correctly.
|
|
blocks += [KVCacheBlock(block_id=2)]
|
|
block_pool.cache_full_blocks(
|
|
request=req,
|
|
blocks=blocks,
|
|
num_cached_blocks=2,
|
|
num_full_blocks=3,
|
|
block_size=block_size,
|
|
kv_cache_group_id=0,
|
|
)
|
|
assert len(block_pool.cached_block_hash_to_block) == 3
|
|
assert blocks[0].block_hash is not None
|
|
|
|
|
|
def test_cache_blocks_multi_group():
|
|
"""
|
|
This tests that blocks are cached correctly for different kv cache groups.
|
|
"""
|
|
block_size = 4
|
|
block_pool = BlockPool(
|
|
num_gpu_blocks=10, enable_caching=True, hash_block_size=block_size
|
|
)
|
|
|
|
# Req:
|
|
# Block 0/4: [0, 1, 2, 3]
|
|
# Block 1/5: [4, 5, 6, 7]
|
|
# Block 2/6: [8, 9, 10, 11]
|
|
# Block 3/7: [12, 13]
|
|
req = make_request("0", list(range(14)), block_size, sha256)
|
|
|
|
# Cache the blocks for group 0.
|
|
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
|
|
block_pool.cache_full_blocks(
|
|
request=req,
|
|
blocks=blocks,
|
|
num_cached_blocks=0,
|
|
num_full_blocks=2,
|
|
block_size=block_size,
|
|
kv_cache_group_id=0,
|
|
)
|
|
assert len(block_pool.cached_block_hash_to_block) == 2
|
|
assert len(req.block_hashes) == 3
|
|
assert all([block.block_hash is not None for block in blocks])
|
|
|
|
# Cache the blocks for group 1.
|
|
blocks = [KVCacheBlock(block_id=i) for i in range(3)]
|
|
block_pool.cache_full_blocks(
|
|
request=req,
|
|
blocks=blocks,
|
|
num_cached_blocks=0,
|
|
num_full_blocks=3,
|
|
block_size=block_size,
|
|
kv_cache_group_id=1,
|
|
)
|
|
assert len(block_pool.cached_block_hash_to_block) == 5
|
|
assert len(req.block_hashes) == 3
|
|
assert all([block.block_hash is not None for block in blocks])
|
|
|
|
# Block hash 0: hit for group 0 and 1
|
|
# Block hash 1: hit for group 0 and 1
|
|
# Block hash 2: hit for group 1
|
|
|
|
assert (
|
|
block_pool.get_cached_block(req.block_hashes[0], kv_cache_group_ids=[0])
|
|
is not None
|
|
)
|
|
assert (
|
|
block_pool.get_cached_block(req.block_hashes[1], kv_cache_group_ids=[0])
|
|
is not None
|
|
)
|
|
assert (
|
|
block_pool.get_cached_block(req.block_hashes[2], kv_cache_group_ids=[0]) is None
|
|
)
|
|
assert (
|
|
block_pool.get_cached_block(req.block_hashes[0], kv_cache_group_ids=[1])
|
|
is not None
|
|
)
|
|
assert (
|
|
block_pool.get_cached_block(req.block_hashes[1], kv_cache_group_ids=[1])
|
|
is not None
|
|
)
|
|
assert (
|
|
block_pool.get_cached_block(req.block_hashes[2], kv_cache_group_ids=[1])
|
|
is not None
|
|
)
|
|
assert (
|
|
block_pool.get_cached_block(req.block_hashes[0], kv_cache_group_ids=[0, 1])
|
|
is not None
|
|
)
|
|
assert (
|
|
block_pool.get_cached_block(req.block_hashes[1], kv_cache_group_ids=[0, 1])
|
|
is not None
|
|
)
|
|
assert (
|
|
block_pool.get_cached_block(req.block_hashes[2], kv_cache_group_ids=[0, 1])
|
|
is None
|
|
)
|
|
|
|
|
|
def test_mm_prefix_caching():
|
|
"""
|
|
This tests that the multi-modal prefix caching is correct.
|
|
"""
|
|
|
|
block_size = 16
|
|
manager = make_kv_cache_manager(
|
|
make_kv_cache_config(block_size, 11),
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
|
|
# Common prompt tokens (T is text tokens and P is image placeholder tokens)
|
|
# [T,...,T, P0,...,P0], [P0,...,P0,T,...,T,P1,...,P1], [P1,...,P1]
|
|
common_token_ids = list(range(10)) + [-1] * 6
|
|
common_token_ids += [-1] * 4 + list(range(10, 20)) + [-1] * 2
|
|
common_token_ids += [-1] * 16
|
|
|
|
common_mm_positions = [
|
|
PlaceholderRange(offset=11, length=10),
|
|
PlaceholderRange(offset=30, length=18),
|
|
]
|
|
common_mm_hashes = ["aaa", "bbb"]
|
|
|
|
# A unique image plus some text tokens.
|
|
unique_token_ids = [-1] * 7 + [100] * 4
|
|
all_token_ids = common_token_ids + unique_token_ids
|
|
mm_positions = common_mm_positions + [PlaceholderRange(offset=48, length=7)]
|
|
mm_hashes = common_mm_hashes + ["ccc"]
|
|
req0 = make_request(
|
|
"0",
|
|
all_token_ids,
|
|
block_size,
|
|
sha256,
|
|
mm_positions=mm_positions,
|
|
mm_hashes=mm_hashes,
|
|
)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
|
|
|
# Completed block should have hashes
|
|
assert not computed_blocks.blocks[0]
|
|
assert num_computed_tokens == 0
|
|
block_hashes = req0.block_hashes
|
|
assert len(block_hashes) == 3
|
|
assert block_hashes[0] == sha256(
|
|
(
|
|
kv_cache_utils.NONE_HASH,
|
|
tuple(all_token_ids[:block_size]),
|
|
(("aaa", 11),),
|
|
)
|
|
)
|
|
assert block_hashes[1] == sha256(
|
|
(
|
|
block_hashes[0],
|
|
tuple(all_token_ids[block_size : block_size * 2]),
|
|
(("aaa", -5), ("bbb", 14)),
|
|
)
|
|
)
|
|
assert block_hashes[2] == sha256(
|
|
(
|
|
block_hashes[1],
|
|
tuple(all_token_ids[block_size * 2 : block_size * 3]),
|
|
(("bbb", -2),),
|
|
)
|
|
)
|
|
|
|
blocks = manager.allocate_slots(
|
|
req0, 59, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert blocks is not None
|
|
assert blocks.get_block_ids() == ([1, 2, 3, 4],)
|
|
req0.num_computed_tokens = 59
|
|
|
|
# Append slots without allocating a new block.
|
|
for _ in range(5):
|
|
req0.append_output_token_ids(8)
|
|
new_blocks = manager.allocate_slots(
|
|
req0, 5, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
|
|
assert len(block_hashes) == 4
|
|
assert block_hashes[3] == sha256(
|
|
(
|
|
block_hashes[2],
|
|
tuple(all_token_ids[3 * block_size :] + [8] * 5),
|
|
(("ccc", 0),),
|
|
)
|
|
)
|
|
|
|
# Cache hit.
|
|
unique_token_ids = [-1] * 7 + [200] * 5
|
|
all_token_ids = common_token_ids + unique_token_ids
|
|
mm_positions = common_mm_positions + [PlaceholderRange(offset=48, length=7)]
|
|
mm_hashes = common_mm_hashes + ["ccc"]
|
|
req1 = make_request(
|
|
"1",
|
|
all_token_ids,
|
|
block_size,
|
|
sha256,
|
|
mm_positions=mm_positions,
|
|
mm_hashes=mm_hashes,
|
|
)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
assert len(computed_blocks.blocks[0]) == 3
|
|
assert num_computed_tokens == 3 * 16
|
|
|
|
|
|
def test_cache_key_salting():
|
|
"""
|
|
This tests that cache salts are applied during hashing and the cache
|
|
is separated cache as expected.
|
|
"""
|
|
block_size = 16
|
|
manager = make_kv_cache_manager(
|
|
make_kv_cache_config(block_size, 11),
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
|
|
# 3 complete blocks and an incomplete block with 11 tokens.
|
|
common_token_ids = [i for i in range(3) for _ in range(block_size)]
|
|
token_ids = common_token_ids + [3] * 11
|
|
req0 = make_request("0", token_ids, block_size, sha256, cache_salt="salt1")
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
|
|
|
# Completed block should have hashes
|
|
assert not computed_blocks.blocks[0]
|
|
assert num_computed_tokens == 0
|
|
block_hashes = req0.block_hashes
|
|
assert len(block_hashes) == 3
|
|
assert block_hashes[0] == sha256(
|
|
(kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt1",))
|
|
)
|
|
assert block_hashes[1] == sha256(
|
|
(block_hashes[0], tuple(token_ids[block_size : block_size * 2]), None)
|
|
)
|
|
assert block_hashes[2] == sha256(
|
|
(block_hashes[1], tuple(token_ids[block_size * 2 : block_size * 3]), None)
|
|
)
|
|
|
|
blocks = manager.allocate_slots(
|
|
req0, 59, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert blocks is not None
|
|
assert blocks.get_block_ids() == ([1, 2, 3, 4],)
|
|
req0.num_computed_tokens = 59
|
|
|
|
# Append slots without allocating a new block.
|
|
for _ in range(5):
|
|
req0.append_output_token_ids(8)
|
|
new_blocks = manager.allocate_slots(
|
|
req0, 5, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
|
|
assert len(block_hashes) == 4
|
|
assert block_hashes[3] == sha256(
|
|
(block_hashes[2], tuple(token_ids[3 * block_size :] + [8] * 5), None)
|
|
)
|
|
|
|
# Test cache hit with a new request that has the same salt.
|
|
token_ids = common_token_ids + [4] * 11
|
|
req1 = make_request("1", token_ids, block_size, sha256, cache_salt="salt1")
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
# Should match only a prefix of 3 blocks.
|
|
assert len(computed_blocks.blocks[0]) == 3
|
|
assert num_computed_tokens == 3 * block_size
|
|
|
|
# Test cache miss with same content but different salt.
|
|
token_ids = common_token_ids + [4] * 11
|
|
req2 = make_request("2", token_ids, block_size, sha256, cache_salt="salt2")
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
|
assert len(computed_blocks.blocks[0]) == 0
|
|
assert num_computed_tokens == 0
|
|
block_hashes = req2.block_hashes
|
|
assert len(block_hashes) == 3
|
|
assert block_hashes[0] == sha256(
|
|
(kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt2",))
|
|
)
|
|
assert block_hashes[1] == sha256(
|
|
(block_hashes[0], tuple(token_ids[block_size : block_size * 2]), None)
|
|
)
|
|
assert block_hashes[2] == sha256(
|
|
(block_hashes[1], tuple(token_ids[block_size * 2 : block_size * 3]), None)
|
|
)
|
|
|
|
|
|
def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
|
"""
|
|
This is a unit test that tests the correctness of the allocate_slots
|
|
when there is not enough free blocks. Specifically, when a request
|
|
has computed blocks but cannot be allocated due to not enough free blocks,
|
|
the computed blocks should not be touched.
|
|
"""
|
|
block_size = 16
|
|
manager = make_kv_cache_manager(
|
|
make_kv_cache_config(block_size, 11),
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
# Complete 3 blocks (48 tokens)
|
|
# | Common-0 | Common-1 | Common-2 | ... |
|
|
common_token_ids = [i for i in range(3) for _ in range(16)]
|
|
req0 = make_request("0", common_token_ids, block_size, sha256)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
|
assert not computed_blocks.blocks[0]
|
|
assert num_computed_tokens == 0
|
|
manager.allocate_slots(
|
|
req0, 48, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
block_part0 = manager.coordinator.single_type_managers[0].req_to_blocks[
|
|
req0.request_id
|
|
]
|
|
|
|
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
|
|
req1 = make_request("1", common_token_ids * 2, block_size, sha256)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
assert computed_blocks.blocks[0] == block_part0
|
|
assert num_computed_tokens == 3 * 16
|
|
manager.allocate_slots(
|
|
req1, 48, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
block_part1 = manager.coordinator.single_type_managers[0].req_to_blocks[
|
|
req1.request_id
|
|
]
|
|
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
|
|
# | Req1-5(F)| ... |
|
|
manager.free(req1)
|
|
assert {block.ref_cnt for block in block_part1[:3]} == {1}
|
|
assert {block.ref_cnt for block in block_part1[3:]} == {0}
|
|
|
|
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
|
|
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
|
|
req2 = make_request("2", [7] * block_size * 2, block_size, sha256)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
|
assert not computed_blocks.blocks[0]
|
|
assert num_computed_tokens == 0
|
|
manager.allocate_slots(
|
|
req2,
|
|
block_size * 2,
|
|
len(computed_blocks.blocks[0]) * block_size,
|
|
computed_blocks,
|
|
)
|
|
|
|
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
|
|
# but it cannot be allocated due to insufficient free blocks (2).
|
|
# In this case, the ref_cnt of the computed blocks should not be changed.
|
|
assert manager.block_pool.free_block_queue.num_free_blocks == 5
|
|
req3 = make_request("3", common_token_ids * 3, block_size, sha256)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
|
assert computed_blocks.blocks[0] == block_part1
|
|
assert num_computed_tokens == 6 * 16
|
|
# Req3 cannot be allocated.
|
|
assert (
|
|
manager.allocate_slots(
|
|
req3, 48, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
is None
|
|
)
|
|
# Block 0-2 are used by Req 1.
|
|
assert {block.ref_cnt for block in block_part1[:3]} == {1}
|
|
# Block 3-5 are free.
|
|
assert {block.ref_cnt for block in block_part1[3:]} == {0}
|
|
|
|
|
|
def test_reset_prefix_cache():
|
|
block_size = 16
|
|
manager = make_kv_cache_manager(
|
|
make_kv_cache_config(block_size, 11),
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
|
|
full_block_token_ids = [i for i in range(3) for _ in range(16)]
|
|
unique_token_ids = [3] * 7
|
|
all_token_ids = full_block_token_ids + unique_token_ids
|
|
req0 = make_request("0", all_token_ids, block_size, sha256)
|
|
blocks = manager.allocate_slots(req0, 55)
|
|
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4],)
|
|
|
|
unique_token_ids = [4] * 7
|
|
all_token_ids = full_block_token_ids + unique_token_ids
|
|
req1 = make_request("1", all_token_ids, block_size, sha256)
|
|
computed_blocks, _ = manager.get_computed_blocks(req1)
|
|
assert len(req1.block_hashes) == 3
|
|
assert len(computed_blocks.blocks[0]) == 3
|
|
blocks = manager.allocate_slots(
|
|
req1, 7, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert blocks is not None and blocks.get_block_ids() == ([5],)
|
|
|
|
# Failed to reset prefix cache because some blocks are not freed yet.
|
|
assert not manager.reset_prefix_cache()
|
|
assert manager.block_pool.cached_block_hash_to_block
|
|
|
|
# Free the blocks.
|
|
manager.free(req0)
|
|
manager.free(req1)
|
|
|
|
assert manager.reset_prefix_cache()
|
|
assert not manager.block_pool.cached_block_hash_to_block
|
|
assert all([blk.block_hash is None for blk in manager.block_pool.blocks])
|
|
|
|
|
|
def test_prefix_cache_stats_disabled():
|
|
"""Test that prefix_cache_stats is None when log_stats is False."""
|
|
block_size = 16
|
|
manager = make_kv_cache_manager(
|
|
make_kv_cache_config(block_size, 11),
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
log_stats=False, # Disable logging stats
|
|
)
|
|
assert manager.prefix_cache_stats is None
|
|
|
|
# Call all functions that check whether log_stats is disabled.
|
|
req = make_request("0", list(range(16)), block_size, sha256)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
|
assert not computed_blocks.blocks[0]
|
|
assert num_computed_tokens == 0
|
|
manager.allocate_slots(
|
|
req, 16, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
manager.reset_prefix_cache()
|
|
|
|
# Ensure prefix_cache_stats remains None
|
|
assert manager.prefix_cache_stats is None
|
|
|
|
|
|
def test_maybe_evict_cached_block():
|
|
pool = BlockPool(num_gpu_blocks=4, enable_caching=True, hash_block_size=16)
|
|
block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000)
|
|
block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000)
|
|
block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000)
|
|
block_hashes = [
|
|
block_hash0,
|
|
block_hash1,
|
|
block_hash2,
|
|
# block3 had the exact same block_hash as the first block
|
|
block_hash0,
|
|
]
|
|
assert len(pool.blocks) == len(block_hashes)
|
|
# Manually add all blocks to cached_blocks
|
|
for block, block_hash in zip(pool.blocks, block_hashes):
|
|
block.block_hash = block_hash
|
|
pool.cached_block_hash_to_block.insert(block_hash, block)
|
|
|
|
block0, block1, block2, block3 = pool.blocks
|
|
assert pool.cached_block_hash_to_block._cache == {
|
|
block_hash0: {
|
|
block0.block_id: block0,
|
|
block3.block_id: block3,
|
|
},
|
|
block_hash1: block1,
|
|
block_hash2: block2,
|
|
}
|
|
# Evict block1
|
|
pool._maybe_evict_cached_block(block1)
|
|
assert pool.cached_block_hash_to_block._cache == {
|
|
block_hash0: {block0.block_id: block0, block3.block_id: block3},
|
|
block_hash2: block2,
|
|
}
|
|
# Evict block0: block_hash0 entry should NOT be removed, as block3
|
|
# also use the same hash
|
|
pool._maybe_evict_cached_block(block0)
|
|
assert pool.cached_block_hash_to_block._cache == {
|
|
block_hash0: {block3.block_id: block3},
|
|
block_hash2: block2,
|
|
}
|
|
# Evict block2
|
|
pool._maybe_evict_cached_block(block2)
|
|
assert pool.cached_block_hash_to_block._cache == {block_hash0: {3: block3}}
|
|
# Evict block3
|
|
pool._maybe_evict_cached_block(block3)
|
|
assert pool.cached_block_hash_to_block._cache == {}
|
|
|
|
|
|
@pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
|
|
def test_kv_cache_events(blocks_to_cache: int):
|
|
block_size = 16
|
|
num_blocks = blocks_to_cache + 1
|
|
|
|
# Allocate Blocks
|
|
# Should see a single block stored event with a blocks_to_cache number of
|
|
# block hashes
|
|
# take_events should reset the kv_event_queue
|
|
manager = make_kv_cache_manager(
|
|
make_kv_cache_config(block_size, num_blocks),
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
enable_kv_cache_events=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
|
|
num_tokens = block_size * blocks_to_cache
|
|
req0 = make_request("0", list(range(num_tokens)), block_size, sha256)
|
|
_ = manager.allocate_slots(req0, num_tokens)
|
|
events = manager.take_events()
|
|
|
|
block = events[-1]
|
|
assert (
|
|
len(block.block_hashes)
|
|
== blocks_to_cache
|
|
== len(manager.block_pool.cached_block_hash_to_block)
|
|
)
|
|
assert len(block.token_ids) == block.block_size * len(block.block_hashes)
|
|
assert block.kv_cache_spec_kind == KVCacheSpecKind.FULL_ATTENTION.value
|
|
assert len(manager.block_pool.kv_event_queue) == 0
|
|
|
|
stored_block_hash = block.block_hashes
|
|
|
|
# Remove blocks and send another request
|
|
# Should see block_to_cache number of removed block events and a new block
|
|
# stored event
|
|
manager.free(req0)
|
|
req1 = make_request("1", list(range(num_tokens)), block_size, sha256)
|
|
_ = manager.allocate_slots(req1, num_tokens)
|
|
events = manager.take_events()
|
|
|
|
for blocks in events[:-1]:
|
|
assert isinstance(blocks, BlockRemoved)
|
|
assert blocks.block_hashes[0] in stored_block_hash
|
|
assert len(events) == blocks_to_cache + 1
|
|
assert isinstance(events[-2], BlockRemoved)
|
|
assert (
|
|
len(events[-1].block_hashes)
|
|
== blocks_to_cache
|
|
== len(manager.block_pool.cached_block_hash_to_block)
|
|
)
|
|
|
|
# All Blocks Cleared
|
|
# Should see a single all blocks cleared event
|
|
manager.free(req1)
|
|
manager.reset_prefix_cache()
|
|
events = manager.take_events()
|
|
|
|
assert isinstance(events[-1], AllBlocksCleared)
|
|
assert len(manager.block_pool.cached_block_hash_to_block) == 0
|
|
|
|
|
|
def test_null_parent_block_hash():
|
|
block_size = 1
|
|
num_cached_blocks = 2
|
|
num_full_blocks = 4
|
|
kv_cache_group_id = 0
|
|
|
|
pool = BlockPool(
|
|
num_gpu_blocks=8,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
enable_kv_cache_events=True,
|
|
)
|
|
|
|
req = make_request(
|
|
"req_null_parent",
|
|
prompt_token_ids=[10, 11, 12, 13],
|
|
block_size=block_size,
|
|
hash_fn=sha256,
|
|
)
|
|
assert len(req.block_hashes) == num_full_blocks
|
|
|
|
# Physical parent is `null_block` (no hash), while the logical parent hash
|
|
# still exists in `request.block_hashes[num_cached_blocks - 1]`.
|
|
assert pool.null_block.block_hash is None
|
|
new_blocks = pool.get_new_blocks(num_full_blocks - 1)
|
|
blocks = [
|
|
new_blocks[: num_cached_blocks - 1],
|
|
pool.null_block, # physical parent
|
|
*new_blocks[num_cached_blocks - 1 :],
|
|
]
|
|
|
|
pool.cache_full_blocks(
|
|
request=req,
|
|
blocks=blocks,
|
|
num_cached_blocks=num_cached_blocks,
|
|
num_full_blocks=num_full_blocks,
|
|
block_size=block_size,
|
|
kv_cache_group_id=kv_cache_group_id,
|
|
)
|
|
|
|
events = pool.take_events()
|
|
assert len(events) == 1
|
|
event = events[0]
|
|
assert isinstance(event, BlockStored)
|
|
|
|
expected_parent = kv_cache_utils.maybe_convert_block_hash(
|
|
req.block_hashes[num_cached_blocks - 1]
|
|
)
|
|
assert event.parent_block_hash == expected_parent
|
|
assert event.parent_block_hash is not None
|
|
|
|
expected_new_hashes = [
|
|
kv_cache_utils.maybe_convert_block_hash(h)
|
|
for h in req.block_hashes[num_cached_blocks:num_full_blocks]
|
|
]
|
|
assert event.block_hashes == expected_new_hashes
|
|
assert event.group_idx == kv_cache_group_id
|
|
assert event.kv_cache_spec_kind is None
|
|
assert event.kv_cache_spec_sliding_window is None
|
|
|
|
# Ensure we didn't accidentally assign a hash to the null block.
|
|
assert pool.null_block.block_hash is None
|
|
# Sanity check: newly cached physical blocks should have hashes assigned.
|
|
assert blocks[num_cached_blocks].block_hash is not None
|
|
assert blocks[num_full_blocks - 1].block_hash is not None
|
|
|
|
|
|
@pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
|
|
def test_kv_cache_events_with_lora(blocks_to_cache: int):
|
|
"""Test BlockStored events contain correct lora_id when using LoRA requests."""
|
|
block_size = 16
|
|
num_blocks = blocks_to_cache + 1
|
|
|
|
# Create KVCacheManager with events enabled
|
|
manager = make_kv_cache_manager(
|
|
make_kv_cache_config(block_size, num_blocks),
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
enable_kv_cache_events=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
|
|
# Test with LoRA request
|
|
lora_request = LoRARequest(
|
|
lora_name="test_lora", lora_int_id=42, lora_path="/test/path"
|
|
)
|
|
|
|
num_tokens = block_size * blocks_to_cache
|
|
req_with_lora = make_request(
|
|
"lora_req",
|
|
list(range(num_tokens)),
|
|
block_size,
|
|
sha256,
|
|
lora_request=lora_request,
|
|
)
|
|
|
|
# Allocate slots and get events
|
|
_ = manager.allocate_slots(req_with_lora, num_tokens)
|
|
events = manager.take_events()
|
|
|
|
# Verify BlockStored event contains correct lora_id
|
|
block_stored_event = events[-1]
|
|
assert isinstance(block_stored_event, BlockStored)
|
|
assert block_stored_event.lora_id == 42 # Should match lora_request.adapter_id
|
|
assert len(block_stored_event.block_hashes) == blocks_to_cache
|
|
assert block_stored_event.block_size == block_size
|
|
|
|
# Clean up
|
|
manager.free(req_with_lora)
|
|
|
|
# Test without LoRA request (should have lora_id=None)
|
|
req_without_lora = make_request(
|
|
"no_lora_req", list(range(num_tokens)), block_size, sha256
|
|
)
|
|
|
|
_ = manager.allocate_slots(req_without_lora, num_tokens)
|
|
events = manager.take_events()
|
|
|
|
block_stored_event = events[-1]
|
|
assert isinstance(block_stored_event, BlockStored)
|
|
assert block_stored_event.lora_id is None # Should be None when no LoRA request
|
|
assert len(block_stored_event.block_hashes) == blocks_to_cache
|
|
assert block_stored_event.block_size == block_size
|
|
|
|
|
|
@pytest.mark.parametrize("group_id", [0, 1, 2])
|
|
def test_block_stored_event_group_idx(group_id: int):
|
|
"""Test BlockStored events emitted by cache_full_blocks carry the correct
|
|
group_idx."""
|
|
block_size = 4
|
|
num_tokens = block_size * 2
|
|
|
|
manager = make_kv_cache_manager(
|
|
make_kv_cache_config_three_types(block_size, num_blocks=5),
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
enable_kv_cache_events=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
pool = manager.block_pool
|
|
|
|
req = make_request(
|
|
"req_grp_idx",
|
|
prompt_token_ids=list(range(num_tokens)),
|
|
block_size=block_size,
|
|
hash_fn=sha256,
|
|
)
|
|
|
|
blocks = pool.get_new_blocks(2)
|
|
pool.cache_full_blocks(
|
|
request=req,
|
|
blocks=blocks,
|
|
num_cached_blocks=0,
|
|
num_full_blocks=2,
|
|
block_size=block_size,
|
|
kv_cache_group_id=group_id,
|
|
)
|
|
|
|
events = manager.take_events()
|
|
assert len(events) == 1
|
|
assert isinstance(events[0], BlockStored)
|
|
assert events[0].group_idx == group_id
|
|
assert (
|
|
events[0].kv_cache_spec_kind
|
|
== [
|
|
KVCacheSpecKind.FULL_ATTENTION.value,
|
|
KVCacheSpecKind.SLIDING_WINDOW.value,
|
|
KVCacheSpecKind.MAMBA.value,
|
|
][group_id]
|
|
)
|
|
assert (
|
|
events[0].kv_cache_spec_sliding_window
|
|
== [
|
|
None,
|
|
2 * block_size,
|
|
None,
|
|
][group_id]
|
|
)
|
|
|
|
|
|
def test_block_stored_event_group_idx_multiple_groups():
|
|
"""
|
|
Test BlockStored events for separate HMA groups that each carry the
|
|
correct group_idx.
|
|
|
|
Simulates the HMA scenario where full-attention blocks (group 0) and
|
|
sliding-window blocks (group 1) are cached independently and must be
|
|
distinguishable by consumers doing HMA-aware prefix-cache routing.
|
|
"""
|
|
block_size = 4
|
|
num_tokens = block_size * 2
|
|
|
|
manager = make_kv_cache_manager(
|
|
KVCacheConfig(
|
|
num_blocks=5,
|
|
kv_cache_tensors=[],
|
|
kv_cache_groups=[
|
|
KVCacheGroupSpec(
|
|
["layer1"],
|
|
FullAttentionSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
),
|
|
),
|
|
KVCacheGroupSpec(
|
|
["layer2"],
|
|
SlidingWindowSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
sliding_window=128,
|
|
),
|
|
),
|
|
],
|
|
),
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
enable_kv_cache_events=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
pool = manager.block_pool
|
|
|
|
req = make_request(
|
|
"req_multi_grp",
|
|
prompt_token_ids=list(range(num_tokens)),
|
|
block_size=block_size,
|
|
hash_fn=sha256,
|
|
)
|
|
|
|
# Cache blocks for group 0 (full-attention)
|
|
blocks_grp0 = pool.get_new_blocks(2)
|
|
pool.cache_full_blocks(
|
|
request=req,
|
|
blocks=blocks_grp0,
|
|
num_cached_blocks=0,
|
|
num_full_blocks=2,
|
|
block_size=block_size,
|
|
kv_cache_group_id=0,
|
|
)
|
|
|
|
# Cache blocks for group 1 (sliding-window)
|
|
blocks_grp1 = pool.get_new_blocks(2)
|
|
pool.cache_full_blocks(
|
|
request=req,
|
|
blocks=blocks_grp1,
|
|
num_cached_blocks=0,
|
|
num_full_blocks=2,
|
|
block_size=block_size,
|
|
kv_cache_group_id=1,
|
|
)
|
|
|
|
events = manager.take_events()
|
|
assert len(events) == 2
|
|
assert isinstance(events[0], BlockStored)
|
|
assert events[0].group_idx == 0
|
|
assert events[0].kv_cache_spec_kind == KVCacheSpecKind.FULL_ATTENTION.value
|
|
assert events[0].kv_cache_spec_sliding_window is None
|
|
assert isinstance(events[1], BlockStored)
|
|
assert events[1].group_idx == 1
|
|
assert events[1].kv_cache_spec_kind == KVCacheSpecKind.SLIDING_WINDOW.value
|
|
assert events[1].kv_cache_spec_sliding_window == 128
|
|
|
|
|
|
def test_block_stored_event_group_idx_out_of_bounds(monkeypatch):
|
|
"""Out-of-range group_idx events are returned without metadata annotation."""
|
|
block_size = 4
|
|
manager = make_kv_cache_manager(
|
|
make_kv_cache_config(block_size, num_blocks=5),
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
enable_kv_cache_events=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
event = BlockStored(
|
|
block_hashes=[1],
|
|
parent_block_hash=None,
|
|
token_ids=list(range(block_size)),
|
|
block_size=block_size,
|
|
lora_id=None,
|
|
medium=None,
|
|
lora_name=None,
|
|
group_idx=1,
|
|
)
|
|
manager.block_pool.kv_event_queue.append(event)
|
|
warnings = []
|
|
|
|
def collect_warning(message, *args, **kwargs):
|
|
del kwargs
|
|
warnings.append(message % args if args else message)
|
|
|
|
monkeypatch.setattr(kv_cache_manager.logger, "warning", collect_warning)
|
|
events = manager.take_events()
|
|
|
|
assert events == [event]
|
|
assert event.kv_cache_spec_kind is None
|
|
assert event.kv_cache_spec_sliding_window is None
|
|
assert warnings == ["Group index `1` not in KV cache metadata"]
|
|
|
|
|
|
@pytest.mark.parametrize("group_id", [0, 1, 2])
|
|
def test_block_removed_event_group_idx(group_id: int):
|
|
"""
|
|
Test BlockRemoved events emitted on eviction carry the group_idx extracted
|
|
from the evicted block's BlockHashWithGroupId via get_group_id().
|
|
"""
|
|
block_size = 4
|
|
num_tokens = block_size * 2
|
|
|
|
# null block + 4 usable; allocate all 4, cache 2, free all, re-allocate
|
|
# all 4 so the 2 cached blocks are forced through _maybe_evict_cached_block.
|
|
pool = BlockPool(
|
|
num_gpu_blocks=5,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
enable_kv_cache_events=True,
|
|
)
|
|
|
|
req = make_request(
|
|
"req_evict_grp",
|
|
prompt_token_ids=list(range(num_tokens)),
|
|
block_size=block_size,
|
|
hash_fn=sha256,
|
|
)
|
|
|
|
# Allocate all usable blocks and cache the first two for the target group.
|
|
all_blocks = pool.get_new_blocks(4)
|
|
pool.cache_full_blocks(
|
|
request=req,
|
|
blocks=all_blocks,
|
|
num_cached_blocks=0,
|
|
num_full_blocks=2,
|
|
block_size=block_size,
|
|
kv_cache_group_id=group_id,
|
|
)
|
|
|
|
# Drain the BlockStored events so only eviction events remain later.
|
|
pool.take_events()
|
|
|
|
# Return all blocks to the free queue so they become eviction candidates.
|
|
pool.free_blocks(all_blocks)
|
|
|
|
# Re-allocate all blocks; the two with hashes trigger BlockRemoved events.
|
|
pool.get_new_blocks(4)
|
|
|
|
events = pool.take_events()
|
|
removed_events = [e for e in events if isinstance(e, BlockRemoved)]
|
|
|
|
assert len(removed_events) == 2
|
|
for event in removed_events:
|
|
assert event.group_idx == group_id
|
|
|
|
|
|
def test_eagle_enabled_removes_last_block():
|
|
"""Verify Eagle does NOT remove blocks when request
|
|
length is divisible by block size."""
|
|
block_size = 16
|
|
manager = make_kv_cache_manager(
|
|
make_kv_cache_config(block_size, num_blocks=10),
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
use_eagle=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
|
|
# Request with 3 full blocks (48 tokens)
|
|
token_ids = [0] * (3 * block_size)
|
|
req = make_request("divisible_request", token_ids, block_size, sha256)
|
|
|
|
# Prime the cache
|
|
computed_blocks, _ = manager.get_computed_blocks(req)
|
|
manager.allocate_slots(
|
|
req, len(token_ids), len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
manager.free(req)
|
|
|
|
# New request with same tokens + Eagle enabled
|
|
req_eagle = make_request("eagle_divisible", token_ids, block_size, sha256)
|
|
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
|
|
|
# Should retain 1 block:
|
|
# 1. Original 3 blocks → pop last hash → 2 matched blocks
|
|
# 2. drop last matched block → 1 remaining block
|
|
assert len(computed_blocks.blocks[0]) == 1
|
|
assert num_tokens == 1 * block_size # 16 tokens
|
|
|
|
|
|
def test_eagle_with_partial_blocks():
|
|
"""Test Eagle behavior with requests containing partial blocks."""
|
|
block_size = 16
|
|
manager = make_kv_cache_manager(
|
|
make_kv_cache_config(block_size, num_blocks=10),
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
use_eagle=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
# 2 full blocks + 5 tokens (non-divisible length)
|
|
token_ids = [0] * (2 * block_size + 5)
|
|
req = make_request("partial_block_test", token_ids, block_size, sha256)
|
|
|
|
# Prime the cache
|
|
computed_blocks, _ = manager.get_computed_blocks(req)
|
|
manager.allocate_slots(
|
|
req, len(token_ids), len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
manager.free(req)
|
|
|
|
# New request with Eagle enabled
|
|
req_eagle = make_request("partial_eagle", token_ids, block_size, sha256)
|
|
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
|
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
|
|
assert len(computed_blocks.blocks[0]) == 1
|
|
assert num_tokens == 1 * block_size
|
|
|
|
|
|
def test_eagle_with_sliding_window():
|
|
"""Test Eagle behavior with sliding window."""
|
|
block_size = 16
|
|
sliding_window_spec = SlidingWindowSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
sliding_window=block_size,
|
|
)
|
|
manager = make_kv_cache_manager(
|
|
KVCacheConfig(
|
|
num_blocks=10,
|
|
kv_cache_tensors=[],
|
|
kv_cache_groups=[KVCacheGroupSpec(["layer"], sliding_window_spec)],
|
|
),
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
use_eagle=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
|
|
# 2 full blocks + 5 tokens (non-divisible length)
|
|
token_ids = [0] * (2 * block_size + 5)
|
|
req = make_request("partial_block_test", token_ids, block_size, sha256)
|
|
|
|
# Prime the cache
|
|
computed_blocks, _ = manager.get_computed_blocks(req)
|
|
manager.allocate_slots(
|
|
req, len(token_ids), len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
# record the block hash of the first block in the request for later use
|
|
block_hash_first_block = req.block_hashes[0]
|
|
assert block_hash_first_block is not None
|
|
manager.free(req)
|
|
|
|
# New request with Eagle enabled
|
|
req_eagle = make_request("partial_eagle", token_ids, block_size, sha256)
|
|
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
|
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
|
|
assert len(computed_blocks.blocks[0]) == 1
|
|
assert num_tokens == 1 * block_size
|
|
|
|
# Evict the first block in the request
|
|
assert (
|
|
manager.block_pool.get_cached_block(
|
|
block_hash_first_block, kv_cache_group_ids=[0]
|
|
)
|
|
is not None
|
|
)
|
|
manager.block_pool.cached_block_hash_to_block._cache.pop(
|
|
make_block_hash_with_group_id(block_hash_first_block, 0)
|
|
)
|
|
|
|
# New request
|
|
req_after_evict = make_request(
|
|
"partial_eagle_after_evict", token_ids, block_size, sha256
|
|
)
|
|
computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict)
|
|
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
|
|
# not considered. But after dropping the last matched block due to eagle,
|
|
# there will be no matched prefix.
|
|
assert len(computed_blocks.blocks[0]) == 0
|
|
assert num_tokens == 0
|
|
|
|
|
|
def test_eagle_swa_alignment_caches_extra_block():
|
|
"""Regression: SWA + EAGLE with `sliding_window <= alignment_tokens`.
|
|
|
|
When the cache-hit alignment (lcm of per-group block sizes) is larger than
|
|
the SWA window, the SWA mask only kept the last block of each aligned
|
|
segment. EAGLE/MTP lookup needs ``tail + 1`` contiguous cached blocks and
|
|
that +1 block lives at the next segment's first position, which was left
|
|
uncached. The fix caches that extra block when ``use_eagle=True``.
|
|
"""
|
|
block_size = 8
|
|
# Full group uses 4 * block_size, so lcm/alignment is 4 * block_size.
|
|
# SWA group has sliding_window = block_size (i.e., tail = 1 block).
|
|
# Without the fix, the second cached block needed for the EAGLE 2-block
|
|
# match never exists -> EAGLE cache hit fails entirely.
|
|
kv_cache_config = KVCacheConfig(
|
|
num_blocks=100,
|
|
kv_cache_tensors=[],
|
|
kv_cache_groups=[
|
|
KVCacheGroupSpec(
|
|
["full"],
|
|
FullAttentionSpec(
|
|
block_size=4 * block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float16,
|
|
),
|
|
),
|
|
KVCacheGroupSpec(
|
|
["swa_mtp"],
|
|
SlidingWindowSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
sliding_window=block_size,
|
|
),
|
|
is_eagle_group=True,
|
|
),
|
|
],
|
|
)
|
|
manager = make_kv_cache_manager(
|
|
kv_cache_config=kv_cache_config,
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
use_eagle=True,
|
|
)
|
|
|
|
# Prime the cache with a long prompt (16 swa blocks = 4 aligned segments).
|
|
token_ids = [i for i in range(16) for _ in range(block_size)]
|
|
req0 = make_request("0", token_ids, block_size, sha256)
|
|
computed_blocks, _ = manager.get_computed_blocks(req0)
|
|
blocks = manager.allocate_slots(
|
|
req0,
|
|
len(token_ids),
|
|
len(computed_blocks.blocks[0]) * block_size,
|
|
computed_blocks,
|
|
)
|
|
assert blocks is not None
|
|
manager.free(req0)
|
|
|
|
# Second request with identical prompt should find an EAGLE cache hit.
|
|
# Without the fix, ``num_computed_tokens`` is 0; with the fix, it lands at
|
|
# an alignment boundary (multiple of 32 tokens, minus the EAGLE drop).
|
|
req1 = make_request("1", token_ids, block_size, sha256)
|
|
_, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
assert num_computed_tokens > 0, (
|
|
"EAGLE + SWA with sliding_window <= alignment failed to find any "
|
|
"cache hit; the +1 block past each segment boundary must be cached."
|
|
)
|
|
# Each aligned segment contributes 4 * block_size = 32 tokens; EAGLE drops
|
|
# the last block (block_size tokens) from the hit.
|
|
assert num_computed_tokens % (4 * block_size) == 0
|
|
|
|
|
|
def test_eagle_swa_boundary_caches_post_boundary_block():
|
|
"""EAGLE + SWA must cache the first block after an alignment boundary.
|
|
|
|
A 40-token computed prefix with 8-token SWA blocks and 32-token hybrid
|
|
alignment needs SWA blocks 3 and 4 cached to reuse a 32-token prefix:
|
|
block 3 is the segment tail, and block 4 is the EAGLE lookahead block
|
|
that gets dropped after lookup.
|
|
"""
|
|
block_size = 8
|
|
kv_cache_config = KVCacheConfig(
|
|
num_blocks=100,
|
|
kv_cache_tensors=[],
|
|
kv_cache_groups=[
|
|
KVCacheGroupSpec(
|
|
["full"],
|
|
FullAttentionSpec(
|
|
block_size=4 * block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float16,
|
|
),
|
|
),
|
|
KVCacheGroupSpec(
|
|
["swa_mtp"],
|
|
SlidingWindowSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
sliding_window=block_size,
|
|
),
|
|
is_eagle_group=True,
|
|
),
|
|
],
|
|
)
|
|
manager = make_kv_cache_manager(
|
|
kv_cache_config=kv_cache_config,
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
use_eagle=True,
|
|
)
|
|
|
|
token_ids = [i for i in range(5) for _ in range(block_size)]
|
|
req0 = make_request("0", token_ids, block_size, sha256)
|
|
computed_blocks, _ = manager.get_computed_blocks(req0)
|
|
blocks = manager.allocate_slots(
|
|
req0,
|
|
len(token_ids),
|
|
len(computed_blocks.blocks[0]) * block_size,
|
|
computed_blocks,
|
|
)
|
|
assert blocks is not None
|
|
|
|
pool = manager.block_pool
|
|
assert pool.get_cached_block(req0.block_hashes[3], kv_cache_group_ids=[1])
|
|
assert pool.get_cached_block(req0.block_hashes[4], kv_cache_group_ids=[1])
|
|
manager.free(req0)
|
|
|
|
req1 = make_request("1", token_ids + [999], block_size, sha256)
|
|
_, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
assert num_computed_tokens == 4 * block_size
|
|
|
|
|
|
def test_eagle_grouped_swa_siblings_use_same_cache_mask():
|
|
"""Grouped SWA siblings must cache the EAGLE lookahead block together."""
|
|
block_size = 8
|
|
swa_spec = SlidingWindowSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
sliding_window=block_size,
|
|
)
|
|
kv_cache_config = KVCacheConfig(
|
|
num_blocks=100,
|
|
kv_cache_tensors=[],
|
|
kv_cache_groups=[
|
|
KVCacheGroupSpec(
|
|
["full"],
|
|
FullAttentionSpec(
|
|
block_size=4 * block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float16,
|
|
),
|
|
),
|
|
KVCacheGroupSpec(["swa_main"], swa_spec),
|
|
KVCacheGroupSpec(["swa_mtp"], swa_spec, is_eagle_group=True),
|
|
],
|
|
)
|
|
manager = make_kv_cache_manager(
|
|
kv_cache_config=kv_cache_config,
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
use_eagle=True,
|
|
)
|
|
|
|
token_ids = [i for i in range(9) for _ in range(block_size)]
|
|
req0 = make_request("0", token_ids, block_size, sha256)
|
|
computed_blocks, _ = manager.get_computed_blocks(req0)
|
|
blocks = manager.allocate_slots(
|
|
req0,
|
|
len(token_ids),
|
|
len(computed_blocks.blocks[0]) * block_size,
|
|
computed_blocks,
|
|
)
|
|
assert blocks is not None
|
|
|
|
pool = manager.block_pool
|
|
assert pool.get_cached_block(req0.block_hashes[4], kv_cache_group_ids=[1, 2])
|
|
assert pool.get_cached_block(req0.block_hashes[8], kv_cache_group_ids=[1, 2])
|
|
manager.free(req0)
|
|
|
|
req1 = make_request("1", token_ids + [999], block_size, sha256)
|
|
_, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
assert num_computed_tokens == 8 * block_size
|
|
|
|
|
|
def test_different_block_size():
|
|
block_size = 16
|
|
# full attention and sliding window attention layers have the same page size:
|
|
# (32 tokens/block * float16 token, vs. 16 tokens/block * float32 token)
|
|
kv_cache_config = KVCacheConfig(
|
|
num_blocks=100,
|
|
kv_cache_tensors=[],
|
|
kv_cache_groups=[
|
|
KVCacheGroupSpec(
|
|
["layer1"],
|
|
FullAttentionSpec(
|
|
block_size=block_size * 2,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float16,
|
|
),
|
|
),
|
|
KVCacheGroupSpec(
|
|
["layer2"],
|
|
SlidingWindowSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
sliding_window=2 * block_size,
|
|
),
|
|
),
|
|
],
|
|
)
|
|
manager = make_kv_cache_manager(
|
|
kv_cache_config=kv_cache_config,
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
|
|
# 10 blocks of 16 tokens each. Token ids are not strictly aligned for each block.
|
|
common_token_ids = [i for i in range(10) for _ in range(block_size)]
|
|
|
|
req0 = make_request("0", common_token_ids, block_size, sha256)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
|
assert not computed_blocks.blocks[0]
|
|
assert not computed_blocks.blocks[1]
|
|
assert num_computed_tokens == 0
|
|
blocks = manager.allocate_slots(
|
|
req0, 7 * block_size, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
|
)
|
|
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7, 8, 9, 10, 11])
|
|
req1 = make_request("1", common_token_ids[: 7 * block_size + 1], block_size, sha256)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
assert len(computed_blocks.blocks[0]) == 3
|
|
assert len(computed_blocks.blocks[1]) == 6
|
|
assert num_computed_tokens == 6 * 16
|
|
|
|
req2 = make_request("2", common_token_ids[: 6 * block_size + 1], block_size, sha256)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
|
assert len(computed_blocks.blocks[0]) == 3
|
|
assert len(computed_blocks.blocks[1]) == 6
|
|
assert num_computed_tokens == 6 * 16
|
|
|
|
# Evict some blocks to make sliding window cache hit length 5*16
|
|
# But should return 4 * 16 because full attention cache hit length must be
|
|
# a multiple of 32
|
|
manager.block_pool.cached_block_hash_to_block.pop(
|
|
make_block_hash_with_group_id(req1.block_hashes[6], 1), 11
|
|
)
|
|
manager.block_pool.cached_block_hash_to_block.pop(
|
|
make_block_hash_with_group_id(req1.block_hashes[5], 1), 10
|
|
)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
assert len(computed_blocks.blocks[0]) == 2
|
|
assert len(computed_blocks.blocks[1]) == 4
|
|
assert num_computed_tokens == 4 * 16
|
|
|
|
|
|
def test_hybrid_cache_blocks_swa_tail_window_only():
|
|
"""Within each lcm-aligned segment, SWA's ``find_longest_cache_hit`` only
|
|
returns the trailing ``ceil((sliding_window - 1) / block_size)`` blocks
|
|
(its right-to-left scan stops once a contiguous match is found). Blocks
|
|
earlier in the segment can never serve a hit, so
|
|
``HybridKVCacheCoordinator.cache_blocks`` should skip them rather than
|
|
polluting the prefix-cache hash map."""
|
|
block_size = 8
|
|
# Full attn block_size=32, SWA block_size=8, sw=8 -> lcm=32.
|
|
# tail = ceil(7/8) = 1; per_segment = 32/8 = 4.
|
|
# Per-segment template = [F, F, F, T]; only the last SWA block in each
|
|
# 32-token segment ends up in the prefix-cache hash map.
|
|
kv_cache_config = KVCacheConfig(
|
|
num_blocks=100,
|
|
kv_cache_tensors=[],
|
|
kv_cache_groups=[
|
|
KVCacheGroupSpec(
|
|
["layer1"],
|
|
FullAttentionSpec(
|
|
block_size=4 * block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float16,
|
|
),
|
|
),
|
|
KVCacheGroupSpec(
|
|
["layer2"],
|
|
SlidingWindowSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
sliding_window=block_size,
|
|
),
|
|
),
|
|
],
|
|
)
|
|
manager = make_kv_cache_manager(
|
|
kv_cache_config=kv_cache_config,
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
|
|
# 8 hash-blocks of 8 tokens (64 tokens, two lcm-aligned segments).
|
|
token_ids = [i for i in range(8) for _ in range(block_size)]
|
|
req = make_request("0", token_ids, block_size, sha256)
|
|
computed_blocks, _ = manager.get_computed_blocks(req)
|
|
blocks = manager.allocate_slots(
|
|
req,
|
|
8 * block_size,
|
|
len(computed_blocks.blocks[0]) * block_size,
|
|
computed_blocks,
|
|
)
|
|
assert blocks is not None
|
|
assert len(req.block_hashes) == 8
|
|
|
|
pool = manager.block_pool
|
|
# SWA group_id=1: only hash 3 and hash 7 (the last block of each
|
|
# 32-token segment) should be cached. Hashes 0,1,2,4,5,6 cannot serve
|
|
# a hit at any lcm-aligned length, so they must NOT be cached.
|
|
expected_cached = {3, 7}
|
|
for i in range(8):
|
|
cached = pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[1])
|
|
if i in expected_cached:
|
|
assert cached is not None, f"SWA hash {i} should be cached"
|
|
else:
|
|
assert cached is None, (
|
|
f"SWA hash {i} cannot serve any lcm-aligned hit; should not be cached"
|
|
)
|
|
|
|
|
|
def test_hybrid_cache_blocks_clamped_to_lcm():
|
|
"""HybridKVCacheCoordinator.cache_blocks() clamps to scheduler_block_size.
|
|
Chunks past the last lcm-aligned boundary can never participate in a
|
|
cache hit (find_longest_cache_hit always returns lcm-aligned hits), so
|
|
caching them only pollutes the prefix-cache hash map and keeps blocks
|
|
on the LRU list that could otherwise return to the free pool."""
|
|
block_size = 16
|
|
# Full attn block_size=32, SWA block_size=16 -> lcm=32.
|
|
kv_cache_config = KVCacheConfig(
|
|
num_blocks=100,
|
|
kv_cache_tensors=[],
|
|
kv_cache_groups=[
|
|
KVCacheGroupSpec(
|
|
["layer1"],
|
|
FullAttentionSpec(
|
|
block_size=block_size * 2,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float16,
|
|
),
|
|
),
|
|
KVCacheGroupSpec(
|
|
["layer2"],
|
|
SlidingWindowSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
sliding_window=2 * block_size,
|
|
),
|
|
),
|
|
],
|
|
)
|
|
manager = make_kv_cache_manager(
|
|
kv_cache_config=kv_cache_config,
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
|
|
# 7 hash-blocks of 16 tokens (112 tokens). With lcm=32 the clamp truncates
|
|
# to 96 tokens — SWA caches 6 hashes, full-attn caches 3.
|
|
token_ids = [i for i in range(7) for _ in range(block_size)]
|
|
req = make_request("0", token_ids, block_size, sha256)
|
|
computed_blocks, _ = manager.get_computed_blocks(req)
|
|
blocks = manager.allocate_slots(
|
|
req,
|
|
7 * block_size,
|
|
len(computed_blocks.blocks[0]) * block_size,
|
|
computed_blocks,
|
|
)
|
|
assert blocks is not None
|
|
assert len(req.block_hashes) == 7
|
|
|
|
pool = manager.block_pool
|
|
# SWA group_id=1: hashes 0..5 cached (6 blocks * 16 tokens = 96), hash 6
|
|
# spans tokens [96, 112) past the lcm boundary and must NOT be cached.
|
|
for i in range(6):
|
|
assert (
|
|
pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[1])
|
|
is not None
|
|
), f"SWA hash {i} should be cached"
|
|
assert pool.get_cached_block(req.block_hashes[6], kv_cache_group_ids=[1]) is None, (
|
|
"SWA hash 6 spans tokens past the lcm boundary; should not be cached"
|
|
)
|
|
|
|
|
|
def test_hybrid_local_kv_retention_interval_aligns_in_manager(monkeypatch):
|
|
"""Verify fixed intervals retain sparse tails plus the latest replay tail."""
|
|
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "64")
|
|
block_size = 8
|
|
kv_cache_config = KVCacheConfig(
|
|
num_blocks=100,
|
|
kv_cache_tensors=[],
|
|
kv_cache_groups=[
|
|
KVCacheGroupSpec(
|
|
["layer1"],
|
|
FullAttentionSpec(
|
|
block_size=4 * block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float16,
|
|
),
|
|
),
|
|
KVCacheGroupSpec(
|
|
["layer2"],
|
|
SlidingWindowSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
sliding_window=block_size,
|
|
),
|
|
),
|
|
],
|
|
)
|
|
manager = make_kv_cache_manager(
|
|
kv_cache_config=kv_cache_config,
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
|
|
# The SWA manager uses the configured 64-token interval (a multiple of the
|
|
# 32-token lcm_block_size) as its retention segment. For this 128-token
|
|
# prompt, the retained SWA tails are the 64-token interval boundary, the
|
|
# 96-token replay boundary, and the 128-token interval boundary.
|
|
token_ids = [i for i in range(16) for _ in range(block_size)]
|
|
req = make_request("0", token_ids, block_size, sha256)
|
|
computed_blocks, _ = manager.get_computed_blocks(req)
|
|
blocks = manager.allocate_slots(
|
|
req,
|
|
len(token_ids),
|
|
len(computed_blocks.blocks[0]) * block_size,
|
|
computed_blocks,
|
|
)
|
|
assert blocks is not None
|
|
|
|
pool = manager.block_pool
|
|
expected_swa_cached = {7, 11, 15}
|
|
for i in range(16):
|
|
cached = pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[1])
|
|
if i in expected_swa_cached:
|
|
assert cached is not None, f"SWA hash {i} should be cached"
|
|
else:
|
|
assert cached is None, f"SWA hash {i} should not be cached"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"interval, expected_match",
|
|
[
|
|
# scheduler_block_size is 32 (= lcm(4*8, 8)); 33 is not a multiple of it.
|
|
("33", "multiple of scheduler_block_size"),
|
|
# A negative multiple (-32 % 32 == 0) must still be rejected explicitly,
|
|
# otherwise it would pass the modulo check and silently degrade to dense.
|
|
("-32", "non-negative"),
|
|
],
|
|
)
|
|
def test_hybrid_local_kv_retention_interval_rejects_invalid(
|
|
monkeypatch, interval, expected_match
|
|
):
|
|
"""A retention interval that is negative or not a multiple of
|
|
scheduler_block_size errors out at construction time."""
|
|
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", interval)
|
|
block_size = 8
|
|
kv_cache_config = KVCacheConfig(
|
|
num_blocks=100,
|
|
kv_cache_tensors=[],
|
|
kv_cache_groups=[
|
|
KVCacheGroupSpec(
|
|
["layer1"],
|
|
FullAttentionSpec(
|
|
block_size=4 * block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float16,
|
|
),
|
|
),
|
|
KVCacheGroupSpec(
|
|
["layer2"],
|
|
SlidingWindowSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
sliding_window=block_size,
|
|
),
|
|
),
|
|
],
|
|
)
|
|
with pytest.raises(ValueError, match=expected_match):
|
|
make_kv_cache_manager(
|
|
kv_cache_config=kv_cache_config,
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
|
|
|
|
def test_hybrid_local_kv_retention_interval_survives_recycling(monkeypatch):
|
|
"""Verify retained local checkpoints are reused after block recycling."""
|
|
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "1024")
|
|
hash_block_size = 4
|
|
kv_cache_config = KVCacheConfig(
|
|
num_blocks=800,
|
|
kv_cache_tensors=[],
|
|
kv_cache_groups=[
|
|
KVCacheGroupSpec(
|
|
["full"],
|
|
MLAAttentionSpec(
|
|
block_size=64 * hash_block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.uint8,
|
|
compress_ratio=4,
|
|
),
|
|
),
|
|
KVCacheGroupSpec(
|
|
["swa"],
|
|
SlidingWindowSpec(
|
|
block_size=16 * hash_block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
sliding_window=512,
|
|
),
|
|
),
|
|
KVCacheGroupSpec(
|
|
["c128"],
|
|
SlidingWindowSpec(
|
|
block_size=2 * hash_block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
sliding_window=128,
|
|
),
|
|
),
|
|
KVCacheGroupSpec(
|
|
["c4"],
|
|
SlidingWindowSpec(
|
|
block_size=hash_block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
sliding_window=8,
|
|
),
|
|
),
|
|
],
|
|
)
|
|
manager = make_kv_cache_manager(
|
|
kv_cache_config=kv_cache_config,
|
|
max_model_len=4096,
|
|
enable_caching=True,
|
|
hash_block_size=hash_block_size,
|
|
)
|
|
|
|
def fill_request(request_id: str, token_offset: int) -> list[int]:
|
|
token_ids = [
|
|
token_offset + i for i in range(1024) for _ in range(hash_block_size)
|
|
]
|
|
fill_req = make_request(request_id, token_ids, hash_block_size, sha256)
|
|
while fill_req.num_computed_tokens < len(token_ids):
|
|
num_new_tokens = min(512, len(token_ids) - fill_req.num_computed_tokens)
|
|
blocks = manager.allocate_slots(fill_req, num_new_tokens)
|
|
assert blocks is not None
|
|
fill_req.num_computed_tokens += num_new_tokens
|
|
manager.free(fill_req)
|
|
return token_ids
|
|
|
|
token_ids = fill_request("fill_0", 0)
|
|
replay_req = make_request("replay", token_ids[:1800], hash_block_size, sha256)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(replay_req)
|
|
assert num_computed_tokens == 1024
|
|
assert [len(blocks) for blocks in computed_blocks.blocks] == [4, 16, 128, 256]
|
|
|
|
fill_request("fill_1", 100_000)
|
|
replay_req = make_request("replay_again", token_ids[:1800], hash_block_size, sha256)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(replay_req)
|
|
assert num_computed_tokens == 1024
|
|
assert [len(blocks) for blocks in computed_blocks.blocks] == [4, 16, 128, 256]
|
|
|
|
|
|
def test_hybrid_local_kv_retention_latest_only_reuses_replay_boundary(monkeypatch):
|
|
"""Verify latest-only retention reuses only the replayable prompt boundary."""
|
|
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "0")
|
|
block_size = 8
|
|
kv_cache_config = KVCacheConfig(
|
|
num_blocks=100,
|
|
kv_cache_tensors=[],
|
|
kv_cache_groups=[
|
|
KVCacheGroupSpec(
|
|
["layer1"],
|
|
FullAttentionSpec(
|
|
block_size=4 * block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float16,
|
|
),
|
|
),
|
|
KVCacheGroupSpec(
|
|
["layer2"],
|
|
SlidingWindowSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
sliding_window=block_size,
|
|
),
|
|
),
|
|
],
|
|
)
|
|
manager = make_kv_cache_manager(
|
|
kv_cache_config=kv_cache_config,
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
|
|
token_ids = [i for i in range(16) for _ in range(block_size)]
|
|
req0 = make_request("0", token_ids, block_size, sha256)
|
|
computed_blocks, _ = manager.get_computed_blocks(req0)
|
|
blocks = manager.allocate_slots(
|
|
req0,
|
|
len(token_ids),
|
|
len(computed_blocks.blocks[0]) * block_size,
|
|
computed_blocks,
|
|
)
|
|
assert blocks is not None
|
|
|
|
pool = manager.block_pool
|
|
expected_swa_cached = {11}
|
|
for i in range(16):
|
|
cached = pool.get_cached_block(req0.block_hashes[i], kv_cache_group_ids=[1])
|
|
if i in expected_swa_cached:
|
|
assert cached is not None, f"SWA hash {i} should be cached"
|
|
else:
|
|
assert cached is None, f"SWA hash {i} should not be cached"
|
|
|
|
manager.free(req0)
|
|
retained_swa_block = pool.get_cached_block(req0.block_hashes[11], [1])
|
|
assert retained_swa_block is not None
|
|
assert retained_swa_block[0].ref_cnt == 0
|
|
|
|
req1 = make_request("1", token_ids, block_size, sha256)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
# Full prompt hits intentionally recompute the final block for logits, so
|
|
# the longest usable hit is the previous LCM boundary: 96 tokens.
|
|
assert num_computed_tokens == 12 * block_size
|
|
assert len(computed_blocks.blocks[1]) == 12
|
|
|
|
shorter_req = make_request("2", token_ids[: 12 * block_size], block_size, sha256)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(shorter_req)
|
|
assert num_computed_tokens == 0
|
|
assert len(computed_blocks.blocks[1]) == 0
|
|
|
|
|
|
def test_hybrid_local_kv_retention_mtp_reuses_latest_boundary(monkeypatch):
|
|
"""Verify MTP/EAGLE SWA retention keeps the extra proof block.
|
|
|
|
EAGLE/MTP lookup matches one additional local block after the returned
|
|
prefix and then drops it. Sparse retention must therefore cache the normal
|
|
local tail at the latest replay boundary plus one extra SWA block.
|
|
"""
|
|
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "0")
|
|
block_size = 8
|
|
kv_cache_config = KVCacheConfig(
|
|
num_blocks=100,
|
|
kv_cache_tensors=[],
|
|
kv_cache_groups=[
|
|
KVCacheGroupSpec(
|
|
["full"],
|
|
FullAttentionSpec(
|
|
block_size=4 * block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float16,
|
|
),
|
|
),
|
|
KVCacheGroupSpec(
|
|
["swa_mtp"],
|
|
SlidingWindowSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
sliding_window=block_size,
|
|
),
|
|
is_eagle_group=True,
|
|
),
|
|
],
|
|
)
|
|
manager = make_kv_cache_manager(
|
|
kv_cache_config=kv_cache_config,
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
use_eagle=True,
|
|
)
|
|
|
|
# 127 tokens: latest replay boundary is floor((127 - 1) / 32) * 32 = 96.
|
|
# The EAGLE/MTP SWA lookup group must cache the local tail ending at
|
|
# 104 tokens, and that tail is two 8-token blocks wide: hashes 11 and 12.
|
|
token_ids = [i for i in range(15) for _ in range(block_size)] + [15] * 7
|
|
req0 = make_request("0", token_ids, block_size, sha256)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
|
assert num_computed_tokens == 0
|
|
blocks = manager.allocate_slots(
|
|
req0,
|
|
len(token_ids),
|
|
num_computed_tokens,
|
|
computed_blocks,
|
|
)
|
|
assert blocks is not None
|
|
|
|
pool = manager.block_pool
|
|
expected_swa_cached = {11, 12}
|
|
for i in range(15):
|
|
cached = pool.get_cached_block(req0.block_hashes[i], kv_cache_group_ids=[1])
|
|
if i in expected_swa_cached:
|
|
assert cached is not None, f"SWA hash {i} should be cached"
|
|
else:
|
|
assert cached is None, f"SWA hash {i} should not be cached"
|
|
|
|
manager.free(req0)
|
|
|
|
req1 = make_request("1", token_ids, block_size, sha256)
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
assert num_computed_tokens == 12 * block_size
|
|
assert [len(blocks) for blocks in computed_blocks.blocks] == [3, 12]
|
|
|
|
|
|
def test_block_lookup_cache_single_block_per_key():
|
|
cache = BlockHashToBlockMap()
|
|
key0 = BlockHashWithGroupId(b"hash0")
|
|
key1 = BlockHashWithGroupId(b"hash1")
|
|
key2 = BlockHashWithGroupId(b"hash2")
|
|
block0 = KVCacheBlock(0)
|
|
block1 = KVCacheBlock(1)
|
|
|
|
assert cache.get_one_block(key0) is None
|
|
assert cache.get_one_block(key1) is None
|
|
assert cache.get_one_block(key2) is None
|
|
# key0 inserted
|
|
cache.insert(key0, block0)
|
|
assert cache.get_one_block(key0) is block0
|
|
assert cache.get_one_block(key1) is None
|
|
assert cache.get_one_block(key2) is None
|
|
# key1 inserted
|
|
cache.insert(key1, block1)
|
|
assert cache.get_one_block(key0) is block0
|
|
assert cache.get_one_block(key1) is block1
|
|
assert cache.get_one_block(key2) is None
|
|
# No block popped due to block_id mismatch
|
|
assert cache.pop(key0, 100) is None
|
|
assert cache.get_one_block(key0) is block0
|
|
assert cache.get_one_block(key1) is block1
|
|
assert cache.get_one_block(key2) is None
|
|
# block popped with (key0, block ID 0)
|
|
assert cache.pop(key0, 0) is block0
|
|
assert cache.get_one_block(key0) is None
|
|
assert cache.get_one_block(key1) is block1
|
|
assert cache.get_one_block(key2) is None
|
|
# No block popped due to block_id mismatch
|
|
assert cache.pop(key0, 1) is None
|
|
assert cache.get_one_block(key0) is None
|
|
assert cache.get_one_block(key1) is block1
|
|
assert cache.get_one_block(key2) is None
|
|
# block popped with (key1, block ID 1)
|
|
assert cache.pop(key1, 1) is block1
|
|
assert cache.get_one_block(key0) is None
|
|
assert cache.get_one_block(key1) is None
|
|
assert cache.get_one_block(key2) is None
|
|
|
|
|
|
def test_block_lookup_cache_multi_blocks_per_key():
|
|
cache = BlockHashToBlockMap()
|
|
key0 = BlockHashWithGroupId(b"hash0")
|
|
key1 = BlockHashWithGroupId(b"hash1")
|
|
block00 = KVCacheBlock(0)
|
|
block01 = KVCacheBlock(1)
|
|
block10 = KVCacheBlock(10)
|
|
block11 = KVCacheBlock(11)
|
|
|
|
assert cache.get_one_block(key0) is None
|
|
assert cache.get_one_block(key1) is None
|
|
|
|
cache.insert(key0, block00)
|
|
cache.insert(key0, block01)
|
|
cache.insert(key1, block10)
|
|
cache.insert(key1, block11)
|
|
|
|
assert cache.get_one_block(key0) is block00
|
|
assert cache.pop(key0, 0) is block00
|
|
assert cache.get_one_block(key0) is block01
|
|
assert cache.pop(key0, 1) is block01
|
|
assert cache.get_one_block(key0) is None
|
|
assert cache.pop(key0, 2) is None
|
|
|
|
assert cache.get_one_block(key1) is block10
|
|
assert cache.pop(key1, 10) is block10
|
|
assert cache.get_one_block(key1) is block11
|
|
assert cache.pop(key1, 11) is block11
|
|
assert cache.get_one_block(key1) is None
|
|
assert cache.pop(key1, 12) is None
|
|
|
|
|
|
def test_can_fit_full_sequence_swa_cap_admits_long_prompt():
|
|
"""Hybrid full+SWA model with a pool sized at the startup minimum should
|
|
admit a prompt longer than the SWA cap, because SlidingWindowManager
|
|
recycles blocks during chunked prefill (issue #39734)."""
|
|
block_size = 16
|
|
sliding_window = 4 * block_size # 64 tokens
|
|
max_num_batched_tokens = 8 * block_size # 128 tokens
|
|
max_model_len = 64 * block_size # 1024 tokens — much larger than the SWA cap
|
|
# Startup pool sizing: full demands cdiv(max_model_len, bs) = 64 blocks,
|
|
# SWA demands cdiv(SW-1+max_batched, bs) + 1 = cdiv(191, 16) + 1 = 13.
|
|
# Pool minimum = 64 + 13 = 77; +1 for the null block.
|
|
num_blocks = 64 + 13 + 1
|
|
|
|
config = KVCacheConfig(
|
|
num_blocks=num_blocks,
|
|
kv_cache_tensors=[],
|
|
kv_cache_groups=[
|
|
KVCacheGroupSpec(
|
|
["layer_full"],
|
|
FullAttentionSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
),
|
|
),
|
|
KVCacheGroupSpec(
|
|
["layer_swa"],
|
|
SlidingWindowSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
sliding_window=sliding_window,
|
|
),
|
|
),
|
|
],
|
|
)
|
|
|
|
manager = make_kv_cache_manager(
|
|
config,
|
|
max_model_len=max_model_len,
|
|
max_num_batched_tokens=max_num_batched_tokens,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
|
|
# A prompt that is shorter than max_model_len but longer than SW + chunk:
|
|
# cdiv(prompt_len, bs) = 32 blocks. Without the cap, admission would
|
|
# demand 32 (full) + 32 (SWA) = 64 blocks. With the cap, SWA contributes
|
|
# only 13, so total = 32 + 13 = 45 ≤ pool size.
|
|
prompt_len = 32 * block_size
|
|
req = make_request("long", list(range(prompt_len)), block_size, sha256)
|
|
|
|
assert (
|
|
manager.allocate_slots(req, block_size, full_sequence_must_fit=True) is not None
|
|
)
|
|
|
|
|
|
def test_can_fit_full_sequence_full_attention_still_gates_oversized():
|
|
"""The cap only loosens the SWA group; a prompt that exceeds the
|
|
full-attention pool capacity must still be rejected."""
|
|
block_size = 16
|
|
sliding_window = 4 * block_size
|
|
max_num_batched_tokens = 8 * block_size
|
|
max_model_len = 64 * block_size
|
|
# Provide a tiny pool — even a small prompt should be rejected.
|
|
num_blocks = 5
|
|
|
|
config = KVCacheConfig(
|
|
num_blocks=num_blocks,
|
|
kv_cache_tensors=[],
|
|
kv_cache_groups=[
|
|
KVCacheGroupSpec(
|
|
["layer_full"],
|
|
FullAttentionSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
),
|
|
),
|
|
KVCacheGroupSpec(
|
|
["layer_swa"],
|
|
SlidingWindowSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
sliding_window=sliding_window,
|
|
),
|
|
),
|
|
],
|
|
)
|
|
|
|
manager = make_kv_cache_manager(
|
|
config,
|
|
max_model_len=max_model_len,
|
|
max_num_batched_tokens=max_num_batched_tokens,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
|
|
# 16 blocks of full attention demand alone exceeds the 5-block pool.
|
|
prompt_len = 16 * block_size
|
|
req = make_request("oversized", list(range(prompt_len)), block_size, sha256)
|
|
|
|
assert manager.allocate_slots(req, block_size, full_sequence_must_fit=True) is None
|
|
|
|
|
|
def test_swa_free_split_keeps_cached_tail_ahead_of_scratch(monkeypatch):
|
|
"""Default path (no retention): freeing an SWA request must place its
|
|
uncached scratch blocks at the front of the free queue (recycled first)
|
|
and keep its cached checkpoint blocks at the back (retained for prefix
|
|
hits). This split is always-on, independent of the retention interval."""
|
|
monkeypatch.delenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", raising=False)
|
|
block_size = 8
|
|
kv_cache_config = KVCacheConfig(
|
|
num_blocks=100,
|
|
kv_cache_tensors=[],
|
|
kv_cache_groups=[
|
|
KVCacheGroupSpec(
|
|
["layer1"],
|
|
FullAttentionSpec(
|
|
block_size=4 * block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float16,
|
|
),
|
|
),
|
|
KVCacheGroupSpec(
|
|
["layer2"],
|
|
SlidingWindowSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
sliding_window=block_size,
|
|
),
|
|
),
|
|
],
|
|
)
|
|
manager = make_kv_cache_manager(
|
|
kv_cache_config=kv_cache_config,
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
)
|
|
|
|
token_ids = [i for i in range(16) for _ in range(block_size)]
|
|
req = make_request("0", token_ids, block_size, sha256)
|
|
computed_blocks, _ = manager.get_computed_blocks(req)
|
|
blocks = manager.allocate_slots(
|
|
req,
|
|
len(token_ids),
|
|
len(computed_blocks.blocks[0]) * block_size,
|
|
computed_blocks,
|
|
)
|
|
assert blocks is not None
|
|
|
|
swa_manager = manager.coordinator.single_type_managers[1]
|
|
null_block = manager.block_pool.null_block
|
|
cached_ids: set[int] = set()
|
|
uncached_ids: set[int] = set()
|
|
cached_hash_indices: list[int] = []
|
|
for i, block in enumerate(swa_manager.req_to_blocks[req.request_id]):
|
|
if block is null_block:
|
|
continue
|
|
if block.block_hash is None:
|
|
uncached_ids.add(block.block_id)
|
|
else:
|
|
cached_ids.add(block.block_id)
|
|
cached_hash_indices.append(i)
|
|
# The dense default mask caches only the per-segment tails, so a 16-block
|
|
# SWA prompt must produce a mix of retained and scratch blocks.
|
|
assert cached_ids, "expected some retained (cached) SWA tail blocks"
|
|
assert uncached_ids, "expected some scratch (uncached) SWA blocks"
|
|
|
|
manager.free(req)
|
|
|
|
order = [
|
|
b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks()
|
|
]
|
|
pos = {bid: i for i, bid in enumerate(order)}
|
|
# Every scratch block is recycled before every retained block.
|
|
assert max(pos[bid] for bid in uncached_ids) < min(pos[bid] for bid in cached_ids)
|
|
# The retained tails survive the free and still serve a prefix-cache hit.
|
|
for i in cached_hash_indices:
|
|
assert (
|
|
manager.block_pool.get_cached_block(
|
|
req.block_hashes[i], kv_cache_group_ids=[1]
|
|
)
|
|
is not None
|
|
)
|
|
|
|
|
|
def _make_pure_swa_manager(block_size, sliding_window, num_blocks=100, **kwargs):
|
|
"""Single sliding-window group (UnitaryKVCacheCoordinator)."""
|
|
kv_cache_config = KVCacheConfig(
|
|
num_blocks=num_blocks,
|
|
kv_cache_tensors=[],
|
|
kv_cache_groups=[
|
|
KVCacheGroupSpec(
|
|
["layer"],
|
|
SlidingWindowSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=1,
|
|
head_size=1,
|
|
dtype=torch.float32,
|
|
sliding_window=sliding_window,
|
|
),
|
|
),
|
|
],
|
|
)
|
|
return make_kv_cache_manager(
|
|
kv_cache_config=kv_cache_config,
|
|
max_model_len=8192,
|
|
enable_caching=True,
|
|
hash_block_size=block_size,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
def test_pure_swa_retention_interval_caches_sparse_tails(monkeypatch):
|
|
"""Sparse retention must work for a pure-SWA single-group model, not just
|
|
hybrid models: only the per-interval tails plus the latest replay tail are
|
|
cached, and a replay still hits the latest replayable boundary."""
|
|
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "64")
|
|
block_size = 16
|
|
manager = _make_pure_swa_manager(block_size, sliding_window=block_size)
|
|
assert type(manager.coordinator).__name__ == "UnitaryKVCacheCoordinator"
|
|
|
|
token_ids = [i for i in range(16) for _ in range(block_size)]
|
|
req = make_request("0", token_ids, block_size, sha256)
|
|
computed_blocks, _ = manager.get_computed_blocks(req)
|
|
blocks = manager.allocate_slots(
|
|
req,
|
|
len(token_ids),
|
|
len(computed_blocks.blocks[0]) * block_size,
|
|
computed_blocks,
|
|
)
|
|
assert blocks is not None
|
|
|
|
pool = manager.block_pool
|
|
cached = {
|
|
i
|
|
for i in range(16)
|
|
if pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[0])
|
|
is not None
|
|
}
|
|
# per_segment = 64 / 16 = 4, need = cdiv(16-1, 16) = 1 -> segment tails at
|
|
# i%4==3 -> {3,7,11,15}; latest replay boundary (255//16*16 = 240) -> tail
|
|
# block 14. Crucially this is a strict subset of all 16 blocks: retention
|
|
# is actually sparse for pure SWA (not silently dense).
|
|
assert cached == {3, 7, 11, 14, 15}
|
|
|
|
# A replay of the same prompt hits the latest replayable boundary (240).
|
|
replay = make_request("1", token_ids, block_size, sha256)
|
|
_, num_computed = manager.get_computed_blocks(replay)
|
|
assert num_computed == 240
|
|
|
|
|
|
def test_pure_swa_retention_latest_only(monkeypatch):
|
|
"""`=0` on a pure-SWA model keeps only the latest replay tail."""
|
|
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "0")
|
|
block_size = 16
|
|
manager = _make_pure_swa_manager(block_size, sliding_window=block_size)
|
|
|
|
token_ids = [i for i in range(16) for _ in range(block_size)]
|
|
req = make_request("0", token_ids, block_size, sha256)
|
|
computed_blocks, _ = manager.get_computed_blocks(req)
|
|
blocks = manager.allocate_slots(
|
|
req,
|
|
len(token_ids),
|
|
len(computed_blocks.blocks[0]) * block_size,
|
|
computed_blocks,
|
|
)
|
|
assert blocks is not None
|
|
|
|
pool = manager.block_pool
|
|
cached = {
|
|
i
|
|
for i in range(16)
|
|
if pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[0])
|
|
is not None
|
|
}
|
|
# No segment tails (interval 0); only the latest replay tail (block 14).
|
|
assert cached == {14}
|
|
|
|
replay = make_request("1", token_ids, block_size, sha256)
|
|
_, num_computed = manager.get_computed_blocks(replay)
|
|
assert num_computed == 240
|
|
|
|
|
|
def test_pure_swa_retention_dense_default_caches_all(monkeypatch):
|
|
"""With retention unset, a pure-SWA model must keep the dense behavior:
|
|
every block boundary is a potential hit, so all blocks are cached."""
|
|
monkeypatch.delenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", raising=False)
|
|
block_size = 16
|
|
manager = _make_pure_swa_manager(block_size, sliding_window=block_size)
|
|
|
|
token_ids = [i for i in range(16) for _ in range(block_size)]
|
|
req = make_request("0", token_ids, block_size, sha256)
|
|
computed_blocks, _ = manager.get_computed_blocks(req)
|
|
blocks = manager.allocate_slots(
|
|
req,
|
|
len(token_ids),
|
|
len(computed_blocks.blocks[0]) * block_size,
|
|
computed_blocks,
|
|
)
|
|
assert blocks is not None
|
|
|
|
pool = manager.block_pool
|
|
cached = {
|
|
i
|
|
for i in range(16)
|
|
if pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[0])
|
|
is not None
|
|
}
|
|
assert cached == set(range(16))
|