diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 7ce87540cd1..91c5f37b417 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -2466,6 +2466,201 @@ def test_eagle_with_sliding_window(): assert num_tokens == 0 +def test_eagle_swa_alignment_caches_extra_block(): + """Regression: SWA + EAGLE with `sliding_window <= alignment_tokens`. + + When the cache-hit alignment (lcm of per-group block sizes) is larger than + the SWA window, the SWA mask only kept the last block of each aligned + segment. EAGLE/MTP lookup needs ``tail + 1`` contiguous cached blocks and + that +1 block lives at the next segment's first position, which was left + uncached. The fix caches that extra block when ``use_eagle=True``. + """ + block_size = 8 + # Full group uses 4 * block_size, so lcm/alignment is 4 * block_size. + # SWA group has sliding_window = block_size (i.e., tail = 1 block). + # Without the fix, the second cached block needed for the EAGLE 2-block + # match never exists -> EAGLE cache hit fails entirely. + kv_cache_config = KVCacheConfig( + num_blocks=100, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["full"], + FullAttentionSpec( + block_size=4 * block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float16, + ), + ), + KVCacheGroupSpec( + ["swa_mtp"], + SlidingWindowSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=block_size, + ), + is_eagle_group=True, + ), + ], + ) + manager = make_kv_cache_manager( + kv_cache_config=kv_cache_config, + max_model_len=8192, + enable_caching=True, + hash_block_size=block_size, + use_eagle=True, + ) + + # Prime the cache with a long prompt (16 swa blocks = 4 aligned segments). + token_ids = [i for i in range(16) for _ in range(block_size)] + req0 = make_request("0", token_ids, block_size, sha256) + computed_blocks, _ = manager.get_computed_blocks(req0) + blocks = manager.allocate_slots( + req0, + len(token_ids), + len(computed_blocks.blocks[0]) * block_size, + computed_blocks, + ) + assert blocks is not None + manager.free(req0) + + # Second request with identical prompt should find an EAGLE cache hit. + # Without the fix, ``num_computed_tokens`` is 0; with the fix, it lands at + # an alignment boundary (multiple of 32 tokens, minus the EAGLE drop). + req1 = make_request("1", token_ids, block_size, sha256) + _, num_computed_tokens = manager.get_computed_blocks(req1) + assert num_computed_tokens > 0, ( + "EAGLE + SWA with sliding_window <= alignment failed to find any " + "cache hit; the +1 block past each segment boundary must be cached." + ) + # Each aligned segment contributes 4 * block_size = 32 tokens; EAGLE drops + # the last block (block_size tokens) from the hit. + assert num_computed_tokens % (4 * block_size) == 0 + + +def test_eagle_swa_boundary_caches_post_boundary_block(): + """EAGLE + SWA must cache the first block after an alignment boundary. + + A 40-token computed prefix with 8-token SWA blocks and 32-token hybrid + alignment needs SWA blocks 3 and 4 cached to reuse a 32-token prefix: + block 3 is the segment tail, and block 4 is the EAGLE lookahead block + that gets dropped after lookup. + """ + block_size = 8 + kv_cache_config = KVCacheConfig( + num_blocks=100, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["full"], + FullAttentionSpec( + block_size=4 * block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float16, + ), + ), + KVCacheGroupSpec( + ["swa_mtp"], + SlidingWindowSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=block_size, + ), + is_eagle_group=True, + ), + ], + ) + manager = make_kv_cache_manager( + kv_cache_config=kv_cache_config, + max_model_len=8192, + enable_caching=True, + hash_block_size=block_size, + use_eagle=True, + ) + + token_ids = [i for i in range(5) for _ in range(block_size)] + req0 = make_request("0", token_ids, block_size, sha256) + computed_blocks, _ = manager.get_computed_blocks(req0) + blocks = manager.allocate_slots( + req0, + len(token_ids), + len(computed_blocks.blocks[0]) * block_size, + computed_blocks, + ) + assert blocks is not None + + pool = manager.block_pool + assert pool.get_cached_block(req0.block_hashes[3], kv_cache_group_ids=[1]) + assert pool.get_cached_block(req0.block_hashes[4], kv_cache_group_ids=[1]) + manager.free(req0) + + req1 = make_request("1", token_ids + [999], block_size, sha256) + _, num_computed_tokens = manager.get_computed_blocks(req1) + assert num_computed_tokens == 4 * block_size + + +def test_eagle_grouped_swa_siblings_use_same_cache_mask(): + """Grouped SWA siblings must cache the EAGLE lookahead block together.""" + block_size = 8 + swa_spec = SlidingWindowSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=block_size, + ) + kv_cache_config = KVCacheConfig( + num_blocks=100, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["full"], + FullAttentionSpec( + block_size=4 * block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float16, + ), + ), + KVCacheGroupSpec(["swa_main"], swa_spec), + KVCacheGroupSpec(["swa_mtp"], swa_spec, is_eagle_group=True), + ], + ) + manager = make_kv_cache_manager( + kv_cache_config=kv_cache_config, + max_model_len=8192, + enable_caching=True, + hash_block_size=block_size, + use_eagle=True, + ) + + token_ids = [i for i in range(9) for _ in range(block_size)] + req0 = make_request("0", token_ids, block_size, sha256) + computed_blocks, _ = manager.get_computed_blocks(req0) + blocks = manager.allocate_slots( + req0, + len(token_ids), + len(computed_blocks.blocks[0]) * block_size, + computed_blocks, + ) + assert blocks is not None + + pool = manager.block_pool + assert pool.get_cached_block(req0.block_hashes[4], kv_cache_group_ids=[1, 2]) + assert pool.get_cached_block(req0.block_hashes[8], kv_cache_group_ids=[1, 2]) + manager.free(req0) + + req1 = make_request("1", token_ids + [999], block_size, sha256) + _, num_computed_tokens = manager.get_computed_blocks(req1) + assert num_computed_tokens == 8 * block_size + + def test_different_block_size(): block_size = 16 # full attention and sliding window attention layers have the same page size: @@ -2614,7 +2809,7 @@ def test_hybrid_cache_blocks_swa_tail_window_only(): def test_hybrid_cache_blocks_clamped_to_lcm(): - """HybridKVCacheCoordinator.cache_blocks() clamps to lcm_block_size. + """HybridKVCacheCoordinator.cache_blocks() clamps to scheduler_block_size. Chunks past the last lcm-aligned boundary can never participate in a cache hit (find_longest_cache_hit always returns lcm-aligned hits), so caching them only pollutes the prefix-cache hash map and keeps blocks diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index e82c17ca60c..0e3e8879359 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -86,7 +86,7 @@ def test_chunked_local_attention_possible_cached_prefix(): kv_cache_group_ids=[0], block_pool=block_pool, kv_cache_spec=chunked_local_attention_spec, - use_eagle=False, + drop_eagle_block=False, alignment_tokens=block_size, )[0] assert len(computed_blocks) == expect_length @@ -157,7 +157,7 @@ def test_sliding_window_possible_cached_prefix(): kv_cache_group_ids=[0], block_pool=block_pool, kv_cache_spec=sliding_window_spec, - use_eagle=False, + drop_eagle_block=False, alignment_tokens=block_size, )[0] assert len(computed_blocks) == expect_length diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/coordinator.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/coordinator.py index b16fdb7c16c..a17f0b5f5ff 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/coordinator.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/coordinator.py @@ -233,7 +233,7 @@ class MooncakeStoreCoordinator: kv_cache_group_ids=group_ids, block_pool=cast(BlockPool, cached_block_pool), kv_cache_spec=spec, - use_eagle=(0 in eagle_indices), + drop_eagle_block=(0 in eagle_indices), alignment_tokens=spec.block_size, ) num_groups = len(self.kv_cache_groups) @@ -262,9 +262,9 @@ class MooncakeStoreCoordinator: ) continue - use_eagle = idx in eagle_indices and idx not in eagle_verified + drop_eagle_block = idx in eagle_indices and idx not in eagle_verified _max_length = curr_hit_length - if use_eagle: + if drop_eagle_block: _max_length = min(curr_hit_length + spec.block_size, max_length) hashes = self.block_hashes_for_spec(block_hashes, spec) hit_blocks = manager_cls.find_longest_cache_hit( @@ -273,11 +273,11 @@ class MooncakeStoreCoordinator: kv_cache_group_ids=group_ids, block_pool=cast(BlockPool, cached_block_pool), kv_cache_spec=spec, - use_eagle=use_eagle, + drop_eagle_block=drop_eagle_block, alignment_tokens=self.lcm_block_size, ) _new_hit_length = len(hit_blocks[0]) * spec.block_size - if use_eagle: + if drop_eagle_block: eagle_verified.add(idx) elif _new_hit_length < curr_hit_length: eagle_verified.clear() diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 0336ecdfd28..387f1a1e335 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod from collections.abc import Sequence +from typing import NamedTuple from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector @@ -381,6 +382,8 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): assert len(self.kv_cache_config.kv_cache_groups) == 1, ( "UnitaryKVCacheCoordinator assumes only one kv cache group" ) + # Single group; useless but just set ``use_eagle`` for consistency regardless. + self.single_type_managers[0].use_eagle = 0 in self.eagle_group_ids def find_longest_cache_hit( self, @@ -393,7 +396,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): kv_cache_group_ids=[0], block_pool=self.block_pool, kv_cache_spec=self.kv_cache_spec, - use_eagle=0 in self.eagle_group_ids, + drop_eagle_block=0 in self.eagle_group_ids, alignment_tokens=self.block_size, dcp_world_size=self.dcp_world_size, pcp_world_size=self.pcp_world_size, @@ -401,6 +404,21 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): return hit_blocks, len(hit_blocks[0]) * self.block_size +class SpecGroup(NamedTuple): + """KV cache groups that share one spec, batched together for a single + cache-hit lookup. + + ``use_eagle`` is True iff any member group is an EAGLE/MTP group. Members + sharing a spec are cached and looked up jointly, so the EAGLE last-block drop + is necessarily decided for the whole spec group. + """ + + spec: KVCacheSpec + group_ids: list[int] + manager_cls: type[SingleTypeKVCacheManager] + use_eagle: bool + + class HybridKVCacheCoordinator(KVCacheCoordinator): """ KV cache coordinator for hybrid models with multiple KV cache types, and @@ -452,58 +470,63 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): Groups KV cache groups by their spec type for efficient batch processing during cache hit lookup. """ - attention_groups: list[ - tuple[KVCacheSpec, list[int], type[SingleTypeKVCacheManager]] - ] = [] - + self.attention_groups: list[SpecGroup] = [] for i, g in enumerate(self.kv_cache_config.kv_cache_groups): manager_cls = self.single_type_managers[i].__class__ spec = g.kv_cache_spec + use_eagle = i in self.eagle_group_ids # Try to find an existing group with the same spec - for existing_spec, group_ids, existing_cls in attention_groups: - if existing_spec == spec: - assert manager_cls is existing_cls, ( + for idx, group in enumerate(self.attention_groups): + if group.spec == spec: + assert manager_cls is group.manager_cls, ( "Expected same manager class for identical KV cache specs." ) - group_ids.append(i) + group.group_ids.append(i) + if use_eagle and not group.use_eagle: + self.attention_groups[idx] = group._replace(use_eagle=True) break else: - attention_groups.append((spec, [i], manager_cls)) + self.attention_groups.append( + SpecGroup(spec, [i], manager_cls, use_eagle) + ) - assert len(attention_groups) > 1, ( + assert len(self.attention_groups) > 1, ( "HybridKVCacheCoordinator requires at least two attention groups." ) # Put full attention first: its efficient left-to-right scan provides # a tighter initial bound, reducing work for subsequent groups. - self.attention_groups = sorted( - attention_groups, - key=lambda x: not isinstance(x[0], FullAttentionSpec), + self.attention_groups.sort( + key=lambda g: not isinstance(g.spec, FullAttentionSpec) ) - # Attention-group indices (into ``self.attention_groups``) that - # contain at least one EAGLE/MTP KV cache group. - self.eagle_attn_group_indices: set[int] = { - i - for i, (_, group_ids, _) in enumerate(self.attention_groups) - if any(gid in self.eagle_group_ids for gid in group_ids) - } + # Propagate the eagle bit to each manager (default to ``use_eagle=False``). + for group in self.attention_groups: + if group.use_eagle: + for gid in group.group_ids: + self.single_type_managers[gid].use_eagle = True def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: # Cache hits in this coordinator are always a multiple of # ``scheduler_block_size`` tokens (see ``find_longest_cache_hit``). - # Within an aligned region, SWA groups only consult a subset of blocks + # Within an aligned region, SWA groups may only consult a subset of blocks # per ``scheduler_block_size``-segment so the unused blocks also stay # out of the prefix-cache hash map. - num_computed_tokens = ( + aligned_num_computed_tokens = ( num_computed_tokens // self.scheduler_block_size * self.scheduler_block_size ) for manager in self.single_type_managers: + num_tokens_to_cache = aligned_num_computed_tokens + # EAGLE groups match one block past each aligned boundary and drop + # it, so make that lookahead block eligible to be cached. + if manager.use_eagle and aligned_num_computed_tokens > 0: + num_tokens_to_cache = min( + num_computed_tokens, + aligned_num_computed_tokens + manager.block_size, + ) manager.cache_blocks( - request, - num_computed_tokens, - alignment_tokens=self.scheduler_block_size, + request, num_tokens_to_cache, alignment_tokens=self.scheduler_block_size ) def find_longest_cache_hit( @@ -543,7 +566,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): # Simple hybrid (1 full attn + 1 other): one iteration suffices. # Full attn is always first if it exists. is_simple_hybrid = len(self.attention_groups) == 2 and isinstance( - self.attention_groups[0][0], FullAttentionSpec + self.attention_groups[0].spec, FullAttentionSpec ) # Attention-group indices whose EAGLE drop is verified at the current @@ -554,7 +577,9 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): while True: curr_hit_length = hit_length - for idx, (spec, group_ids, manager_cls) in enumerate(self.attention_groups): + for idx, (spec, group_ids, manager_cls, use_eagle) in enumerate( + self.attention_groups + ): cached_blocks = hit_blocks_by_group[group_ids[0]] if isinstance(spec, FullAttentionSpec) and cached_blocks is not None: # Full attention is downward-closed: we only need to look @@ -565,12 +590,10 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): ) continue - use_eagle = ( - idx in self.eagle_attn_group_indices and idx not in eagle_verified - ) + drop_eagle_block = use_eagle and idx not in eagle_verified _max_length = curr_hit_length - if use_eagle: + if drop_eagle_block: # Eagle needs to match one more block and then pop the last. _max_length = min( curr_hit_length + spec.block_size, max_cache_hit_length @@ -581,11 +604,11 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): kv_cache_group_ids=group_ids, block_pool=self.block_pool, kv_cache_spec=spec, - use_eagle=use_eagle, + drop_eagle_block=drop_eagle_block, alignment_tokens=self.scheduler_block_size, ) _new_hit_length = len(hit_blocks[0]) * spec.block_size - if use_eagle: + if drop_eagle_block: eagle_verified.add(idx) elif _new_hit_length < curr_hit_length: # length shrunk; invalidate previous eagle verifications @@ -601,10 +624,10 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): break # Truncate full attention blocks to final hit_length (if present) - spec, group_ids, _ = self.attention_groups[0] - if isinstance(spec, FullAttentionSpec): - num_blocks = hit_length // spec.block_size - for group_id in group_ids: + first_group = self.attention_groups[0] + if isinstance(first_group.spec, FullAttentionSpec): + num_blocks = hit_length // first_group.spec.block_size + for group_id in first_group.group_ids: if (blks := hit_blocks_by_group[group_id]) is not None: del blks[num_blocks:] diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 7559e6c1c7b..7b16d9c6f05 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -86,6 +86,12 @@ class SingleTypeKVCacheManager(ABC): self.kv_cache_group_id = kv_cache_group_id self._null_block = block_pool.null_block + # Whether this group's prefix-cache hits drop the EAGLE/MTP lookahead + # block. Only consulted by managers whose hit logic is sparse within an + # aligned segment (SWA). Initialized lazily by the coordinator after + # determining the attention groups. + self.use_eagle = False + @classmethod def _get_num_evictable_blocks(cls, blocks: Sequence[KVCacheBlock]): return sum(blk.ref_cnt == 0 and not blk.is_null for blk in blocks) @@ -317,8 +323,12 @@ class SingleTypeKVCacheManager(ABC): if alignment_tokens is None or alignment_tokens <= self.block_size: block_mask = None else: - block_mask = self._cache_block_mask( - num_cached_blocks, num_full_blocks, alignment_tokens + block_mask = self.reachable_block_mask( + num_cached_blocks, + num_full_blocks, + alignment_tokens, + self.kv_cache_spec, + self.use_eagle, ) self.block_pool.cache_full_blocks( request=request, @@ -332,11 +342,14 @@ class SingleTypeKVCacheManager(ABC): self.num_cached_block[request.request_id] = num_full_blocks - def _cache_block_mask( - self, - num_cached_blocks: int, - num_full_blocks: int, + @classmethod + def reachable_block_mask( + cls, + start_block: int, + num_blocks: int, alignment_tokens: int, + kv_cache_spec: KVCacheSpec, + use_eagle: bool, ) -> list[bool] | None: """Per-block mask for ``cache_full_blocks``. ``None`` means cache every (non-null) block — the default for full attention. @@ -389,7 +402,7 @@ class SingleTypeKVCacheManager(ABC): kv_cache_group_ids: list[int], block_pool: BlockPool, kv_cache_spec: KVCacheSpec, - use_eagle: bool, + drop_eagle_block: bool, alignment_tokens: int, dcp_world_size: int = 1, pcp_world_size: int = 1, @@ -409,7 +422,10 @@ class SingleTypeKVCacheManager(ABC): kv_cache_group_ids: The ids of the kv cache groups. block_pool: The block pool. kv_cache_spec: The kv cache spec. - use_eagle: Whether to use eagle. + drop_eagle_block: Whether to drop the last matched block for EAGLE/MTP. + Always False for non-EAGLE/MTP groups, but can be False for EAGLE/MTP + groups too if the last block is already dropped (e.g., in a + convergence loop in `find_longest_cache_hit`). alignment_tokens: The returned cache hit length (in tokens) should be a multiple of this value (in tokens). By default, it should be set to the block_size. @@ -499,7 +515,7 @@ class FullAttentionManager(SingleTypeKVCacheManager): kv_cache_group_ids: list[int], block_pool: BlockPool, kv_cache_spec: KVCacheSpec, - use_eagle: bool, + drop_eagle_block: bool, alignment_tokens: int, dcp_world_size: int = 1, pcp_world_size: int = 1, @@ -528,7 +544,7 @@ class FullAttentionManager(SingleTypeKVCacheManager): computed.append(cached) else: break - if use_eagle and computed_blocks[0]: + if drop_eagle_block and computed_blocks[0]: # Need to drop the last matched block if eagle is enabled. for computed in computed_blocks: computed.pop() @@ -556,6 +572,19 @@ class SlidingWindowManager(SingleTypeKVCacheManager): super().__init__(kv_cache_spec, **kwargs) self.sliding_window = kv_cache_spec.sliding_window + @classmethod + def _contiguous_blocks_for_hit( + cls, window_size: int, block_size: int, use_eagle: bool + ) -> int: + blocks = cdiv(window_size - 1, block_size) + if use_eagle: + # Need to drop the last matched block if eagle is enabled. For + # sliding window layer, we achieve this by increasing the number of + # contiguous blocks needed for prefix cache hit by one and dropping + # the last matched block. + blocks += 1 + return blocks + @classmethod def find_longest_cache_hit( cls, @@ -564,7 +593,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager): kv_cache_group_ids: list[int], block_pool: BlockPool, kv_cache_spec: KVCacheSpec, - use_eagle: bool, + drop_eagle_block: bool, alignment_tokens: int, dcp_world_size: int = 1, pcp_world_size: int = 1, @@ -575,17 +604,10 @@ class SlidingWindowManager(SingleTypeKVCacheManager): assert dcp_world_size == 1, "DCP not support sliding window attn now." assert pcp_world_size == 1, "PCP not support sliding window attn now." - # The number of contiguous blocks needed for prefix cache hit. - # -1 since the input token itself is also included in the window - sliding_window_contiguous_blocks = cdiv( - kv_cache_spec.sliding_window - 1, kv_cache_spec.block_size + # The number of contiguous blocks needed for a prefix cache hit. + sliding_window_contiguous_blocks = cls._contiguous_blocks_for_hit( + kv_cache_spec.sliding_window, kv_cache_spec.block_size, drop_eagle_block ) - if use_eagle: - # Need to drop the last matched block if eagle is enabled. For - # sliding window layer, we achieve this by increasing the number of - # contiguous blocks needed for prefix cache hit by one and dropping - # the last matched block. - sliding_window_contiguous_blocks += 1 # TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to # optimize the time complexity from O(max_num_blocks) to @@ -608,7 +630,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager): # Skip prefix matching check if the block is not aligned with # `alignment_tokens`. if num_contiguous_blocks == 0 and block_size != alignment_tokens: - post_pop_blocks = i if use_eagle else i + 1 + post_pop_blocks = i if drop_eagle_block else i + 1 if (post_pop_blocks * block_size) % alignment_tokens != 0: continue # Add the cached block to the computed blocks. @@ -636,7 +658,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager): ): for computed in computed_blocks: computed.pop() - if use_eagle and computed_blocks[0]: + if drop_eagle_block and computed_blocks[0]: for computed in computed_blocks: computed.pop() # Re-align after eagle pop: the pop may break the alignment @@ -650,17 +672,33 @@ class SlidingWindowManager(SingleTypeKVCacheManager): computed.pop() return computed_blocks - def _cache_block_mask( - self, num_cached_blocks: int, num_full_blocks: int, alignment_tokens: int + @classmethod + def reachable_block_mask( + cls, + start_block: int, + num_blocks: int, + alignment_tokens: int, + kv_cache_spec: KVCacheSpec, + use_eagle: bool, ) -> list[bool] | None: - assert alignment_tokens > self.block_size - per_segment = alignment_tokens // self.block_size - tail = cdiv(self.sliding_window - 1, self.block_size) - if tail >= per_segment: + assert alignment_tokens > kv_cache_spec.block_size + assert isinstance(kv_cache_spec, SlidingWindowSpec) + per_segment = alignment_tokens // kv_cache_spec.block_size + need = cls._contiguous_blocks_for_hit( + window_size=kv_cache_spec.sliding_window, + block_size=kv_cache_spec.block_size, + use_eagle=use_eagle, + ) + if need >= per_segment: return None - skip = per_segment - tail + # The matched run's right edge sits on the aligned boundary block when + # EAGLE peeks one block past it (shift=1), otherwise on the last block + # before the boundary (shift=0). A block is reachable iff it falls in + # the ``need``-wide run ending at some boundary's right edge. + shift = 1 if use_eagle else 0 return [ - i % per_segment >= skip for i in range(num_cached_blocks, num_full_blocks) + i >= shift and (i - shift) % per_segment >= per_segment - need + for i in range(start_block, num_blocks) ] def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: @@ -714,7 +752,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): kv_cache_group_ids: list[int], block_pool: BlockPool, kv_cache_spec: KVCacheSpec, - use_eagle: bool, + drop_eagle_block: bool, alignment_tokens: int, dcp_world_size: int = 1, pcp_world_size: int = 1, @@ -745,7 +783,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): kv_cache_group_ids: The ids of the kv cache groups. block_pool: The block pool. kv_cache_spec: The kv cache spec. - use_eagle: Whether to use eagle. + drop_eagle_block: Whether to drop the last matched block for EAGLE/MTP. dcp_world_size: The world size of decode context parallelism. pcp_world_size: The world size of prefill context parallelism. alignment_tokens: The returned cache hit length (in tokens) should @@ -758,7 +796,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): "ChunkedLocalAttentionManager can only be used for " "chunked local attention groups" ) - assert use_eagle is False, ( + assert drop_eagle_block is False, ( "Hybrid KV cache is not supported for " + "eagle + chunked local attention." ) assert dcp_world_size == 1, "DCP not support chunked local attn now." @@ -874,7 +912,7 @@ class MambaManager(SingleTypeKVCacheManager): kv_cache_group_ids: list[int], block_pool: BlockPool, kv_cache_spec: KVCacheSpec, - use_eagle: bool, + drop_eagle_block: bool, alignment_tokens: int, dcp_world_size: int = 1, pcp_world_size: int = 1, @@ -1168,7 +1206,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager): kv_cache_group_ids: list[int], block_pool: BlockPool, kv_cache_spec: KVCacheSpec, - use_eagle: bool, + drop_eagle_block: bool, alignment_tokens: int, dcp_world_size: int = 1, pcp_world_size: int = 1,