[Bugfix] Cache the EAGLE/MTP lookahead block in the SWA prefix-cache mask (#44082)

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
This commit is contained in:
Yifan Qiao
2026-06-02 12:21:07 -07:00
committed by GitHub
parent e4a2e584e5
commit e9e08c49b9
5 changed files with 338 additions and 82 deletions
+196 -1
View File
@@ -2466,6 +2466,201 @@ def test_eagle_with_sliding_window():
assert num_tokens == 0
def test_eagle_swa_alignment_caches_extra_block():
"""Regression: SWA + EAGLE with `sliding_window <= alignment_tokens`.
When the cache-hit alignment (lcm of per-group block sizes) is larger than
the SWA window, the SWA mask only kept the last block of each aligned
segment. EAGLE/MTP lookup needs ``tail + 1`` contiguous cached blocks and
that +1 block lives at the next segment's first position, which was left
uncached. The fix caches that extra block when ``use_eagle=True``.
"""
block_size = 8
# Full group uses 4 * block_size, so lcm/alignment is 4 * block_size.
# SWA group has sliding_window = block_size (i.e., tail = 1 block).
# Without the fix, the second cached block needed for the EAGLE 2-block
# match never exists -> EAGLE cache hit fails entirely.
kv_cache_config = KVCacheConfig(
num_blocks=100,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["full"],
FullAttentionSpec(
block_size=4 * block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float16,
),
),
KVCacheGroupSpec(
["swa_mtp"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
),
is_eagle_group=True,
),
],
)
manager = make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
use_eagle=True,
)
# Prime the cache with a long prompt (16 swa blocks = 4 aligned segments).
token_ids = [i for i in range(16) for _ in range(block_size)]
req0 = make_request("0", token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req0)
blocks = manager.allocate_slots(
req0,
len(token_ids),
len(computed_blocks.blocks[0]) * block_size,
computed_blocks,
)
assert blocks is not None
manager.free(req0)
# Second request with identical prompt should find an EAGLE cache hit.
# Without the fix, ``num_computed_tokens`` is 0; with the fix, it lands at
# an alignment boundary (multiple of 32 tokens, minus the EAGLE drop).
req1 = make_request("1", token_ids, block_size, sha256)
_, num_computed_tokens = manager.get_computed_blocks(req1)
assert num_computed_tokens > 0, (
"EAGLE + SWA with sliding_window <= alignment failed to find any "
"cache hit; the +1 block past each segment boundary must be cached."
)
# Each aligned segment contributes 4 * block_size = 32 tokens; EAGLE drops
# the last block (block_size tokens) from the hit.
assert num_computed_tokens % (4 * block_size) == 0
def test_eagle_swa_boundary_caches_post_boundary_block():
"""EAGLE + SWA must cache the first block after an alignment boundary.
A 40-token computed prefix with 8-token SWA blocks and 32-token hybrid
alignment needs SWA blocks 3 and 4 cached to reuse a 32-token prefix:
block 3 is the segment tail, and block 4 is the EAGLE lookahead block
that gets dropped after lookup.
"""
block_size = 8
kv_cache_config = KVCacheConfig(
num_blocks=100,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["full"],
FullAttentionSpec(
block_size=4 * block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float16,
),
),
KVCacheGroupSpec(
["swa_mtp"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
),
is_eagle_group=True,
),
],
)
manager = make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
use_eagle=True,
)
token_ids = [i for i in range(5) for _ in range(block_size)]
req0 = make_request("0", token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req0)
blocks = manager.allocate_slots(
req0,
len(token_ids),
len(computed_blocks.blocks[0]) * block_size,
computed_blocks,
)
assert blocks is not None
pool = manager.block_pool
assert pool.get_cached_block(req0.block_hashes[3], kv_cache_group_ids=[1])
assert pool.get_cached_block(req0.block_hashes[4], kv_cache_group_ids=[1])
manager.free(req0)
req1 = make_request("1", token_ids + [999], block_size, sha256)
_, num_computed_tokens = manager.get_computed_blocks(req1)
assert num_computed_tokens == 4 * block_size
def test_eagle_grouped_swa_siblings_use_same_cache_mask():
"""Grouped SWA siblings must cache the EAGLE lookahead block together."""
block_size = 8
swa_spec = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
)
kv_cache_config = KVCacheConfig(
num_blocks=100,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["full"],
FullAttentionSpec(
block_size=4 * block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float16,
),
),
KVCacheGroupSpec(["swa_main"], swa_spec),
KVCacheGroupSpec(["swa_mtp"], swa_spec, is_eagle_group=True),
],
)
manager = make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
use_eagle=True,
)
token_ids = [i for i in range(9) for _ in range(block_size)]
req0 = make_request("0", token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req0)
blocks = manager.allocate_slots(
req0,
len(token_ids),
len(computed_blocks.blocks[0]) * block_size,
computed_blocks,
)
assert blocks is not None
pool = manager.block_pool
assert pool.get_cached_block(req0.block_hashes[4], kv_cache_group_ids=[1, 2])
assert pool.get_cached_block(req0.block_hashes[8], kv_cache_group_ids=[1, 2])
manager.free(req0)
req1 = make_request("1", token_ids + [999], block_size, sha256)
_, num_computed_tokens = manager.get_computed_blocks(req1)
assert num_computed_tokens == 8 * block_size
def test_different_block_size():
block_size = 16
# full attention and sliding window attention layers have the same page size:
@@ -2614,7 +2809,7 @@ def test_hybrid_cache_blocks_swa_tail_window_only():
def test_hybrid_cache_blocks_clamped_to_lcm():
"""HybridKVCacheCoordinator.cache_blocks() clamps to lcm_block_size.
"""HybridKVCacheCoordinator.cache_blocks() clamps to scheduler_block_size.
Chunks past the last lcm-aligned boundary can never participate in a
cache hit (find_longest_cache_hit always returns lcm-aligned hits), so
caching them only pollutes the prefix-cache hash map and keeps blocks
@@ -86,7 +86,7 @@ def test_chunked_local_attention_possible_cached_prefix():
kv_cache_group_ids=[0],
block_pool=block_pool,
kv_cache_spec=chunked_local_attention_spec,
use_eagle=False,
drop_eagle_block=False,
alignment_tokens=block_size,
)[0]
assert len(computed_blocks) == expect_length
@@ -157,7 +157,7 @@ def test_sliding_window_possible_cached_prefix():
kv_cache_group_ids=[0],
block_pool=block_pool,
kv_cache_spec=sliding_window_spec,
use_eagle=False,
drop_eagle_block=False,
alignment_tokens=block_size,
)[0]
assert len(computed_blocks) == expect_length
@@ -233,7 +233,7 @@ class MooncakeStoreCoordinator:
kv_cache_group_ids=group_ids,
block_pool=cast(BlockPool, cached_block_pool),
kv_cache_spec=spec,
use_eagle=(0 in eagle_indices),
drop_eagle_block=(0 in eagle_indices),
alignment_tokens=spec.block_size,
)
num_groups = len(self.kv_cache_groups)
@@ -262,9 +262,9 @@ class MooncakeStoreCoordinator:
)
continue
use_eagle = idx in eagle_indices and idx not in eagle_verified
drop_eagle_block = idx in eagle_indices and idx not in eagle_verified
_max_length = curr_hit_length
if use_eagle:
if drop_eagle_block:
_max_length = min(curr_hit_length + spec.block_size, max_length)
hashes = self.block_hashes_for_spec(block_hashes, spec)
hit_blocks = manager_cls.find_longest_cache_hit(
@@ -273,11 +273,11 @@ class MooncakeStoreCoordinator:
kv_cache_group_ids=group_ids,
block_pool=cast(BlockPool, cached_block_pool),
kv_cache_spec=spec,
use_eagle=use_eagle,
drop_eagle_block=drop_eagle_block,
alignment_tokens=self.lcm_block_size,
)
_new_hit_length = len(hit_blocks[0]) * spec.block_size
if use_eagle:
if drop_eagle_block:
eagle_verified.add(idx)
elif _new_hit_length < curr_hit_length:
eagle_verified.clear()
+61 -38
View File
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import NamedTuple
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector
@@ -381,6 +382,8 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"UnitaryKVCacheCoordinator assumes only one kv cache group"
)
# Single group; useless but just set ``use_eagle`` for consistency regardless.
self.single_type_managers[0].use_eagle = 0 in self.eagle_group_ids
def find_longest_cache_hit(
self,
@@ -393,7 +396,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
kv_cache_group_ids=[0],
block_pool=self.block_pool,
kv_cache_spec=self.kv_cache_spec,
use_eagle=0 in self.eagle_group_ids,
drop_eagle_block=0 in self.eagle_group_ids,
alignment_tokens=self.block_size,
dcp_world_size=self.dcp_world_size,
pcp_world_size=self.pcp_world_size,
@@ -401,6 +404,21 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
return hit_blocks, len(hit_blocks[0]) * self.block_size
class SpecGroup(NamedTuple):
"""KV cache groups that share one spec, batched together for a single
cache-hit lookup.
``use_eagle`` is True iff any member group is an EAGLE/MTP group. Members
sharing a spec are cached and looked up jointly, so the EAGLE last-block drop
is necessarily decided for the whole spec group.
"""
spec: KVCacheSpec
group_ids: list[int]
manager_cls: type[SingleTypeKVCacheManager]
use_eagle: bool
class HybridKVCacheCoordinator(KVCacheCoordinator):
"""
KV cache coordinator for hybrid models with multiple KV cache types, and
@@ -452,58 +470,63 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
Groups KV cache groups by their spec type for efficient batch processing
during cache hit lookup.
"""
attention_groups: list[
tuple[KVCacheSpec, list[int], type[SingleTypeKVCacheManager]]
] = []
self.attention_groups: list[SpecGroup] = []
for i, g in enumerate(self.kv_cache_config.kv_cache_groups):
manager_cls = self.single_type_managers[i].__class__
spec = g.kv_cache_spec
use_eagle = i in self.eagle_group_ids
# Try to find an existing group with the same spec
for existing_spec, group_ids, existing_cls in attention_groups:
if existing_spec == spec:
assert manager_cls is existing_cls, (
for idx, group in enumerate(self.attention_groups):
if group.spec == spec:
assert manager_cls is group.manager_cls, (
"Expected same manager class for identical KV cache specs."
)
group_ids.append(i)
group.group_ids.append(i)
if use_eagle and not group.use_eagle:
self.attention_groups[idx] = group._replace(use_eagle=True)
break
else:
attention_groups.append((spec, [i], manager_cls))
self.attention_groups.append(
SpecGroup(spec, [i], manager_cls, use_eagle)
)
assert len(attention_groups) > 1, (
assert len(self.attention_groups) > 1, (
"HybridKVCacheCoordinator requires at least two attention groups."
)
# Put full attention first: its efficient left-to-right scan provides
# a tighter initial bound, reducing work for subsequent groups.
self.attention_groups = sorted(
attention_groups,
key=lambda x: not isinstance(x[0], FullAttentionSpec),
self.attention_groups.sort(
key=lambda g: not isinstance(g.spec, FullAttentionSpec)
)
# Attention-group indices (into ``self.attention_groups``) that
# contain at least one EAGLE/MTP KV cache group.
self.eagle_attn_group_indices: set[int] = {
i
for i, (_, group_ids, _) in enumerate(self.attention_groups)
if any(gid in self.eagle_group_ids for gid in group_ids)
}
# Propagate the eagle bit to each manager (default to ``use_eagle=False``).
for group in self.attention_groups:
if group.use_eagle:
for gid in group.group_ids:
self.single_type_managers[gid].use_eagle = True
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
# Cache hits in this coordinator are always a multiple of
# ``scheduler_block_size`` tokens (see ``find_longest_cache_hit``).
# Within an aligned region, SWA groups only consult a subset of blocks
# Within an aligned region, SWA groups may only consult a subset of blocks
# per ``scheduler_block_size``-segment so the unused blocks also stay
# out of the prefix-cache hash map.
num_computed_tokens = (
aligned_num_computed_tokens = (
num_computed_tokens // self.scheduler_block_size * self.scheduler_block_size
)
for manager in self.single_type_managers:
num_tokens_to_cache = aligned_num_computed_tokens
# EAGLE groups match one block past each aligned boundary and drop
# it, so make that lookahead block eligible to be cached.
if manager.use_eagle and aligned_num_computed_tokens > 0:
num_tokens_to_cache = min(
num_computed_tokens,
aligned_num_computed_tokens + manager.block_size,
)
manager.cache_blocks(
request,
num_computed_tokens,
alignment_tokens=self.scheduler_block_size,
request, num_tokens_to_cache, alignment_tokens=self.scheduler_block_size
)
def find_longest_cache_hit(
@@ -543,7 +566,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
# Simple hybrid (1 full attn + 1 other): one iteration suffices.
# Full attn is always first if it exists.
is_simple_hybrid = len(self.attention_groups) == 2 and isinstance(
self.attention_groups[0][0], FullAttentionSpec
self.attention_groups[0].spec, FullAttentionSpec
)
# Attention-group indices whose EAGLE drop is verified at the current
@@ -554,7 +577,9 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
while True:
curr_hit_length = hit_length
for idx, (spec, group_ids, manager_cls) in enumerate(self.attention_groups):
for idx, (spec, group_ids, manager_cls, use_eagle) in enumerate(
self.attention_groups
):
cached_blocks = hit_blocks_by_group[group_ids[0]]
if isinstance(spec, FullAttentionSpec) and cached_blocks is not None:
# Full attention is downward-closed: we only need to look
@@ -565,12 +590,10 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
)
continue
use_eagle = (
idx in self.eagle_attn_group_indices and idx not in eagle_verified
)
drop_eagle_block = use_eagle and idx not in eagle_verified
_max_length = curr_hit_length
if use_eagle:
if drop_eagle_block:
# Eagle needs to match one more block and then pop the last.
_max_length = min(
curr_hit_length + spec.block_size, max_cache_hit_length
@@ -581,11 +604,11 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
kv_cache_group_ids=group_ids,
block_pool=self.block_pool,
kv_cache_spec=spec,
use_eagle=use_eagle,
drop_eagle_block=drop_eagle_block,
alignment_tokens=self.scheduler_block_size,
)
_new_hit_length = len(hit_blocks[0]) * spec.block_size
if use_eagle:
if drop_eagle_block:
eagle_verified.add(idx)
elif _new_hit_length < curr_hit_length:
# length shrunk; invalidate previous eagle verifications
@@ -601,10 +624,10 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
break
# Truncate full attention blocks to final hit_length (if present)
spec, group_ids, _ = self.attention_groups[0]
if isinstance(spec, FullAttentionSpec):
num_blocks = hit_length // spec.block_size
for group_id in group_ids:
first_group = self.attention_groups[0]
if isinstance(first_group.spec, FullAttentionSpec):
num_blocks = hit_length // first_group.spec.block_size
for group_id in first_group.group_ids:
if (blks := hit_blocks_by_group[group_id]) is not None:
del blks[num_blocks:]
+74 -36
View File
@@ -86,6 +86,12 @@ class SingleTypeKVCacheManager(ABC):
self.kv_cache_group_id = kv_cache_group_id
self._null_block = block_pool.null_block
# Whether this group's prefix-cache hits drop the EAGLE/MTP lookahead
# block. Only consulted by managers whose hit logic is sparse within an
# aligned segment (SWA). Initialized lazily by the coordinator after
# determining the attention groups.
self.use_eagle = False
@classmethod
def _get_num_evictable_blocks(cls, blocks: Sequence[KVCacheBlock]):
return sum(blk.ref_cnt == 0 and not blk.is_null for blk in blocks)
@@ -317,8 +323,12 @@ class SingleTypeKVCacheManager(ABC):
if alignment_tokens is None or alignment_tokens <= self.block_size:
block_mask = None
else:
block_mask = self._cache_block_mask(
num_cached_blocks, num_full_blocks, alignment_tokens
block_mask = self.reachable_block_mask(
num_cached_blocks,
num_full_blocks,
alignment_tokens,
self.kv_cache_spec,
self.use_eagle,
)
self.block_pool.cache_full_blocks(
request=request,
@@ -332,11 +342,14 @@ class SingleTypeKVCacheManager(ABC):
self.num_cached_block[request.request_id] = num_full_blocks
def _cache_block_mask(
self,
num_cached_blocks: int,
num_full_blocks: int,
@classmethod
def reachable_block_mask(
cls,
start_block: int,
num_blocks: int,
alignment_tokens: int,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> list[bool] | None:
"""Per-block mask for ``cache_full_blocks``. ``None`` means cache
every (non-null) block — the default for full attention.
@@ -389,7 +402,7 @@ class SingleTypeKVCacheManager(ABC):
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
drop_eagle_block: bool,
alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
@@ -409,7 +422,10 @@ class SingleTypeKVCacheManager(ABC):
kv_cache_group_ids: The ids of the kv cache groups.
block_pool: The block pool.
kv_cache_spec: The kv cache spec.
use_eagle: Whether to use eagle.
drop_eagle_block: Whether to drop the last matched block for EAGLE/MTP.
Always False for non-EAGLE/MTP groups, but can be False for EAGLE/MTP
groups too if the last block is already dropped (e.g., in a
convergence loop in `find_longest_cache_hit`).
alignment_tokens: The returned cache hit length (in tokens) should
be a multiple of this value (in tokens). By default, it should
be set to the block_size.
@@ -499,7 +515,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
drop_eagle_block: bool,
alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
@@ -528,7 +544,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
computed.append(cached)
else:
break
if use_eagle and computed_blocks[0]:
if drop_eagle_block and computed_blocks[0]:
# Need to drop the last matched block if eagle is enabled.
for computed in computed_blocks:
computed.pop()
@@ -556,6 +572,19 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
super().__init__(kv_cache_spec, **kwargs)
self.sliding_window = kv_cache_spec.sliding_window
@classmethod
def _contiguous_blocks_for_hit(
cls, window_size: int, block_size: int, use_eagle: bool
) -> int:
blocks = cdiv(window_size - 1, block_size)
if use_eagle:
# Need to drop the last matched block if eagle is enabled. For
# sliding window layer, we achieve this by increasing the number of
# contiguous blocks needed for prefix cache hit by one and dropping
# the last matched block.
blocks += 1
return blocks
@classmethod
def find_longest_cache_hit(
cls,
@@ -564,7 +593,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
drop_eagle_block: bool,
alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
@@ -575,17 +604,10 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
assert dcp_world_size == 1, "DCP not support sliding window attn now."
assert pcp_world_size == 1, "PCP not support sliding window attn now."
# The number of contiguous blocks needed for prefix cache hit.
# -1 since the input token itself is also included in the window
sliding_window_contiguous_blocks = cdiv(
kv_cache_spec.sliding_window - 1, kv_cache_spec.block_size
# The number of contiguous blocks needed for a prefix cache hit.
sliding_window_contiguous_blocks = cls._contiguous_blocks_for_hit(
kv_cache_spec.sliding_window, kv_cache_spec.block_size, drop_eagle_block
)
if use_eagle:
# Need to drop the last matched block if eagle is enabled. For
# sliding window layer, we achieve this by increasing the number of
# contiguous blocks needed for prefix cache hit by one and dropping
# the last matched block.
sliding_window_contiguous_blocks += 1
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
# optimize the time complexity from O(max_num_blocks) to
@@ -608,7 +630,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
# Skip prefix matching check if the block is not aligned with
# `alignment_tokens`.
if num_contiguous_blocks == 0 and block_size != alignment_tokens:
post_pop_blocks = i if use_eagle else i + 1
post_pop_blocks = i if drop_eagle_block else i + 1
if (post_pop_blocks * block_size) % alignment_tokens != 0:
continue
# Add the cached block to the computed blocks.
@@ -636,7 +658,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
):
for computed in computed_blocks:
computed.pop()
if use_eagle and computed_blocks[0]:
if drop_eagle_block and computed_blocks[0]:
for computed in computed_blocks:
computed.pop()
# Re-align after eagle pop: the pop may break the alignment
@@ -650,17 +672,33 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
computed.pop()
return computed_blocks
def _cache_block_mask(
self, num_cached_blocks: int, num_full_blocks: int, alignment_tokens: int
@classmethod
def reachable_block_mask(
cls,
start_block: int,
num_blocks: int,
alignment_tokens: int,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> list[bool] | None:
assert alignment_tokens > self.block_size
per_segment = alignment_tokens // self.block_size
tail = cdiv(self.sliding_window - 1, self.block_size)
if tail >= per_segment:
assert alignment_tokens > kv_cache_spec.block_size
assert isinstance(kv_cache_spec, SlidingWindowSpec)
per_segment = alignment_tokens // kv_cache_spec.block_size
need = cls._contiguous_blocks_for_hit(
window_size=kv_cache_spec.sliding_window,
block_size=kv_cache_spec.block_size,
use_eagle=use_eagle,
)
if need >= per_segment:
return None
skip = per_segment - tail
# The matched run's right edge sits on the aligned boundary block when
# EAGLE peeks one block past it (shift=1), otherwise on the last block
# before the boundary (shift=0). A block is reachable iff it falls in
# the ``need``-wide run ending at some boundary's right edge.
shift = 1 if use_eagle else 0
return [
i % per_segment >= skip for i in range(num_cached_blocks, num_full_blocks)
i >= shift and (i - shift) % per_segment >= per_segment - need
for i in range(start_block, num_blocks)
]
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
@@ -714,7 +752,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
drop_eagle_block: bool,
alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
@@ -745,7 +783,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
kv_cache_group_ids: The ids of the kv cache groups.
block_pool: The block pool.
kv_cache_spec: The kv cache spec.
use_eagle: Whether to use eagle.
drop_eagle_block: Whether to drop the last matched block for EAGLE/MTP.
dcp_world_size: The world size of decode context parallelism.
pcp_world_size: The world size of prefill context parallelism.
alignment_tokens: The returned cache hit length (in tokens) should
@@ -758,7 +796,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
"ChunkedLocalAttentionManager can only be used for "
"chunked local attention groups"
)
assert use_eagle is False, (
assert drop_eagle_block is False, (
"Hybrid KV cache is not supported for " + "eagle + chunked local attention."
)
assert dcp_world_size == 1, "DCP not support chunked local attn now."
@@ -874,7 +912,7 @@ class MambaManager(SingleTypeKVCacheManager):
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
drop_eagle_block: bool,
alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
@@ -1168,7 +1206,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
drop_eagle_block: bool,
alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,