mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Bugfix] Cache the EAGLE/MTP lookahead block in the SWA prefix-cache mask (#44082)
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
This commit is contained in:
@@ -2466,6 +2466,201 @@ def test_eagle_with_sliding_window():
|
||||
assert num_tokens == 0
|
||||
|
||||
|
||||
def test_eagle_swa_alignment_caches_extra_block():
|
||||
"""Regression: SWA + EAGLE with `sliding_window <= alignment_tokens`.
|
||||
|
||||
When the cache-hit alignment (lcm of per-group block sizes) is larger than
|
||||
the SWA window, the SWA mask only kept the last block of each aligned
|
||||
segment. EAGLE/MTP lookup needs ``tail + 1`` contiguous cached blocks and
|
||||
that +1 block lives at the next segment's first position, which was left
|
||||
uncached. The fix caches that extra block when ``use_eagle=True``.
|
||||
"""
|
||||
block_size = 8
|
||||
# Full group uses 4 * block_size, so lcm/alignment is 4 * block_size.
|
||||
# SWA group has sliding_window = block_size (i.e., tail = 1 block).
|
||||
# Without the fix, the second cached block needed for the EAGLE 2-block
|
||||
# match never exists -> EAGLE cache hit fails entirely.
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=100,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["full"],
|
||||
FullAttentionSpec(
|
||||
block_size=4 * block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float16,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["swa_mtp"],
|
||||
SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=block_size,
|
||||
),
|
||||
is_eagle_group=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
manager = make_kv_cache_manager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
use_eagle=True,
|
||||
)
|
||||
|
||||
# Prime the cache with a long prompt (16 swa blocks = 4 aligned segments).
|
||||
token_ids = [i for i in range(16) for _ in range(block_size)]
|
||||
req0 = make_request("0", token_ids, block_size, sha256)
|
||||
computed_blocks, _ = manager.get_computed_blocks(req0)
|
||||
blocks = manager.allocate_slots(
|
||||
req0,
|
||||
len(token_ids),
|
||||
len(computed_blocks.blocks[0]) * block_size,
|
||||
computed_blocks,
|
||||
)
|
||||
assert blocks is not None
|
||||
manager.free(req0)
|
||||
|
||||
# Second request with identical prompt should find an EAGLE cache hit.
|
||||
# Without the fix, ``num_computed_tokens`` is 0; with the fix, it lands at
|
||||
# an alignment boundary (multiple of 32 tokens, minus the EAGLE drop).
|
||||
req1 = make_request("1", token_ids, block_size, sha256)
|
||||
_, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert num_computed_tokens > 0, (
|
||||
"EAGLE + SWA with sliding_window <= alignment failed to find any "
|
||||
"cache hit; the +1 block past each segment boundary must be cached."
|
||||
)
|
||||
# Each aligned segment contributes 4 * block_size = 32 tokens; EAGLE drops
|
||||
# the last block (block_size tokens) from the hit.
|
||||
assert num_computed_tokens % (4 * block_size) == 0
|
||||
|
||||
|
||||
def test_eagle_swa_boundary_caches_post_boundary_block():
|
||||
"""EAGLE + SWA must cache the first block after an alignment boundary.
|
||||
|
||||
A 40-token computed prefix with 8-token SWA blocks and 32-token hybrid
|
||||
alignment needs SWA blocks 3 and 4 cached to reuse a 32-token prefix:
|
||||
block 3 is the segment tail, and block 4 is the EAGLE lookahead block
|
||||
that gets dropped after lookup.
|
||||
"""
|
||||
block_size = 8
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=100,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["full"],
|
||||
FullAttentionSpec(
|
||||
block_size=4 * block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float16,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["swa_mtp"],
|
||||
SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=block_size,
|
||||
),
|
||||
is_eagle_group=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
manager = make_kv_cache_manager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
use_eagle=True,
|
||||
)
|
||||
|
||||
token_ids = [i for i in range(5) for _ in range(block_size)]
|
||||
req0 = make_request("0", token_ids, block_size, sha256)
|
||||
computed_blocks, _ = manager.get_computed_blocks(req0)
|
||||
blocks = manager.allocate_slots(
|
||||
req0,
|
||||
len(token_ids),
|
||||
len(computed_blocks.blocks[0]) * block_size,
|
||||
computed_blocks,
|
||||
)
|
||||
assert blocks is not None
|
||||
|
||||
pool = manager.block_pool
|
||||
assert pool.get_cached_block(req0.block_hashes[3], kv_cache_group_ids=[1])
|
||||
assert pool.get_cached_block(req0.block_hashes[4], kv_cache_group_ids=[1])
|
||||
manager.free(req0)
|
||||
|
||||
req1 = make_request("1", token_ids + [999], block_size, sha256)
|
||||
_, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert num_computed_tokens == 4 * block_size
|
||||
|
||||
|
||||
def test_eagle_grouped_swa_siblings_use_same_cache_mask():
|
||||
"""Grouped SWA siblings must cache the EAGLE lookahead block together."""
|
||||
block_size = 8
|
||||
swa_spec = SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=block_size,
|
||||
)
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=100,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["full"],
|
||||
FullAttentionSpec(
|
||||
block_size=4 * block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float16,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(["swa_main"], swa_spec),
|
||||
KVCacheGroupSpec(["swa_mtp"], swa_spec, is_eagle_group=True),
|
||||
],
|
||||
)
|
||||
manager = make_kv_cache_manager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
use_eagle=True,
|
||||
)
|
||||
|
||||
token_ids = [i for i in range(9) for _ in range(block_size)]
|
||||
req0 = make_request("0", token_ids, block_size, sha256)
|
||||
computed_blocks, _ = manager.get_computed_blocks(req0)
|
||||
blocks = manager.allocate_slots(
|
||||
req0,
|
||||
len(token_ids),
|
||||
len(computed_blocks.blocks[0]) * block_size,
|
||||
computed_blocks,
|
||||
)
|
||||
assert blocks is not None
|
||||
|
||||
pool = manager.block_pool
|
||||
assert pool.get_cached_block(req0.block_hashes[4], kv_cache_group_ids=[1, 2])
|
||||
assert pool.get_cached_block(req0.block_hashes[8], kv_cache_group_ids=[1, 2])
|
||||
manager.free(req0)
|
||||
|
||||
req1 = make_request("1", token_ids + [999], block_size, sha256)
|
||||
_, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert num_computed_tokens == 8 * block_size
|
||||
|
||||
|
||||
def test_different_block_size():
|
||||
block_size = 16
|
||||
# full attention and sliding window attention layers have the same page size:
|
||||
@@ -2614,7 +2809,7 @@ def test_hybrid_cache_blocks_swa_tail_window_only():
|
||||
|
||||
|
||||
def test_hybrid_cache_blocks_clamped_to_lcm():
|
||||
"""HybridKVCacheCoordinator.cache_blocks() clamps to lcm_block_size.
|
||||
"""HybridKVCacheCoordinator.cache_blocks() clamps to scheduler_block_size.
|
||||
Chunks past the last lcm-aligned boundary can never participate in a
|
||||
cache hit (find_longest_cache_hit always returns lcm-aligned hits), so
|
||||
caching them only pollutes the prefix-cache hash map and keeps blocks
|
||||
|
||||
@@ -86,7 +86,7 @@ def test_chunked_local_attention_possible_cached_prefix():
|
||||
kv_cache_group_ids=[0],
|
||||
block_pool=block_pool,
|
||||
kv_cache_spec=chunked_local_attention_spec,
|
||||
use_eagle=False,
|
||||
drop_eagle_block=False,
|
||||
alignment_tokens=block_size,
|
||||
)[0]
|
||||
assert len(computed_blocks) == expect_length
|
||||
@@ -157,7 +157,7 @@ def test_sliding_window_possible_cached_prefix():
|
||||
kv_cache_group_ids=[0],
|
||||
block_pool=block_pool,
|
||||
kv_cache_spec=sliding_window_spec,
|
||||
use_eagle=False,
|
||||
drop_eagle_block=False,
|
||||
alignment_tokens=block_size,
|
||||
)[0]
|
||||
assert len(computed_blocks) == expect_length
|
||||
|
||||
@@ -233,7 +233,7 @@ class MooncakeStoreCoordinator:
|
||||
kv_cache_group_ids=group_ids,
|
||||
block_pool=cast(BlockPool, cached_block_pool),
|
||||
kv_cache_spec=spec,
|
||||
use_eagle=(0 in eagle_indices),
|
||||
drop_eagle_block=(0 in eagle_indices),
|
||||
alignment_tokens=spec.block_size,
|
||||
)
|
||||
num_groups = len(self.kv_cache_groups)
|
||||
@@ -262,9 +262,9 @@ class MooncakeStoreCoordinator:
|
||||
)
|
||||
continue
|
||||
|
||||
use_eagle = idx in eagle_indices and idx not in eagle_verified
|
||||
drop_eagle_block = idx in eagle_indices and idx not in eagle_verified
|
||||
_max_length = curr_hit_length
|
||||
if use_eagle:
|
||||
if drop_eagle_block:
|
||||
_max_length = min(curr_hit_length + spec.block_size, max_length)
|
||||
hashes = self.block_hashes_for_spec(block_hashes, spec)
|
||||
hit_blocks = manager_cls.find_longest_cache_hit(
|
||||
@@ -273,11 +273,11 @@ class MooncakeStoreCoordinator:
|
||||
kv_cache_group_ids=group_ids,
|
||||
block_pool=cast(BlockPool, cached_block_pool),
|
||||
kv_cache_spec=spec,
|
||||
use_eagle=use_eagle,
|
||||
drop_eagle_block=drop_eagle_block,
|
||||
alignment_tokens=self.lcm_block_size,
|
||||
)
|
||||
_new_hit_length = len(hit_blocks[0]) * spec.block_size
|
||||
if use_eagle:
|
||||
if drop_eagle_block:
|
||||
eagle_verified.add(idx)
|
||||
elif _new_hit_length < curr_hit_length:
|
||||
eagle_verified.clear()
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import NamedTuple
|
||||
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector
|
||||
@@ -381,6 +382,8 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
|
||||
"UnitaryKVCacheCoordinator assumes only one kv cache group"
|
||||
)
|
||||
# Single group; useless but just set ``use_eagle`` for consistency regardless.
|
||||
self.single_type_managers[0].use_eagle = 0 in self.eagle_group_ids
|
||||
|
||||
def find_longest_cache_hit(
|
||||
self,
|
||||
@@ -393,7 +396,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
kv_cache_group_ids=[0],
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.kv_cache_spec,
|
||||
use_eagle=0 in self.eagle_group_ids,
|
||||
drop_eagle_block=0 in self.eagle_group_ids,
|
||||
alignment_tokens=self.block_size,
|
||||
dcp_world_size=self.dcp_world_size,
|
||||
pcp_world_size=self.pcp_world_size,
|
||||
@@ -401,6 +404,21 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
return hit_blocks, len(hit_blocks[0]) * self.block_size
|
||||
|
||||
|
||||
class SpecGroup(NamedTuple):
|
||||
"""KV cache groups that share one spec, batched together for a single
|
||||
cache-hit lookup.
|
||||
|
||||
``use_eagle`` is True iff any member group is an EAGLE/MTP group. Members
|
||||
sharing a spec are cached and looked up jointly, so the EAGLE last-block drop
|
||||
is necessarily decided for the whole spec group.
|
||||
"""
|
||||
|
||||
spec: KVCacheSpec
|
||||
group_ids: list[int]
|
||||
manager_cls: type[SingleTypeKVCacheManager]
|
||||
use_eagle: bool
|
||||
|
||||
|
||||
class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
"""
|
||||
KV cache coordinator for hybrid models with multiple KV cache types, and
|
||||
@@ -452,58 +470,63 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
Groups KV cache groups by their spec type for efficient batch processing
|
||||
during cache hit lookup.
|
||||
"""
|
||||
attention_groups: list[
|
||||
tuple[KVCacheSpec, list[int], type[SingleTypeKVCacheManager]]
|
||||
] = []
|
||||
|
||||
self.attention_groups: list[SpecGroup] = []
|
||||
for i, g in enumerate(self.kv_cache_config.kv_cache_groups):
|
||||
manager_cls = self.single_type_managers[i].__class__
|
||||
spec = g.kv_cache_spec
|
||||
use_eagle = i in self.eagle_group_ids
|
||||
|
||||
# Try to find an existing group with the same spec
|
||||
for existing_spec, group_ids, existing_cls in attention_groups:
|
||||
if existing_spec == spec:
|
||||
assert manager_cls is existing_cls, (
|
||||
for idx, group in enumerate(self.attention_groups):
|
||||
if group.spec == spec:
|
||||
assert manager_cls is group.manager_cls, (
|
||||
"Expected same manager class for identical KV cache specs."
|
||||
)
|
||||
group_ids.append(i)
|
||||
group.group_ids.append(i)
|
||||
if use_eagle and not group.use_eagle:
|
||||
self.attention_groups[idx] = group._replace(use_eagle=True)
|
||||
break
|
||||
else:
|
||||
attention_groups.append((spec, [i], manager_cls))
|
||||
self.attention_groups.append(
|
||||
SpecGroup(spec, [i], manager_cls, use_eagle)
|
||||
)
|
||||
|
||||
assert len(attention_groups) > 1, (
|
||||
assert len(self.attention_groups) > 1, (
|
||||
"HybridKVCacheCoordinator requires at least two attention groups."
|
||||
)
|
||||
|
||||
# Put full attention first: its efficient left-to-right scan provides
|
||||
# a tighter initial bound, reducing work for subsequent groups.
|
||||
self.attention_groups = sorted(
|
||||
attention_groups,
|
||||
key=lambda x: not isinstance(x[0], FullAttentionSpec),
|
||||
self.attention_groups.sort(
|
||||
key=lambda g: not isinstance(g.spec, FullAttentionSpec)
|
||||
)
|
||||
|
||||
# Attention-group indices (into ``self.attention_groups``) that
|
||||
# contain at least one EAGLE/MTP KV cache group.
|
||||
self.eagle_attn_group_indices: set[int] = {
|
||||
i
|
||||
for i, (_, group_ids, _) in enumerate(self.attention_groups)
|
||||
if any(gid in self.eagle_group_ids for gid in group_ids)
|
||||
}
|
||||
# Propagate the eagle bit to each manager (default to ``use_eagle=False``).
|
||||
for group in self.attention_groups:
|
||||
if group.use_eagle:
|
||||
for gid in group.group_ids:
|
||||
self.single_type_managers[gid].use_eagle = True
|
||||
|
||||
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
|
||||
# Cache hits in this coordinator are always a multiple of
|
||||
# ``scheduler_block_size`` tokens (see ``find_longest_cache_hit``).
|
||||
# Within an aligned region, SWA groups only consult a subset of blocks
|
||||
# Within an aligned region, SWA groups may only consult a subset of blocks
|
||||
# per ``scheduler_block_size``-segment so the unused blocks also stay
|
||||
# out of the prefix-cache hash map.
|
||||
num_computed_tokens = (
|
||||
aligned_num_computed_tokens = (
|
||||
num_computed_tokens // self.scheduler_block_size * self.scheduler_block_size
|
||||
)
|
||||
for manager in self.single_type_managers:
|
||||
num_tokens_to_cache = aligned_num_computed_tokens
|
||||
# EAGLE groups match one block past each aligned boundary and drop
|
||||
# it, so make that lookahead block eligible to be cached.
|
||||
if manager.use_eagle and aligned_num_computed_tokens > 0:
|
||||
num_tokens_to_cache = min(
|
||||
num_computed_tokens,
|
||||
aligned_num_computed_tokens + manager.block_size,
|
||||
)
|
||||
manager.cache_blocks(
|
||||
request,
|
||||
num_computed_tokens,
|
||||
alignment_tokens=self.scheduler_block_size,
|
||||
request, num_tokens_to_cache, alignment_tokens=self.scheduler_block_size
|
||||
)
|
||||
|
||||
def find_longest_cache_hit(
|
||||
@@ -543,7 +566,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
# Simple hybrid (1 full attn + 1 other): one iteration suffices.
|
||||
# Full attn is always first if it exists.
|
||||
is_simple_hybrid = len(self.attention_groups) == 2 and isinstance(
|
||||
self.attention_groups[0][0], FullAttentionSpec
|
||||
self.attention_groups[0].spec, FullAttentionSpec
|
||||
)
|
||||
|
||||
# Attention-group indices whose EAGLE drop is verified at the current
|
||||
@@ -554,7 +577,9 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
while True:
|
||||
curr_hit_length = hit_length
|
||||
|
||||
for idx, (spec, group_ids, manager_cls) in enumerate(self.attention_groups):
|
||||
for idx, (spec, group_ids, manager_cls, use_eagle) in enumerate(
|
||||
self.attention_groups
|
||||
):
|
||||
cached_blocks = hit_blocks_by_group[group_ids[0]]
|
||||
if isinstance(spec, FullAttentionSpec) and cached_blocks is not None:
|
||||
# Full attention is downward-closed: we only need to look
|
||||
@@ -565,12 +590,10 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
)
|
||||
continue
|
||||
|
||||
use_eagle = (
|
||||
idx in self.eagle_attn_group_indices and idx not in eagle_verified
|
||||
)
|
||||
drop_eagle_block = use_eagle and idx not in eagle_verified
|
||||
|
||||
_max_length = curr_hit_length
|
||||
if use_eagle:
|
||||
if drop_eagle_block:
|
||||
# Eagle needs to match one more block and then pop the last.
|
||||
_max_length = min(
|
||||
curr_hit_length + spec.block_size, max_cache_hit_length
|
||||
@@ -581,11 +604,11 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
kv_cache_group_ids=group_ids,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=spec,
|
||||
use_eagle=use_eagle,
|
||||
drop_eagle_block=drop_eagle_block,
|
||||
alignment_tokens=self.scheduler_block_size,
|
||||
)
|
||||
_new_hit_length = len(hit_blocks[0]) * spec.block_size
|
||||
if use_eagle:
|
||||
if drop_eagle_block:
|
||||
eagle_verified.add(idx)
|
||||
elif _new_hit_length < curr_hit_length:
|
||||
# length shrunk; invalidate previous eagle verifications
|
||||
@@ -601,10 +624,10 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
break
|
||||
|
||||
# Truncate full attention blocks to final hit_length (if present)
|
||||
spec, group_ids, _ = self.attention_groups[0]
|
||||
if isinstance(spec, FullAttentionSpec):
|
||||
num_blocks = hit_length // spec.block_size
|
||||
for group_id in group_ids:
|
||||
first_group = self.attention_groups[0]
|
||||
if isinstance(first_group.spec, FullAttentionSpec):
|
||||
num_blocks = hit_length // first_group.spec.block_size
|
||||
for group_id in first_group.group_ids:
|
||||
if (blks := hit_blocks_by_group[group_id]) is not None:
|
||||
del blks[num_blocks:]
|
||||
|
||||
|
||||
@@ -86,6 +86,12 @@ class SingleTypeKVCacheManager(ABC):
|
||||
self.kv_cache_group_id = kv_cache_group_id
|
||||
self._null_block = block_pool.null_block
|
||||
|
||||
# Whether this group's prefix-cache hits drop the EAGLE/MTP lookahead
|
||||
# block. Only consulted by managers whose hit logic is sparse within an
|
||||
# aligned segment (SWA). Initialized lazily by the coordinator after
|
||||
# determining the attention groups.
|
||||
self.use_eagle = False
|
||||
|
||||
@classmethod
|
||||
def _get_num_evictable_blocks(cls, blocks: Sequence[KVCacheBlock]):
|
||||
return sum(blk.ref_cnt == 0 and not blk.is_null for blk in blocks)
|
||||
@@ -317,8 +323,12 @@ class SingleTypeKVCacheManager(ABC):
|
||||
if alignment_tokens is None or alignment_tokens <= self.block_size:
|
||||
block_mask = None
|
||||
else:
|
||||
block_mask = self._cache_block_mask(
|
||||
num_cached_blocks, num_full_blocks, alignment_tokens
|
||||
block_mask = self.reachable_block_mask(
|
||||
num_cached_blocks,
|
||||
num_full_blocks,
|
||||
alignment_tokens,
|
||||
self.kv_cache_spec,
|
||||
self.use_eagle,
|
||||
)
|
||||
self.block_pool.cache_full_blocks(
|
||||
request=request,
|
||||
@@ -332,11 +342,14 @@ class SingleTypeKVCacheManager(ABC):
|
||||
|
||||
self.num_cached_block[request.request_id] = num_full_blocks
|
||||
|
||||
def _cache_block_mask(
|
||||
self,
|
||||
num_cached_blocks: int,
|
||||
num_full_blocks: int,
|
||||
@classmethod
|
||||
def reachable_block_mask(
|
||||
cls,
|
||||
start_block: int,
|
||||
num_blocks: int,
|
||||
alignment_tokens: int,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
) -> list[bool] | None:
|
||||
"""Per-block mask for ``cache_full_blocks``. ``None`` means cache
|
||||
every (non-null) block — the default for full attention.
|
||||
@@ -389,7 +402,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
drop_eagle_block: bool,
|
||||
alignment_tokens: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
@@ -409,7 +422,10 @@ class SingleTypeKVCacheManager(ABC):
|
||||
kv_cache_group_ids: The ids of the kv cache groups.
|
||||
block_pool: The block pool.
|
||||
kv_cache_spec: The kv cache spec.
|
||||
use_eagle: Whether to use eagle.
|
||||
drop_eagle_block: Whether to drop the last matched block for EAGLE/MTP.
|
||||
Always False for non-EAGLE/MTP groups, but can be False for EAGLE/MTP
|
||||
groups too if the last block is already dropped (e.g., in a
|
||||
convergence loop in `find_longest_cache_hit`).
|
||||
alignment_tokens: The returned cache hit length (in tokens) should
|
||||
be a multiple of this value (in tokens). By default, it should
|
||||
be set to the block_size.
|
||||
@@ -499,7 +515,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
drop_eagle_block: bool,
|
||||
alignment_tokens: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
@@ -528,7 +544,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
computed.append(cached)
|
||||
else:
|
||||
break
|
||||
if use_eagle and computed_blocks[0]:
|
||||
if drop_eagle_block and computed_blocks[0]:
|
||||
# Need to drop the last matched block if eagle is enabled.
|
||||
for computed in computed_blocks:
|
||||
computed.pop()
|
||||
@@ -556,6 +572,19 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
super().__init__(kv_cache_spec, **kwargs)
|
||||
self.sliding_window = kv_cache_spec.sliding_window
|
||||
|
||||
@classmethod
|
||||
def _contiguous_blocks_for_hit(
|
||||
cls, window_size: int, block_size: int, use_eagle: bool
|
||||
) -> int:
|
||||
blocks = cdiv(window_size - 1, block_size)
|
||||
if use_eagle:
|
||||
# Need to drop the last matched block if eagle is enabled. For
|
||||
# sliding window layer, we achieve this by increasing the number of
|
||||
# contiguous blocks needed for prefix cache hit by one and dropping
|
||||
# the last matched block.
|
||||
blocks += 1
|
||||
return blocks
|
||||
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
@@ -564,7 +593,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
drop_eagle_block: bool,
|
||||
alignment_tokens: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
@@ -575,17 +604,10 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
assert dcp_world_size == 1, "DCP not support sliding window attn now."
|
||||
assert pcp_world_size == 1, "PCP not support sliding window attn now."
|
||||
|
||||
# The number of contiguous blocks needed for prefix cache hit.
|
||||
# -1 since the input token itself is also included in the window
|
||||
sliding_window_contiguous_blocks = cdiv(
|
||||
kv_cache_spec.sliding_window - 1, kv_cache_spec.block_size
|
||||
# The number of contiguous blocks needed for a prefix cache hit.
|
||||
sliding_window_contiguous_blocks = cls._contiguous_blocks_for_hit(
|
||||
kv_cache_spec.sliding_window, kv_cache_spec.block_size, drop_eagle_block
|
||||
)
|
||||
if use_eagle:
|
||||
# Need to drop the last matched block if eagle is enabled. For
|
||||
# sliding window layer, we achieve this by increasing the number of
|
||||
# contiguous blocks needed for prefix cache hit by one and dropping
|
||||
# the last matched block.
|
||||
sliding_window_contiguous_blocks += 1
|
||||
|
||||
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
|
||||
# optimize the time complexity from O(max_num_blocks) to
|
||||
@@ -608,7 +630,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
# Skip prefix matching check if the block is not aligned with
|
||||
# `alignment_tokens`.
|
||||
if num_contiguous_blocks == 0 and block_size != alignment_tokens:
|
||||
post_pop_blocks = i if use_eagle else i + 1
|
||||
post_pop_blocks = i if drop_eagle_block else i + 1
|
||||
if (post_pop_blocks * block_size) % alignment_tokens != 0:
|
||||
continue
|
||||
# Add the cached block to the computed blocks.
|
||||
@@ -636,7 +658,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
):
|
||||
for computed in computed_blocks:
|
||||
computed.pop()
|
||||
if use_eagle and computed_blocks[0]:
|
||||
if drop_eagle_block and computed_blocks[0]:
|
||||
for computed in computed_blocks:
|
||||
computed.pop()
|
||||
# Re-align after eagle pop: the pop may break the alignment
|
||||
@@ -650,17 +672,33 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
computed.pop()
|
||||
return computed_blocks
|
||||
|
||||
def _cache_block_mask(
|
||||
self, num_cached_blocks: int, num_full_blocks: int, alignment_tokens: int
|
||||
@classmethod
|
||||
def reachable_block_mask(
|
||||
cls,
|
||||
start_block: int,
|
||||
num_blocks: int,
|
||||
alignment_tokens: int,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
) -> list[bool] | None:
|
||||
assert alignment_tokens > self.block_size
|
||||
per_segment = alignment_tokens // self.block_size
|
||||
tail = cdiv(self.sliding_window - 1, self.block_size)
|
||||
if tail >= per_segment:
|
||||
assert alignment_tokens > kv_cache_spec.block_size
|
||||
assert isinstance(kv_cache_spec, SlidingWindowSpec)
|
||||
per_segment = alignment_tokens // kv_cache_spec.block_size
|
||||
need = cls._contiguous_blocks_for_hit(
|
||||
window_size=kv_cache_spec.sliding_window,
|
||||
block_size=kv_cache_spec.block_size,
|
||||
use_eagle=use_eagle,
|
||||
)
|
||||
if need >= per_segment:
|
||||
return None
|
||||
skip = per_segment - tail
|
||||
# The matched run's right edge sits on the aligned boundary block when
|
||||
# EAGLE peeks one block past it (shift=1), otherwise on the last block
|
||||
# before the boundary (shift=0). A block is reachable iff it falls in
|
||||
# the ``need``-wide run ending at some boundary's right edge.
|
||||
shift = 1 if use_eagle else 0
|
||||
return [
|
||||
i % per_segment >= skip for i in range(num_cached_blocks, num_full_blocks)
|
||||
i >= shift and (i - shift) % per_segment >= per_segment - need
|
||||
for i in range(start_block, num_blocks)
|
||||
]
|
||||
|
||||
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
|
||||
@@ -714,7 +752,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
drop_eagle_block: bool,
|
||||
alignment_tokens: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
@@ -745,7 +783,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
||||
kv_cache_group_ids: The ids of the kv cache groups.
|
||||
block_pool: The block pool.
|
||||
kv_cache_spec: The kv cache spec.
|
||||
use_eagle: Whether to use eagle.
|
||||
drop_eagle_block: Whether to drop the last matched block for EAGLE/MTP.
|
||||
dcp_world_size: The world size of decode context parallelism.
|
||||
pcp_world_size: The world size of prefill context parallelism.
|
||||
alignment_tokens: The returned cache hit length (in tokens) should
|
||||
@@ -758,7 +796,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
||||
"ChunkedLocalAttentionManager can only be used for "
|
||||
"chunked local attention groups"
|
||||
)
|
||||
assert use_eagle is False, (
|
||||
assert drop_eagle_block is False, (
|
||||
"Hybrid KV cache is not supported for " + "eagle + chunked local attention."
|
||||
)
|
||||
assert dcp_world_size == 1, "DCP not support chunked local attn now."
|
||||
@@ -874,7 +912,7 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
drop_eagle_block: bool,
|
||||
alignment_tokens: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
@@ -1168,7 +1206,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
drop_eagle_block: bool,
|
||||
alignment_tokens: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
|
||||
Reference in New Issue
Block a user