[Core][Refactor]: thread scheduler_block_size into KVCacheManager and KVCacheCoordinator (#44165)

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
This commit is contained in:
Yifan Qiao
2026-06-02 01:14:44 -07:00
committed by GitHub
parent b817b23f7b
commit 7c37096620
9 changed files with 99 additions and 50 deletions
+12 -3
View File
@@ -1447,7 +1447,10 @@ def test_allocate_with_lookahead():
# Test case 1: Requires additional lookahead tokens
kv_cache_manager = KVCacheManager(
kv_cache_config=config, max_model_len=100, hash_block_size=block_size
kv_cache_config=config,
max_model_len=100,
scheduler_block_size=block_size,
hash_block_size=block_size,
)
blocks = kv_cache_manager.allocate_slots(
request,
@@ -1458,7 +1461,10 @@ def test_allocate_with_lookahead():
# Test case 2: With precomputed blocks
kv_cache_manager = KVCacheManager(
kv_cache_config=config, max_model_len=100, hash_block_size=block_size
kv_cache_config=config,
max_model_len=100,
scheduler_block_size=block_size,
hash_block_size=block_size,
)
# required_blocks = ceil((3 + 2) /4) = 2
blocks = kv_cache_manager.allocate_slots(
@@ -1471,7 +1477,10 @@ def test_allocate_with_lookahead():
# Test case 3: With precomputed blocks
# required_blocks = ceil((3 + 4) / 4) = 2
kv_cache_manager = KVCacheManager(
kv_cache_config=config, max_model_len=100, hash_block_size=block_size
kv_cache_config=config,
max_model_len=100,
scheduler_block_size=block_size,
hash_block_size=block_size,
)
blocks = kv_cache_manager.allocate_slots(
request,
+43 -30
View File
@@ -4,6 +4,7 @@
import copy
from collections.abc import Callable
from math import lcm
import pytest
import torch
@@ -92,6 +93,18 @@ def make_request(
)
def make_kv_cache_manager(kv_cache_config: KVCacheConfig, **kwargs) -> KVCacheManager:
"""Build a ``KVCacheManager``, deriving ``scheduler_block_size`` from the
config (LCM of group block sizes) unless explicitly provided. This mirrors
``resolve_kv_cache_block_sizes`` for the non-context-parallel case used by
these tests, so callers don't have to pass it at every site."""
kwargs.setdefault(
"scheduler_block_size",
lcm(*(g.kv_cache_spec.block_size for g in kv_cache_config.kv_cache_groups)),
)
return KVCacheManager(kv_cache_config, **kwargs)
def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
return KVCacheConfig(
num_blocks=num_blocks,
@@ -208,7 +221,7 @@ def make_kv_cache_config_three_types(
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_prefill(hash_fn):
block_size = 16
manager = KVCacheManager(
manager = make_kv_cache_manager(
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
@@ -331,7 +344,7 @@ def test_prefill(hash_fn):
def test_prefill_hybrid_model():
block_size = 16
manager = KVCacheManager(
manager = make_kv_cache_manager(
make_kv_cache_config_hybrid_model(block_size, 21, 2),
max_model_len=8192,
enable_caching=True,
@@ -500,7 +513,7 @@ def test_prefill_hybrid_model():
def test_prefill_hybrid_model_eagle():
block_size = 16
kv_cache_config = make_kv_cache_config_hybrid_model(block_size, 31, 3)
manager = KVCacheManager(
manager = make_kv_cache_manager(
kv_cache_config,
max_model_len=8192,
enable_caching=True,
@@ -837,7 +850,7 @@ def test_prefill_hybrid_model_combinations(spec_types: list[str]):
num_blocks = 10 * num_groups
kv_cache_config = _make_hybrid_kv_cache_config(block_size, num_blocks, spec_types)
manager = KVCacheManager(
manager = make_kv_cache_manager(
kv_cache_config,
max_model_len=8192,
enable_caching=True,
@@ -912,7 +925,7 @@ def test_prefill_hybrid_model_combinations_eagle(
num_blocks = 10 * num_groups
kv_cache_config = _make_hybrid_kv_cache_config(block_size, num_blocks, spec_types)
manager = KVCacheManager(
manager = make_kv_cache_manager(
kv_cache_config,
max_model_len=8192,
enable_caching=True,
@@ -984,7 +997,7 @@ def test_prefill_hybrid_model_mamba_align():
kv_cache_config = _make_hybrid_kv_cache_config(
block_size, num_blocks, ["full", "mamba_align"]
)
manager = KVCacheManager(
manager = make_kv_cache_manager(
kv_cache_config,
max_model_len=8192,
enable_caching=True,
@@ -1017,7 +1030,7 @@ def test_prefill_plp():
3. Schedule plp request; no hit should occur; validate blocks
"""
block_size = 16
manager = KVCacheManager(
manager = make_kv_cache_manager(
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
@@ -1125,7 +1138,7 @@ def test_prefill_plp():
def test_decode():
block_size = 16
manager = KVCacheManager(
manager = make_kv_cache_manager(
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
@@ -1188,7 +1201,7 @@ def test_decode():
def test_evict():
block_size = 16
manager = KVCacheManager(
manager = make_kv_cache_manager(
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
@@ -1247,7 +1260,7 @@ def test_hash_block_correct_reuse():
its hash metadata should be correctly reset.
"""
block_size = 16
manager = KVCacheManager(
manager = make_kv_cache_manager(
make_kv_cache_config(16, 2),
max_model_len=8192,
enable_caching=True,
@@ -1288,7 +1301,7 @@ def test_computed_blocks_not_evicted():
for a request if there are any other free blocks.
"""
block_size = 16
manager = KVCacheManager(
manager = make_kv_cache_manager(
make_kv_cache_config(block_size, 3),
max_model_len=8192,
enable_caching=True,
@@ -1347,7 +1360,7 @@ def test_basic_prefix_caching_disabled():
This tests that the prefix caching is disabled.
"""
block_size = 4
manager = KVCacheManager(
manager = make_kv_cache_manager(
make_kv_cache_config(block_size, 5),
max_model_len=8192,
enable_caching=False,
@@ -1531,7 +1544,7 @@ def test_mm_prefix_caching():
"""
block_size = 16
manager = KVCacheManager(
manager = make_kv_cache_manager(
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
@@ -1639,7 +1652,7 @@ def test_cache_key_salting():
is separated cache as expected.
"""
block_size = 16
manager = KVCacheManager(
manager = make_kv_cache_manager(
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
@@ -1721,7 +1734,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
the computed blocks should not be touched.
"""
block_size = 16
manager = KVCacheManager(
manager = make_kv_cache_manager(
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
@@ -1794,7 +1807,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
def test_reset_prefix_cache():
block_size = 16
manager = KVCacheManager(
manager = make_kv_cache_manager(
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
@@ -1835,7 +1848,7 @@ def test_reset_prefix_cache():
def test_prefix_cache_stats_disabled():
"""Test that prefix_cache_stats is None when log_stats is False."""
block_size = 16
manager = KVCacheManager(
manager = make_kv_cache_manager(
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
@@ -1915,7 +1928,7 @@ def test_kv_cache_events(blocks_to_cache: int):
# Should see a single block stored event with a blocks_to_cache number of
# block hashes
# take_events should reset the kv_event_queue
manager = KVCacheManager(
manager = make_kv_cache_manager(
make_kv_cache_config(block_size, num_blocks),
max_model_len=8192,
enable_caching=True,
@@ -2043,7 +2056,7 @@ def test_kv_cache_events_with_lora(blocks_to_cache: int):
num_blocks = blocks_to_cache + 1
# Create KVCacheManager with events enabled
manager = KVCacheManager(
manager = make_kv_cache_manager(
make_kv_cache_config(block_size, num_blocks),
max_model_len=8192,
enable_caching=True,
@@ -2101,7 +2114,7 @@ def test_block_stored_event_group_idx(group_id: int):
block_size = 4
num_tokens = block_size * 2
manager = KVCacheManager(
manager = make_kv_cache_manager(
make_kv_cache_config_three_types(block_size, num_blocks=5),
max_model_len=8192,
enable_caching=True,
@@ -2161,7 +2174,7 @@ def test_block_stored_event_group_idx_multiple_groups():
block_size = 4
num_tokens = block_size * 2
manager = KVCacheManager(
manager = make_kv_cache_manager(
KVCacheConfig(
num_blocks=5,
kv_cache_tensors=[],
@@ -2238,7 +2251,7 @@ def test_block_stored_event_group_idx_multiple_groups():
def test_block_stored_event_group_idx_out_of_bounds(monkeypatch):
"""Out-of-range group_idx events are returned without metadata annotation."""
block_size = 4
manager = KVCacheManager(
manager = make_kv_cache_manager(
make_kv_cache_config(block_size, num_blocks=5),
max_model_len=8192,
enable_caching=True,
@@ -2328,7 +2341,7 @@ def test_eagle_enabled_removes_last_block():
"""Verify Eagle does NOT remove blocks when request
length is divisible by block size."""
block_size = 16
manager = KVCacheManager(
manager = make_kv_cache_manager(
make_kv_cache_config(block_size, num_blocks=10),
max_model_len=8192,
enable_caching=True,
@@ -2361,7 +2374,7 @@ def test_eagle_enabled_removes_last_block():
def test_eagle_with_partial_blocks():
"""Test Eagle behavior with requests containing partial blocks."""
block_size = 16
manager = KVCacheManager(
manager = make_kv_cache_manager(
make_kv_cache_config(block_size, num_blocks=10),
max_model_len=8192,
enable_caching=True,
@@ -2397,7 +2410,7 @@ def test_eagle_with_sliding_window():
dtype=torch.float32,
sliding_window=block_size,
)
manager = KVCacheManager(
manager = make_kv_cache_manager(
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[],
@@ -2482,7 +2495,7 @@ def test_different_block_size():
),
],
)
manager = KVCacheManager(
manager = make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
@@ -2565,7 +2578,7 @@ def test_hybrid_cache_blocks_swa_tail_window_only():
),
],
)
manager = KVCacheManager(
manager = make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
@@ -2633,7 +2646,7 @@ def test_hybrid_cache_blocks_clamped_to_lcm():
),
],
)
manager = KVCacheManager(
manager = make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
@@ -2781,7 +2794,7 @@ def test_can_fit_full_sequence_swa_cap_admits_long_prompt():
],
)
manager = KVCacheManager(
manager = make_kv_cache_manager(
config,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
@@ -2837,7 +2850,7 @@ def test_can_fit_full_sequence_full_attention_still_gates_oversized():
],
)
manager = KVCacheManager(
manager = make_kv_cache_manager(
config,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
@@ -28,6 +28,7 @@ def get_sliding_window_manager(sliding_window_spec, block_pool, enable_caching=T
block_pool=block_pool,
enable_caching=enable_caching,
kv_cache_group_id=0,
scheduler_block_size=sliding_window_spec.block_size,
max_admission_blocks_per_request=10**9,
)
@@ -40,6 +41,7 @@ def get_chunked_local_attention_manager(
block_pool=block_pool,
enable_caching=enable_caching,
kv_cache_group_id=0,
scheduler_block_size=chunked_local_attention_spec.block_size,
max_admission_blocks_per_request=10**9,
)
@@ -458,6 +460,7 @@ def test_predictor_matches_allocator_blocks_calculation_with_admission_cap():
block_pool=block_pool,
enable_caching=False,
kv_cache_group_id=0,
scheduler_block_size=spec.block_size,
max_admission_blocks_per_request=cap,
)
@@ -28,6 +28,7 @@ from vllm.utils.hashing import sha256
from vllm.v1.core.kv_cache_utils import (
get_request_block_hasher,
init_none_hash,
resolve_kv_cache_block_sizes,
)
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.scheduler import Scheduler
@@ -230,13 +231,18 @@ class RequestRunner:
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
self.num_kv_groups = len(kv_cache_config.kv_cache_groups)
scheduler_block_size, hash_block_size = resolve_kv_cache_block_sizes(
kv_cache_config, vllm_config
)
scheduler_cls = AsyncScheduler if async_scheduling else Scheduler
self.scheduler = scheduler_cls(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
block_size=block_size,
block_size=scheduler_block_size,
hash_block_size=hash_block_size,
)
self.worker_connector = OffloadingConnector(
+26 -16
View File
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Sequence
from math import lcm
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector
@@ -40,12 +39,20 @@ class KVCacheCoordinator(ABC):
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
scheduler_block_size: int,
hash_block_size: int,
metrics_collector: KVCacheMetricsCollector | None = None,
):
self.kv_cache_config = kv_cache_config
self.max_model_len = max_model_len
self.enable_caching = enable_caching
# The scheduling granularity (LCM of all group block sizes), must be a multiple
# of the hash_block_size and the block size of each group.
assert scheduler_block_size % hash_block_size == 0 and all(
scheduler_block_size % g.kv_cache_spec.block_size == 0
for g in kv_cache_config.kv_cache_groups
)
self.scheduler_block_size = scheduler_block_size
self.block_pool = BlockPool(
num_gpu_blocks=kv_cache_config.num_blocks,
@@ -73,6 +80,7 @@ class KVCacheCoordinator(ABC):
kv_cache_group_id=i,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
scheduler_block_size=self.scheduler_block_size,
)
for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups)
)
@@ -290,6 +298,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
scheduler_block_size: int,
hash_block_size: int,
metrics_collector: KVCacheMetricsCollector | None = None,
):
@@ -302,6 +311,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
scheduler_block_size=scheduler_block_size,
hash_block_size=hash_block_size,
metrics_collector=metrics_collector,
)
@@ -338,6 +348,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
scheduler_block_size: int,
hash_block_size: int,
metrics_collector: KVCacheMetricsCollector | None = None,
):
@@ -350,6 +361,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
scheduler_block_size=scheduler_block_size,
hash_block_size=hash_block_size,
metrics_collector=metrics_collector,
)
@@ -405,6 +417,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
scheduler_block_size: int,
hash_block_size: int,
metrics_collector: KVCacheMetricsCollector | None = None,
):
@@ -417,6 +430,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
scheduler_block_size=scheduler_block_size,
hash_block_size=hash_block_size,
metrics_collector=metrics_collector,
)
@@ -468,14 +482,6 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
key=lambda x: not isinstance(x[0], FullAttentionSpec),
)
# The LCM of the block sizes of all attention types.
# The cache hit length must be a multiple of the LCM of the block sizes
# to make sure the cache hit length is a multiple of the block size of
# each attention type. Requiring this because we don't support partial
# block cache hit yet.
block_sizes = [spec.block_size for spec, _, _ in attention_groups]
self.lcm_block_size = lcm(*block_sizes)
# Attention-group indices (into ``self.attention_groups``) that
# contain at least one EAGLE/MTP KV cache group.
self.eagle_attn_group_indices: set[int] = {
@@ -486,18 +492,18 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
# Cache hits in this coordinator are always a multiple of
# ``lcm_block_size`` tokens (see ``find_longest_cache_hit``). Within an
# aligned region, SWA groups only consult a subset of blocks per
# ``lcm_block_size``-segment so the unused blocks also stay out of the
# prefix-cache hash map.
# ``scheduler_block_size`` tokens (see ``find_longest_cache_hit``).
# Within an aligned region, SWA groups only consult a subset of blocks
# per ``scheduler_block_size``-segment so the unused blocks also stay
# out of the prefix-cache hash map.
num_computed_tokens = (
num_computed_tokens // self.lcm_block_size * self.lcm_block_size
num_computed_tokens // self.scheduler_block_size * self.scheduler_block_size
)
for manager in self.single_type_managers:
manager.cache_blocks(
request,
num_computed_tokens,
alignment_tokens=self.lcm_block_size,
alignment_tokens=self.scheduler_block_size,
)
def find_longest_cache_hit(
@@ -576,7 +582,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
block_pool=self.block_pool,
kv_cache_spec=spec,
use_eagle=use_eagle,
alignment_tokens=self.lcm_block_size,
alignment_tokens=self.scheduler_block_size,
)
_new_hit_length = len(hit_blocks[0]) * spec.block_size
if use_eagle:
@@ -616,6 +622,7 @@ def get_kv_cache_coordinator(
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
scheduler_block_size: int,
hash_block_size: int,
metrics_collector: KVCacheMetricsCollector | None = None,
) -> KVCacheCoordinator:
@@ -628,6 +635,7 @@ def get_kv_cache_coordinator(
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
scheduler_block_size=scheduler_block_size,
hash_block_size=hash_block_size,
metrics_collector=metrics_collector,
)
@@ -641,6 +649,7 @@ def get_kv_cache_coordinator(
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
scheduler_block_size=scheduler_block_size,
hash_block_size=hash_block_size,
metrics_collector=metrics_collector,
)
@@ -653,6 +662,7 @@ def get_kv_cache_coordinator(
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
scheduler_block_size=scheduler_block_size,
hash_block_size=hash_block_size,
metrics_collector=metrics_collector,
)
+2
View File
@@ -112,6 +112,7 @@ class KVCacheManager:
self,
kv_cache_config: KVCacheConfig,
max_model_len: int,
scheduler_block_size: int,
hash_block_size: int,
max_num_batched_tokens: int | None = None,
enable_caching: bool = True,
@@ -147,6 +148,7 @@ class KVCacheManager:
enable_kv_cache_events=enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
scheduler_block_size=scheduler_block_size,
hash_block_size=hash_block_size,
metrics_collector=self.metrics_collector,
)
+1
View File
@@ -237,6 +237,7 @@ class Scheduler(SchedulerInterface):
enable_kv_cache_events=self.enable_kv_cache_events,
dcp_world_size=self.dcp_world_size,
pcp_world_size=self.pcp_world_size,
scheduler_block_size=self.block_size,
hash_block_size=hash_block_size,
metrics_collector=self.kv_metrics_collector,
)
@@ -40,6 +40,7 @@ class SingleTypeKVCacheManager(ABC):
block_pool: BlockPool,
enable_caching: bool,
kv_cache_group_id: int,
scheduler_block_size: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
max_admission_blocks_per_request: int | None = None,
@@ -50,6 +51,8 @@ class SingleTypeKVCacheManager(ABC):
kv_cache_spec: The kv_cache_spec for this manager.
block_pool: The block pool.
kv_cache_group_id: The id of the kv cache group of this manager.
scheduler_block_size: The scheduling granularity (LCM of all group
block sizes); a multiple of this manager's ``block_size``.
max_admission_blocks_per_request: Recycling-aware per-request
block cap used by `get_num_blocks_to_allocate`. Only set for
spec types that recycle blocks across chunks (SWA,
@@ -57,6 +60,7 @@ class SingleTypeKVCacheManager(ABC):
correct for full-attention-style specs that hold every
block until the request finishes.
"""
self.scheduler_block_size = scheduler_block_size
self.block_size = kv_cache_spec.block_size
self.dcp_world_size = dcp_world_size
self.pcp_world_size = pcp_world_size
+1
View File
@@ -127,6 +127,7 @@ class SimpleCPUOffloadScheduler:
enable_kv_cache_events=self.enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
scheduler_block_size=self.block_size,
hash_block_size=self.hash_block_size,
)
self.cpu_block_pool: BlockPool = self.cpu_coordinator.block_pool