[HMA] [KVEvent] Enable GPU-side KV events for HMA (#37688)

Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
Co-authored-by: Or Ozeri <or@ozery.com>
This commit is contained in:
Martin Hickey
2026-04-12 08:01:02 +01:00
committed by GitHub
parent 17e787a779
commit cc07dad789
8 changed files with 300 additions and 28 deletions
@@ -43,10 +43,13 @@ class BlockStored(KVCacheEvent):
prompt embeddings data, etc. for that specific block.
"""
group_idx: int | None = None
class BlockRemoved(KVCacheEvent):
block_hashes: list[ExternalBlockHash]
medium: str | None
group_idx: int | None = None
class AllBlocksCleared(KVCacheEvent):
+74
View File
@@ -0,0 +1,74 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.distributed.kv_events import BlockRemoved, BlockStored
# Minimal ExternalBlockHash for testing (bytes are a valid ExternalBlockHash).
_FAKE_HASH: bytes = b"\xab" * 32
def _make_block_stored(group_idx: int | None = None) -> BlockStored:
return BlockStored(
block_hashes=[_FAKE_HASH],
parent_block_hash=None,
token_ids=[1, 2, 3, 4],
block_size=4,
lora_id=None,
medium="GPU",
lora_name=None,
group_idx=group_idx,
)
def _make_block_removed(group_idx: int | None = None) -> BlockRemoved:
return BlockRemoved(
block_hashes=[_FAKE_HASH],
medium="GPU",
group_idx=group_idx,
)
def test_block_stored_default_group_idx_is_none():
"""group_idx defaults to None when not provided."""
event = _make_block_stored()
assert event.group_idx is None
def test_block_removed_default_group_idx_is_none():
"""group_idx defaults to None when not provided."""
event = _make_block_removed()
assert event.group_idx is None
@pytest.mark.parametrize("group_idx", [1, 2, 3])
def test_block_stored_hash_differs_by_group_idx(group_idx: int):
"""BlockStored events that differ only in group_idx must hash differently."""
other_group_idx = group_idx + 1
event_a = _make_block_stored(group_idx=group_idx)
event_b = _make_block_stored(group_idx=other_group_idx)
assert hash(event_a) != hash(event_b)
def test_block_stored_hash_same_for_equal_group_idx():
"""Two BlockStored events with identical fields produce the same hash."""
event_a = _make_block_stored(group_idx=1)
event_b = _make_block_stored(group_idx=1)
assert hash(event_a) == hash(event_b)
@pytest.mark.parametrize("group_idx", [1, 2, 3])
def test_block_removed_hash_differs_by_group_idx(group_idx: int):
"""BlockRemoved events that differ only in group_idx must hash differently."""
other_group_idx = group_idx + 1
event_a = _make_block_removed(group_idx=group_idx)
event_b = _make_block_removed(group_idx=other_group_idx)
assert hash(event_a) != hash(event_b)
def test_block_removed_hash_same_for_equal_group_idx():
"""Two BlockRemoved events with identical fields produce the same hash."""
event_a = _make_block_removed(group_idx=1)
event_b = _make_block_removed(group_idx=1)
assert hash(event_a) == hash(event_b)
+28
View File
@@ -10,6 +10,7 @@ import torch
import vllm.v1.core.kv_cache_utils as kv_cache_utils
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.config.kv_events import KVEventsConfig
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
@@ -2137,3 +2138,30 @@ def test_unify_hybrid_kv_cache_specs():
with pytest.raises(ValueError):
kv_cache_utils.unify_hybrid_kv_cache_specs(kv_cache_spec)
def test_hma_not_disabled_when_kv_events_enabled():
"""
Test enabling KV events must not force disable_hybrid_kv_cache_manager to True.
This test guards against that regression by verifying that a VllmConfig
with kv_events_config set still resolves disable_hybrid_kv_cache_manager
to False (i.e. HMA remains enabled) when no other condition requires it
to be disabled.
"""
model_config = ModelConfig(max_model_len=16)
kv_events_config = KVEventsConfig(
enable_kv_cache_events=True,
publisher="null",
)
# Leave disable_hybrid_kv_cache_manager as None (the default) so that
# VllmConfig.__post_init__ resolves it automatically.
vllm_config = VllmConfig(
model_config=model_config,
kv_events_config=kv_events_config,
)
assert vllm_config.scheduler_config.disable_hybrid_kv_cache_manager is False, (
"kv_events_config must not force-disable the hybrid KV cache manager."
)
+150 -1
View File
@@ -1970,6 +1970,7 @@ def test_null_parent_block_hash():
block_size = 1
num_cached_blocks = 2
num_full_blocks = 4
kv_cache_group_id = 0
pool = BlockPool(
num_gpu_blocks=8,
@@ -2002,7 +2003,7 @@ def test_null_parent_block_hash():
num_cached_blocks=num_cached_blocks,
num_full_blocks=num_full_blocks,
block_size=block_size,
kv_cache_group_id=0,
kv_cache_group_id=kv_cache_group_id,
)
events = pool.take_events()
@@ -2021,6 +2022,7 @@ def test_null_parent_block_hash():
for h in req.block_hashes[num_cached_blocks:num_full_blocks]
]
assert event.block_hashes == expected_new_hashes
assert event.group_idx == kv_cache_group_id
# Ensure we didn't accidentally assign a hash to the null block.
assert pool.null_block.block_hash is None
@@ -2087,6 +2089,153 @@ def test_kv_cache_events_with_lora(blocks_to_cache: int):
assert block_stored_event.block_size == block_size
@pytest.mark.parametrize("group_id", [0, 1, 2])
def test_block_stored_event_group_idx(group_id: int):
"""Test BlockStored events emitted by cache_full_blocks carry the correct
group_idx."""
block_size = 4
num_tokens = block_size * 2
pool = BlockPool(
num_gpu_blocks=5,
enable_caching=True,
hash_block_size=block_size,
enable_kv_cache_events=True,
)
req = make_request(
"req_grp_idx",
prompt_token_ids=list(range(num_tokens)),
block_size=block_size,
hash_fn=sha256,
)
blocks = pool.get_new_blocks(2)
pool.cache_full_blocks(
request=req,
blocks=blocks,
num_cached_blocks=0,
num_full_blocks=2,
block_size=block_size,
kv_cache_group_id=group_id,
)
events = pool.take_events()
assert len(events) == 1
assert isinstance(events[0], BlockStored)
assert events[0].group_idx == group_id
def test_block_stored_event_group_idx_multiple_groups():
"""
Test BlockStored events for separate HMA groups that each carry the
correct group_idx.
Simulates the HMA scenario where full-attention blocks (group 0) and
sliding-window blocks (group 1) are cached independently and must be
distinguishable by consumers doing HMA-aware prefix-cache routing.
"""
block_size = 4
num_tokens = block_size * 2
# null block + 4 usable (2 per group)
pool = BlockPool(
num_gpu_blocks=5,
enable_caching=True,
hash_block_size=block_size,
enable_kv_cache_events=True,
)
req = make_request(
"req_multi_grp",
prompt_token_ids=list(range(num_tokens)),
block_size=block_size,
hash_fn=sha256,
)
# Cache blocks for group 0 (full-attention)
blocks_grp0 = pool.get_new_blocks(2)
pool.cache_full_blocks(
request=req,
blocks=blocks_grp0,
num_cached_blocks=0,
num_full_blocks=2,
block_size=block_size,
kv_cache_group_id=0,
)
# Cache blocks for group 1 (sliding-window)
blocks_grp1 = pool.get_new_blocks(2)
pool.cache_full_blocks(
request=req,
blocks=blocks_grp1,
num_cached_blocks=0,
num_full_blocks=2,
block_size=block_size,
kv_cache_group_id=1,
)
events = pool.take_events()
assert len(events) == 2
assert isinstance(events[0], BlockStored)
assert events[0].group_idx == 0
assert isinstance(events[1], BlockStored)
assert events[1].group_idx == 1
@pytest.mark.parametrize("group_id", [0, 1, 2])
def test_block_removed_event_group_idx(group_id: int):
"""
Test BlockRemoved events emitted on eviction carry the group_idx extracted
from the evicted block's BlockHashWithGroupId via get_group_id().
"""
block_size = 4
num_tokens = block_size * 2
# null block + 4 usable; allocate all 4, cache 2, free all, re-allocate
# all 4 so the 2 cached blocks are forced through _maybe_evict_cached_block.
pool = BlockPool(
num_gpu_blocks=5,
enable_caching=True,
hash_block_size=block_size,
enable_kv_cache_events=True,
)
req = make_request(
"req_evict_grp",
prompt_token_ids=list(range(num_tokens)),
block_size=block_size,
hash_fn=sha256,
)
# Allocate all usable blocks and cache the first two for the target group.
all_blocks = pool.get_new_blocks(4)
pool.cache_full_blocks(
request=req,
blocks=all_blocks,
num_cached_blocks=0,
num_full_blocks=2,
block_size=block_size,
kv_cache_group_id=group_id,
)
# Drain the BlockStored events so only eviction events remain later.
pool.take_events()
# Return all blocks to the free queue so they become eviction candidates.
pool.free_blocks(all_blocks)
# Re-allocate all blocks; the two with hashes trigger BlockRemoved events.
pool.get_new_blocks(4)
events = pool.take_events()
removed_events = [e for e in events if isinstance(e, BlockRemoved)]
assert len(removed_events) == 2
for event in removed_events:
assert event.group_idx == group_id
def test_eagle_enabled_removes_last_block():
"""Verify Eagle does NOT remove blocks when request
length is divisible by block size."""
+31 -19
View File
@@ -936,6 +936,13 @@ async def test_engine_core_client_future_utility_async(
client.shutdown()
@pytest.mark.parametrize(
"model_name,num_groups",
[
("meta-llama/Llama-3.2-1B-Instruct", 1),
("google/gemma-3-1b-it", 7),
],
)
@pytest.mark.parametrize(
"multiprocessing_mode,publisher_config",
[(True, "tcp"), (False, "inproc")],
@@ -944,12 +951,14 @@ async def test_engine_core_client_future_utility_async(
def test_kv_cache_events(
multiprocessing_mode: bool,
publisher_config,
model_name: str,
num_groups: int,
):
block_size = 16
num_blocks = 2
engine_args = EngineArgs(
model=MODEL_NAME,
model=model_name,
enforce_eager=True,
enable_prefix_caching=True,
block_size=block_size,
@@ -985,26 +994,29 @@ def test_kv_cache_events(
assert result is not None, "No message received"
seq, received = result
assert seq == 0, "Sequence number mismatch"
assert len(received.events) == 1, "We should have exactly one BlockStored event"
event = received.events[0]
assert isinstance(event, BlockStored), "We should have a BlockStored event"
assert len(event.block_hashes) == num_blocks, (
"We should have a BlockStored event with 2 block_hashes"
)
assert event.block_size == block_size, (
"Block size should be the same as the block size"
)
assert event.parent_block_hash is None, "Parent block hash should be None"
assert event.lora_id is None, "Lora id should be None"
assert event.lora_name is None, "Lora name should be None"
assert len(event.token_ids) == num_blocks * block_size, (
"Token ids should be the same as the custom tokens"
)
assert event.token_ids == custom_tokens, (
"Token ids should be the same as the custom tokens"
assert len(received.events) == num_groups, (
f"Expected {num_groups} BlockStored event(s), got {len(received.events)}"
)
for index, event in enumerate(received.events):
assert isinstance(event, BlockStored), "We should have a BlockStored event"
assert len(event.block_hashes) == num_blocks, (
"We should have a BlockStored event with 2 block_hashes"
)
assert event.block_size == block_size, (
"Block size should be the same as the block size"
)
assert event.parent_block_hash is None, "Parent block hash should be None"
assert event.lora_id is None, "Lora id should be None"
assert event.lora_name is None, "Lora name should be None"
assert len(event.token_ids) == num_blocks * block_size, (
"Token ids should be the same as the custom tokens"
)
assert event.token_ids == custom_tokens, (
"Token ids should be the same as the custom tokens"
)
assert event.group_idx == index
finally:
client.shutdown()
subscriber.close()
-3
View File
@@ -1229,9 +1229,6 @@ class VllmConfig:
if not current_platform.support_hybrid_kv_cache():
# Hybrid KV cache manager is not supported on non-GPU platforms.
need_disable_hybrid_kv_cache_manager = True
if self.kv_events_config is not None:
# Hybrid KV cache manager is not compatible with KV events.
need_disable_hybrid_kv_cache_manager = True
if (
self.model_config is not None
and self.model_config.attention_chunk_size is not None
+11 -1
View File
@@ -67,6 +67,8 @@ class BlockStored(KVCacheEvent):
KV cache consumers to reconstruct block hashes.
"""
group_idx: int | None = None
def __hash__(self) -> int:
return hash(
(
@@ -77,6 +79,7 @@ class BlockStored(KVCacheEvent):
self.lora_id,
self.medium,
tuple(self.extra_keys) if self.extra_keys else None,
self.group_idx,
)
)
@@ -84,9 +87,16 @@ class BlockStored(KVCacheEvent):
class BlockRemoved(KVCacheEvent):
block_hashes: list[ExternalBlockHash]
medium: str | None
group_idx: int | None = None
def __hash__(self) -> int:
return hash((tuple(self.block_hashes), self.medium))
return hash(
(
tuple(self.block_hashes),
self.medium,
self.group_idx,
)
)
class AllBlocksCleared(KVCacheEvent):
+3 -4
View File
@@ -22,6 +22,7 @@ from vllm.v1.core.kv_cache_utils import (
KVCacheBlock,
generate_block_hash_extra_keys,
get_block_hash,
get_group_id,
make_block_hash_with_group_id,
maybe_convert_block_hash,
)
@@ -314,6 +315,7 @@ class BlockPool:
if request.lora_request
else None,
extra_keys=extra_keys_list if extra_keys_list else None,
group_idx=kv_cache_group_id,
)
)
@@ -377,14 +379,11 @@ class BlockPool:
block.reset_hash()
if self.enable_kv_cache_events:
# FIXME (Chen): Not sure whether we should return `hash_value`
# or `(hash_value, group_id)` here. But it's fine now because
# we disable hybrid kv cache manager when kv cache event is
# enabled, so there is only one group.
self.kv_event_queue.append(
BlockRemoved(
block_hashes=[maybe_convert_block_hash(get_block_hash(block_hash))],
medium=MEDIUM_GPU,
group_idx=get_group_id(block_hash),
)
)
return True