mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[HMA] [KVEvent] Enable GPU-side KV events for HMA (#37688)
Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com> Co-authored-by: Or Ozeri <or@ozery.com>
This commit is contained in:
@@ -43,10 +43,13 @@ class BlockStored(KVCacheEvent):
|
||||
prompt embeddings data, etc. for that specific block.
|
||||
"""
|
||||
|
||||
group_idx: int | None = None
|
||||
|
||||
|
||||
class BlockRemoved(KVCacheEvent):
|
||||
block_hashes: list[ExternalBlockHash]
|
||||
medium: str | None
|
||||
group_idx: int | None = None
|
||||
|
||||
|
||||
class AllBlocksCleared(KVCacheEvent):
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.distributed.kv_events import BlockRemoved, BlockStored
|
||||
|
||||
# Minimal ExternalBlockHash for testing (bytes are a valid ExternalBlockHash).
|
||||
_FAKE_HASH: bytes = b"\xab" * 32
|
||||
|
||||
|
||||
def _make_block_stored(group_idx: int | None = None) -> BlockStored:
|
||||
return BlockStored(
|
||||
block_hashes=[_FAKE_HASH],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1, 2, 3, 4],
|
||||
block_size=4,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
lora_name=None,
|
||||
group_idx=group_idx,
|
||||
)
|
||||
|
||||
|
||||
def _make_block_removed(group_idx: int | None = None) -> BlockRemoved:
|
||||
return BlockRemoved(
|
||||
block_hashes=[_FAKE_HASH],
|
||||
medium="GPU",
|
||||
group_idx=group_idx,
|
||||
)
|
||||
|
||||
|
||||
def test_block_stored_default_group_idx_is_none():
|
||||
"""group_idx defaults to None when not provided."""
|
||||
event = _make_block_stored()
|
||||
assert event.group_idx is None
|
||||
|
||||
|
||||
def test_block_removed_default_group_idx_is_none():
|
||||
"""group_idx defaults to None when not provided."""
|
||||
event = _make_block_removed()
|
||||
assert event.group_idx is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("group_idx", [1, 2, 3])
|
||||
def test_block_stored_hash_differs_by_group_idx(group_idx: int):
|
||||
"""BlockStored events that differ only in group_idx must hash differently."""
|
||||
other_group_idx = group_idx + 1
|
||||
event_a = _make_block_stored(group_idx=group_idx)
|
||||
event_b = _make_block_stored(group_idx=other_group_idx)
|
||||
assert hash(event_a) != hash(event_b)
|
||||
|
||||
|
||||
def test_block_stored_hash_same_for_equal_group_idx():
|
||||
"""Two BlockStored events with identical fields produce the same hash."""
|
||||
event_a = _make_block_stored(group_idx=1)
|
||||
event_b = _make_block_stored(group_idx=1)
|
||||
assert hash(event_a) == hash(event_b)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("group_idx", [1, 2, 3])
|
||||
def test_block_removed_hash_differs_by_group_idx(group_idx: int):
|
||||
"""BlockRemoved events that differ only in group_idx must hash differently."""
|
||||
other_group_idx = group_idx + 1
|
||||
event_a = _make_block_removed(group_idx=group_idx)
|
||||
event_b = _make_block_removed(group_idx=other_group_idx)
|
||||
assert hash(event_a) != hash(event_b)
|
||||
|
||||
|
||||
def test_block_removed_hash_same_for_equal_group_idx():
|
||||
"""Two BlockRemoved events with identical fields produce the same hash."""
|
||||
event_a = _make_block_removed(group_idx=1)
|
||||
event_b = _make_block_removed(group_idx=1)
|
||||
assert hash(event_a) == hash(event_b)
|
||||
@@ -10,6 +10,7 @@ import torch
|
||||
|
||||
import vllm.v1.core.kv_cache_utils as kv_cache_utils
|
||||
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.config.kv_events import KVEventsConfig
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalFeatureSpec,
|
||||
@@ -2137,3 +2138,30 @@ def test_unify_hybrid_kv_cache_specs():
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
kv_cache_utils.unify_hybrid_kv_cache_specs(kv_cache_spec)
|
||||
|
||||
|
||||
def test_hma_not_disabled_when_kv_events_enabled():
|
||||
"""
|
||||
Test enabling KV events must not force disable_hybrid_kv_cache_manager to True.
|
||||
|
||||
This test guards against that regression by verifying that a VllmConfig
|
||||
with kv_events_config set still resolves disable_hybrid_kv_cache_manager
|
||||
to False (i.e. HMA remains enabled) when no other condition requires it
|
||||
to be disabled.
|
||||
"""
|
||||
model_config = ModelConfig(max_model_len=16)
|
||||
kv_events_config = KVEventsConfig(
|
||||
enable_kv_cache_events=True,
|
||||
publisher="null",
|
||||
)
|
||||
|
||||
# Leave disable_hybrid_kv_cache_manager as None (the default) so that
|
||||
# VllmConfig.__post_init__ resolves it automatically.
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
kv_events_config=kv_events_config,
|
||||
)
|
||||
|
||||
assert vllm_config.scheduler_config.disable_hybrid_kv_cache_manager is False, (
|
||||
"kv_events_config must not force-disable the hybrid KV cache manager."
|
||||
)
|
||||
|
||||
@@ -1970,6 +1970,7 @@ def test_null_parent_block_hash():
|
||||
block_size = 1
|
||||
num_cached_blocks = 2
|
||||
num_full_blocks = 4
|
||||
kv_cache_group_id = 0
|
||||
|
||||
pool = BlockPool(
|
||||
num_gpu_blocks=8,
|
||||
@@ -2002,7 +2003,7 @@ def test_null_parent_block_hash():
|
||||
num_cached_blocks=num_cached_blocks,
|
||||
num_full_blocks=num_full_blocks,
|
||||
block_size=block_size,
|
||||
kv_cache_group_id=0,
|
||||
kv_cache_group_id=kv_cache_group_id,
|
||||
)
|
||||
|
||||
events = pool.take_events()
|
||||
@@ -2021,6 +2022,7 @@ def test_null_parent_block_hash():
|
||||
for h in req.block_hashes[num_cached_blocks:num_full_blocks]
|
||||
]
|
||||
assert event.block_hashes == expected_new_hashes
|
||||
assert event.group_idx == kv_cache_group_id
|
||||
|
||||
# Ensure we didn't accidentally assign a hash to the null block.
|
||||
assert pool.null_block.block_hash is None
|
||||
@@ -2087,6 +2089,153 @@ def test_kv_cache_events_with_lora(blocks_to_cache: int):
|
||||
assert block_stored_event.block_size == block_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize("group_id", [0, 1, 2])
|
||||
def test_block_stored_event_group_idx(group_id: int):
|
||||
"""Test BlockStored events emitted by cache_full_blocks carry the correct
|
||||
group_idx."""
|
||||
block_size = 4
|
||||
num_tokens = block_size * 2
|
||||
|
||||
pool = BlockPool(
|
||||
num_gpu_blocks=5,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
enable_kv_cache_events=True,
|
||||
)
|
||||
|
||||
req = make_request(
|
||||
"req_grp_idx",
|
||||
prompt_token_ids=list(range(num_tokens)),
|
||||
block_size=block_size,
|
||||
hash_fn=sha256,
|
||||
)
|
||||
|
||||
blocks = pool.get_new_blocks(2)
|
||||
pool.cache_full_blocks(
|
||||
request=req,
|
||||
blocks=blocks,
|
||||
num_cached_blocks=0,
|
||||
num_full_blocks=2,
|
||||
block_size=block_size,
|
||||
kv_cache_group_id=group_id,
|
||||
)
|
||||
|
||||
events = pool.take_events()
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], BlockStored)
|
||||
assert events[0].group_idx == group_id
|
||||
|
||||
|
||||
def test_block_stored_event_group_idx_multiple_groups():
|
||||
"""
|
||||
Test BlockStored events for separate HMA groups that each carry the
|
||||
correct group_idx.
|
||||
|
||||
Simulates the HMA scenario where full-attention blocks (group 0) and
|
||||
sliding-window blocks (group 1) are cached independently and must be
|
||||
distinguishable by consumers doing HMA-aware prefix-cache routing.
|
||||
"""
|
||||
block_size = 4
|
||||
num_tokens = block_size * 2
|
||||
|
||||
# null block + 4 usable (2 per group)
|
||||
pool = BlockPool(
|
||||
num_gpu_blocks=5,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
enable_kv_cache_events=True,
|
||||
)
|
||||
|
||||
req = make_request(
|
||||
"req_multi_grp",
|
||||
prompt_token_ids=list(range(num_tokens)),
|
||||
block_size=block_size,
|
||||
hash_fn=sha256,
|
||||
)
|
||||
|
||||
# Cache blocks for group 0 (full-attention)
|
||||
blocks_grp0 = pool.get_new_blocks(2)
|
||||
pool.cache_full_blocks(
|
||||
request=req,
|
||||
blocks=blocks_grp0,
|
||||
num_cached_blocks=0,
|
||||
num_full_blocks=2,
|
||||
block_size=block_size,
|
||||
kv_cache_group_id=0,
|
||||
)
|
||||
|
||||
# Cache blocks for group 1 (sliding-window)
|
||||
blocks_grp1 = pool.get_new_blocks(2)
|
||||
pool.cache_full_blocks(
|
||||
request=req,
|
||||
blocks=blocks_grp1,
|
||||
num_cached_blocks=0,
|
||||
num_full_blocks=2,
|
||||
block_size=block_size,
|
||||
kv_cache_group_id=1,
|
||||
)
|
||||
|
||||
events = pool.take_events()
|
||||
assert len(events) == 2
|
||||
assert isinstance(events[0], BlockStored)
|
||||
assert events[0].group_idx == 0
|
||||
assert isinstance(events[1], BlockStored)
|
||||
assert events[1].group_idx == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("group_id", [0, 1, 2])
|
||||
def test_block_removed_event_group_idx(group_id: int):
|
||||
"""
|
||||
Test BlockRemoved events emitted on eviction carry the group_idx extracted
|
||||
from the evicted block's BlockHashWithGroupId via get_group_id().
|
||||
"""
|
||||
block_size = 4
|
||||
num_tokens = block_size * 2
|
||||
|
||||
# null block + 4 usable; allocate all 4, cache 2, free all, re-allocate
|
||||
# all 4 so the 2 cached blocks are forced through _maybe_evict_cached_block.
|
||||
pool = BlockPool(
|
||||
num_gpu_blocks=5,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
enable_kv_cache_events=True,
|
||||
)
|
||||
|
||||
req = make_request(
|
||||
"req_evict_grp",
|
||||
prompt_token_ids=list(range(num_tokens)),
|
||||
block_size=block_size,
|
||||
hash_fn=sha256,
|
||||
)
|
||||
|
||||
# Allocate all usable blocks and cache the first two for the target group.
|
||||
all_blocks = pool.get_new_blocks(4)
|
||||
pool.cache_full_blocks(
|
||||
request=req,
|
||||
blocks=all_blocks,
|
||||
num_cached_blocks=0,
|
||||
num_full_blocks=2,
|
||||
block_size=block_size,
|
||||
kv_cache_group_id=group_id,
|
||||
)
|
||||
|
||||
# Drain the BlockStored events so only eviction events remain later.
|
||||
pool.take_events()
|
||||
|
||||
# Return all blocks to the free queue so they become eviction candidates.
|
||||
pool.free_blocks(all_blocks)
|
||||
|
||||
# Re-allocate all blocks; the two with hashes trigger BlockRemoved events.
|
||||
pool.get_new_blocks(4)
|
||||
|
||||
events = pool.take_events()
|
||||
removed_events = [e for e in events if isinstance(e, BlockRemoved)]
|
||||
|
||||
assert len(removed_events) == 2
|
||||
for event in removed_events:
|
||||
assert event.group_idx == group_id
|
||||
|
||||
|
||||
def test_eagle_enabled_removes_last_block():
|
||||
"""Verify Eagle does NOT remove blocks when request
|
||||
length is divisible by block size."""
|
||||
|
||||
@@ -936,6 +936,13 @@ async def test_engine_core_client_future_utility_async(
|
||||
client.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name,num_groups",
|
||||
[
|
||||
("meta-llama/Llama-3.2-1B-Instruct", 1),
|
||||
("google/gemma-3-1b-it", 7),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"multiprocessing_mode,publisher_config",
|
||||
[(True, "tcp"), (False, "inproc")],
|
||||
@@ -944,12 +951,14 @@ async def test_engine_core_client_future_utility_async(
|
||||
def test_kv_cache_events(
|
||||
multiprocessing_mode: bool,
|
||||
publisher_config,
|
||||
model_name: str,
|
||||
num_groups: int,
|
||||
):
|
||||
block_size = 16
|
||||
num_blocks = 2
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=MODEL_NAME,
|
||||
model=model_name,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=True,
|
||||
block_size=block_size,
|
||||
@@ -985,26 +994,29 @@ def test_kv_cache_events(
|
||||
assert result is not None, "No message received"
|
||||
|
||||
seq, received = result
|
||||
|
||||
assert seq == 0, "Sequence number mismatch"
|
||||
assert len(received.events) == 1, "We should have exactly one BlockStored event"
|
||||
event = received.events[0]
|
||||
assert isinstance(event, BlockStored), "We should have a BlockStored event"
|
||||
assert len(event.block_hashes) == num_blocks, (
|
||||
"We should have a BlockStored event with 2 block_hashes"
|
||||
)
|
||||
assert event.block_size == block_size, (
|
||||
"Block size should be the same as the block size"
|
||||
)
|
||||
assert event.parent_block_hash is None, "Parent block hash should be None"
|
||||
assert event.lora_id is None, "Lora id should be None"
|
||||
assert event.lora_name is None, "Lora name should be None"
|
||||
assert len(event.token_ids) == num_blocks * block_size, (
|
||||
"Token ids should be the same as the custom tokens"
|
||||
)
|
||||
assert event.token_ids == custom_tokens, (
|
||||
"Token ids should be the same as the custom tokens"
|
||||
assert len(received.events) == num_groups, (
|
||||
f"Expected {num_groups} BlockStored event(s), got {len(received.events)}"
|
||||
)
|
||||
|
||||
for index, event in enumerate(received.events):
|
||||
assert isinstance(event, BlockStored), "We should have a BlockStored event"
|
||||
assert len(event.block_hashes) == num_blocks, (
|
||||
"We should have a BlockStored event with 2 block_hashes"
|
||||
)
|
||||
assert event.block_size == block_size, (
|
||||
"Block size should be the same as the block size"
|
||||
)
|
||||
assert event.parent_block_hash is None, "Parent block hash should be None"
|
||||
assert event.lora_id is None, "Lora id should be None"
|
||||
assert event.lora_name is None, "Lora name should be None"
|
||||
assert len(event.token_ids) == num_blocks * block_size, (
|
||||
"Token ids should be the same as the custom tokens"
|
||||
)
|
||||
assert event.token_ids == custom_tokens, (
|
||||
"Token ids should be the same as the custom tokens"
|
||||
)
|
||||
assert event.group_idx == index
|
||||
finally:
|
||||
client.shutdown()
|
||||
subscriber.close()
|
||||
|
||||
@@ -1229,9 +1229,6 @@ class VllmConfig:
|
||||
if not current_platform.support_hybrid_kv_cache():
|
||||
# Hybrid KV cache manager is not supported on non-GPU platforms.
|
||||
need_disable_hybrid_kv_cache_manager = True
|
||||
if self.kv_events_config is not None:
|
||||
# Hybrid KV cache manager is not compatible with KV events.
|
||||
need_disable_hybrid_kv_cache_manager = True
|
||||
if (
|
||||
self.model_config is not None
|
||||
and self.model_config.attention_chunk_size is not None
|
||||
|
||||
@@ -67,6 +67,8 @@ class BlockStored(KVCacheEvent):
|
||||
KV cache consumers to reconstruct block hashes.
|
||||
"""
|
||||
|
||||
group_idx: int | None = None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(
|
||||
@@ -77,6 +79,7 @@ class BlockStored(KVCacheEvent):
|
||||
self.lora_id,
|
||||
self.medium,
|
||||
tuple(self.extra_keys) if self.extra_keys else None,
|
||||
self.group_idx,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -84,9 +87,16 @@ class BlockStored(KVCacheEvent):
|
||||
class BlockRemoved(KVCacheEvent):
|
||||
block_hashes: list[ExternalBlockHash]
|
||||
medium: str | None
|
||||
group_idx: int | None = None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((tuple(self.block_hashes), self.medium))
|
||||
return hash(
|
||||
(
|
||||
tuple(self.block_hashes),
|
||||
self.medium,
|
||||
self.group_idx,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class AllBlocksCleared(KVCacheEvent):
|
||||
|
||||
@@ -22,6 +22,7 @@ from vllm.v1.core.kv_cache_utils import (
|
||||
KVCacheBlock,
|
||||
generate_block_hash_extra_keys,
|
||||
get_block_hash,
|
||||
get_group_id,
|
||||
make_block_hash_with_group_id,
|
||||
maybe_convert_block_hash,
|
||||
)
|
||||
@@ -314,6 +315,7 @@ class BlockPool:
|
||||
if request.lora_request
|
||||
else None,
|
||||
extra_keys=extra_keys_list if extra_keys_list else None,
|
||||
group_idx=kv_cache_group_id,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -377,14 +379,11 @@ class BlockPool:
|
||||
block.reset_hash()
|
||||
|
||||
if self.enable_kv_cache_events:
|
||||
# FIXME (Chen): Not sure whether we should return `hash_value`
|
||||
# or `(hash_value, group_id)` here. But it's fine now because
|
||||
# we disable hybrid kv cache manager when kv cache event is
|
||||
# enabled, so there is only one group.
|
||||
self.kv_event_queue.append(
|
||||
BlockRemoved(
|
||||
block_hashes=[maybe_convert_block_hash(get_block_hash(block_hash))],
|
||||
medium=MEDIUM_GPU,
|
||||
group_idx=get_group_id(block_hash),
|
||||
)
|
||||
)
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user