[None][feat] Enhance support for complex models (#11254)

Signed-off-by: Yao Yao <lowsfer@users.noreply.github.com>
This commit is contained in:
Yao Yao 2026-02-05 17:28:26 +08:00 committed by GitHub
parent 4c1d9d0c10
commit d9b936be94
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 257 additions and 50 deletions

View File

@ -141,6 +141,9 @@ class _KVCache:
def set_page_index_buf(
self, beam_idx: BeamIndex, layer_group_id: LayerGroupId, buf: memoryview | None
) -> None: ...
def set_base_page_index_buf(
self, beam_idx: BeamIndex, layer_group_id: LayerGroupId, buf: memoryview | None
) -> None: ...
@property
def manager(self) -> "KVCacheManager": ...
@property
@ -157,6 +160,9 @@ class _KVCache:
@beam_width.setter
def beam_width(self, beam_width: BeamIndex) -> None: ...
def get_page_indices(self, layer_group_id: int, beam_id: BeamIndex = ...) -> IndexSeq: ...
def get_base_page_indices(
self, layer_group_id: LayerGroupId, beam_id: BeamIndex = DEFAULT_BEAM_INDEX
) -> IndexSeq: ...
def get_aggregated_page_indices(
self,
layer_group_id: LayerGroupId,
@ -229,6 +235,7 @@ class KVCacheManager:
def get_mem_pool_base_address(self, layer_id: LayerId, data_role: DataRole) -> MemAddress: ...
def get_page_stride(self, layer_id: LayerId, data_role: DataRole) -> int: ...
def get_page_index_upper_bound(self, layer_id: LayerId, data_role: DataRole) -> int: ...
def get_page_index_scale(self, layer_id: LayerId, data_role: DataRole) -> int: ...
def create_kv_cache(
self,
lora_task_id: int | None = None,

View File

@ -168,7 +168,8 @@ class _KVCache:
"_history_length",
"_commit_state",
"_blocks",
"_page_indices",
"_base_page_indices",
"_page_indices", # Deprecated. To be removed in the future.
"_committed_tokens",
"_num_committed_blocks",
"_finish_event",
@ -193,6 +194,7 @@ class _KVCache:
_blocks: TypedIndexList[BlockOrdinal, SeqBlock]
# we maintain _page_indices to accelerate the get_page_indices() API. In principle it can be
# computed on the fly, but that would be slow due to python.
_base_page_indices: TypedIndexList[BeamIndex, TypedIndexList[LifeCycleId, IndexSeq]]
_page_indices: TypedIndexList[BeamIndex, TypedIndexList[LifeCycleId, IndexSeq]]
_committed_tokens: list[TokenIdExt]
# Sometimes we can't commit a block because all its tokens are already covered by another block in
@ -226,6 +228,10 @@ class _KVCache:
self._history_length = 0
self._commit_state = self.CommitState.ALLOWED
self._blocks = cast(TypedIndexList, [])
self._base_page_indices = make_typed(
lambda: make_typed(lambda: array.array("i"), self.manager._storage.num_life_cycles),
self.beam_width,
)
self._page_indices = make_typed(
lambda: make_typed(lambda: array.array("i"), self.manager._storage.num_life_cycles),
self.beam_width,
@ -242,7 +248,10 @@ class _KVCache:
def set_page_index_buf(
self, beam_idx: BeamIndex, layer_group_id: LayerGroupId, buf: memoryview | None
) -> None:
"""Set the buffer for page indices, so we directly update indices in user buffer to
"""
Deprecated. Use set_base_page_index_buf() instead.
Set the buffer for page indices, so we directly update indices in user buffer to
avoid user-side copy. This is the zero-copy alternative of get_page_indices()"""
length = self.num_blocks
old_indices = self._page_indices[beam_idx][layer_group_id]
@ -256,6 +265,28 @@ class _KVCache:
new_indices = buf
self._page_indices[beam_idx][layer_group_id] = new_indices
def set_base_page_index_buf(
self, beam_idx: BeamIndex, layer_group_id: LayerGroupId, buf: memoryview | None
) -> None:
"""
Set the buffer for base page indices, so we directly update indices in user buffer to
avoid user-side copy. This is the zero-copy alternative of get_base_page_indices().
Note that base page indices are not meant for direct use in the kernels. They need to
be scaled by kv_cache_manager.page_index_scale().
"""
length = self.num_blocks
old_indices = self._base_page_indices[beam_idx][layer_group_id]
new_indices: IndexSeq
if buf is None:
new_indices = array.array("i", old_indices[:length])
else:
assert buf.ndim == 1 and buf.format == "i" and len(buf) >= length
buf[:length] = old_indices[:length]
buf[length:] = array.array("i", [BAD_PAGE_INDEX]) * (len(buf) - length)
new_indices = buf
self._base_page_indices[beam_idx][layer_group_id] = new_indices
@property
def manager(self) -> "KVCacheManager":
return self._manager
@ -314,6 +345,9 @@ class _KVCache:
def get_page_indices(
self, layer_group_id: LayerGroupId, beam_id: BeamIndex = DEFAULT_BEAM_INDEX
) -> IndexSeq:
"""
Deprecated. Use get_base_page_indices() instead.
"""
indices = self._page_indices[beam_id][layer_group_id]
assert NDEBUG or all(
v == value_or(r, BAD_PAGE_INDEX)
@ -321,6 +355,16 @@ class _KVCache:
)
return indices
def get_base_page_indices(
self, layer_group_id: LayerGroupId, beam_id: BeamIndex = DEFAULT_BEAM_INDEX
) -> IndexSeq:
indices = self._base_page_indices[beam_id][layer_group_id]
assert NDEBUG or all(
v == value_or(r, BAD_PAGE_INDEX)
for v, r in zip(indices, self._get_base_page_indices_ref(layer_group_id, beam_id))
)
return indices
def get_aggregated_page_indices(
self,
layer_group_id: LayerGroupId,
@ -380,15 +424,16 @@ class _KVCache:
if new_num_blocks < old_num_blocks:
with self._record_event():
del self._blocks[new_num_blocks:]
for beam_indices in self._page_indices:
for indices in beam_indices:
assert all(i == BAD_PAGE_INDEX for i in indices[new_num_blocks:])
if type(indices) is array.array:
del indices[new_num_blocks:]
else:
indices[new_num_blocks:] = array.array("i", [BAD_PAGE_INDEX]) * (
len(indices) - new_num_blocks
)
for page_indices in (self._base_page_indices, self._page_indices):
for beam_indices in page_indices:
for indices in beam_indices:
assert all(i == BAD_PAGE_INDEX for i in indices[new_num_blocks:])
if type(indices) is array.array:
del indices[new_num_blocks:]
else:
indices[new_num_blocks:] = array.array("i", [BAD_PAGE_INDEX]) * (
len(indices) - new_num_blocks
)
elif new_num_blocks > old_num_blocks:
num_new_slots = filled_list(0, num_life_cycles)
stale_ranges = [
@ -408,13 +453,14 @@ class _KVCache:
except OutOfPagesError:
self._lock_held_blocks(backup_holders)
return False
for beam_indices in self._page_indices:
for indices in beam_indices:
if type(indices) is array.array:
assert len(indices) == old_num_blocks
indices.extend([BAD_PAGE_INDEX] * (new_num_blocks - old_num_blocks))
else:
assert len(indices) >= new_num_blocks
for page_indices in (self._base_page_indices, self._page_indices):
for beam_indices in page_indices:
for indices in beam_indices:
if type(indices) is array.array:
assert len(indices) == old_num_blocks
indices.extend([BAD_PAGE_INDEX] * (new_num_blocks - old_num_blocks))
else:
assert len(indices) >= new_num_blocks
stream_wait_events(
self.cuda_stream, (s.ready_event for s in chain.from_iterable(slots))
)
@ -543,6 +589,10 @@ class _KVCache:
assert self.status == self.Status.ACTIVE
assert self._check_sanity()
assert self._finish_event is None
for beam_idx, beam_indices in typed_enumerate(self._base_page_indices):
for lc, indices in typed_enumerate(beam_indices):
if type(indices) is memoryview:
self.set_base_page_index_buf(beam_idx, lc, None)
for beam_idx, beam_indices in typed_enumerate(self._page_indices):
for lc, indices in typed_enumerate(beam_indices):
if type(indices) is memoryview:
@ -999,12 +1049,13 @@ class _KVCache:
"failure by disallowing partial matching."
)
self._num_committed_blocks = BlockOrdinal(len(self._committed_tokens) // tokens_per_block)
for beam_indices in self._page_indices:
for indices in beam_indices:
if type(indices) is array.array:
indices.extend([BAD_PAGE_INDEX] * (self.num_blocks - len(indices)))
else:
assert len(indices) >= self.num_blocks
for page_indices in (self._base_page_indices, self._page_indices):
for beam_indices in page_indices:
for indices in beam_indices:
if type(indices) is array.array:
indices.extend([BAD_PAGE_INDEX] * (self.num_blocks - len(indices)))
else:
assert len(indices) >= self.num_blocks
def _clear_blocks(self) -> None:
# drop the last block first
@ -1020,6 +1071,14 @@ class _KVCache:
finally:
self._finish_event = None
def _update_base_page_index(
self, beam_idx: BeamIndex, ordinal: BlockOrdinal, lc: LifeCycleId, page_index: PageIndex
) -> PageIndex:
indices = self._base_page_indices[beam_idx][lc]
old = PageIndex(indices[ordinal])
indices[ordinal] = page_index
return old
def _update_page_index(
self, beam_idx: BeamIndex, ordinal: BlockOrdinal, lc: LifeCycleId, page_index: PageIndex
) -> PageIndex:
@ -1042,6 +1101,13 @@ class _KVCache:
)
return self._storage.get_page_indices_ref(lc, pages)
def _get_base_page_indices_ref(
self, lc: LifeCycleId, beam_id: BeamIndex = DEFAULT_BEAM_INDEX
) -> Iterator[int | None]:
assert beam_id < self.beam_width
assert self.is_active
return self.get_aggregated_page_indices(lc, beam_id)
def _shortcut_set_capacity(self, capacity: int) -> bool:
"Shortcut for cases without side effects. Just for better performance."
tokens_per_block = self.tokens_per_block

View File

@ -152,6 +152,14 @@ class KVCacheManager:
slot_size = pool_group.slot_size[pool_idx]
return exact_div(slot_size, attr.size) * num_slots - exact_div(attr.offset, attr.size)
def get_page_index_scale(self, layer_id: LayerId, data_role: DataRole) -> int:
"""
The multiplier to convert from base page indices to page indices expected by operators/kernels.
"""
storage = self._storage
attr = storage.get_buffer_attr(layer_id, data_role)
return storage._slot_to_page_indices[attr.life_cycle_id][attr.pool_index]
def create_kv_cache(
self,
lora_task_id: int | None = None,

View File

@ -41,7 +41,7 @@ if TYPE_CHECKING:
from ._eviction_controller import NodeRef
from ._exceptions import LogicError
from ._life_cycle_registry import LifeCycleId
from ._storage._core import Slot
from ._storage._core import PoolIndex0, Slot
from ._utils import (
CachedCudaEvent,
assert_critical,
@ -390,6 +390,10 @@ class _SharedPageLock:
if not skip_wait:
self.page.ready_event.wait_in_stream(kv_cache.cuda_stream)
self._user = LockOwner(rawref.ref(kv_cache), beam_index, ordinal, life_cycle)
old_base_index = kv_cache._update_base_page_index(
beam_index, ordinal, life_cycle, PageIndex(self.page.slot_id)
)
assert old_base_index == BAD_PAGE_INDEX
new_index = self._get_page_index()
old_index = kv_cache._update_page_index(beam_index, ordinal, life_cycle, new_index)
assert old_index == BAD_PAGE_INDEX
@ -402,19 +406,29 @@ class _SharedPageLock:
assert self._uniq_lock is not None
page = self.page
self._uniq_lock.finish_events.append(unwrap_rawref(self._user.kv_cache).finish_event)
beam_index = self._user.beam_index
ordinal = self._user.ordinal
life_cycle = self._user.life_cycle
kv_cache = unwrap_rawref(self._user.kv_cache)
new_index = BAD_PAGE_INDEX
old_index = unwrap_rawref(self._user.kv_cache)._update_page_index(
self._user.beam_index, self._user.ordinal, self._user.life_cycle, new_index
old_base_index = kv_cache._update_base_page_index(
beam_index, ordinal, life_cycle, new_index
)
assert NDEBUG or old_base_index == self._get_base_page_index()
old_index = kv_cache._update_page_index(beam_index, ordinal, life_cycle, new_index)
assert NDEBUG or old_index == self._get_page_index()
self._uniq_lock = None
return page
def _get_page_index(self) -> PageIndex:
storage = unwrap_rawref(self._user.kv_cache).manager._storage
num_buffers_per_slot = storage._slot_to_page_indices[self._user.life_cycle]
user = self._user
num_buffers_per_slot = storage._slot_to_page_indices[user.life_cycle][PoolIndex0]
return PageIndex(self.page.slot_id * num_buffers_per_slot)
def _get_base_page_index(self) -> PageIndex:
return PageIndex(self.page.slot_id)
BlockPage = _SharedPageLock | _PageHolder | None

View File

@ -119,14 +119,13 @@ class StorageConfig:
offset += cb.single_buffer_size
return ret
def slot_to_page_indices(self) -> TypedIndexList[LifeCycleId, int]:
ret = filled_list(0, self.num_life_cycles)
def slot_to_page_indices(self) -> TypedIndexList[LifeCycleId, TypedIndexList[PoolIndex, int]]:
ret = [[]] * self.num_life_cycles
for pg in self.slot_desc_list:
for slot in pg.variants:
life_cycle = slot.life_cycle_id
page = slot.coalesced_buffers[0]
ret[life_cycle] = page.num_buffers
return ret
ret[life_cycle] = [cb.num_buffers for cb in slot.coalesced_buffers]
return cast(TypedIndexList[LifeCycleId, TypedIndexList[PoolIndex, int]], ret)
def layer_to_life_cycle_ids(self) -> dict[LayerId, LifeCycleId]:
map = dict[LayerId, LifeCycleId]()
@ -160,11 +159,6 @@ def create_storage_config(config: KVCacheManagerConfig) -> StorageConfig:
# @TODO: add test for this case.
slot_groups: list[SlotDescVariant] = []
for life_cycle_id, size_to_buffers in buffer_groups.items():
assert len(set(len(buffer_ids) for buffer_ids in size_to_buffers.values())) == 1, (
"Not yet supported. While we can support this easily, we need to know whether the kernels "
"need to share page indices or not. We haven't seen such models, yet. So we leave this as a "
"future work."
)
slots = [
CoalescedBuffer(size, tuple(buffer_ids)) for size, buffer_ids in size_to_buffers.items()
]

View File

@ -58,6 +58,10 @@ PoolGroupIndex = NewType("PoolGroupIndex", int)
PoolIndex = NewType("PoolIndex", int)
SlotId = NewType("SlotId", int)
# A temporary work-around while migrating to new page index API.
# To be removed later.
PoolIndex0 = PoolIndex(0)
class SlotPoolBase(abc.ABC):
_slot_size: int

View File

@ -46,6 +46,7 @@ from ._storage._core import (
PoolGroupBase,
PoolGroupIndex,
PoolIndex,
PoolIndex0,
Slot,
SlotId,
)
@ -159,7 +160,7 @@ class StorageManager:
)
_life_cycles: LifeCycleRegistry
_layer_to_life_cycle_ids: dict[LayerId, LifeCycleId]
_slot_to_page_indices: TypedIndexList[LifeCycleId, int]
_slot_to_page_indices: TypedIndexList[LifeCycleId, TypedIndexList[PoolIndex, int]]
_buffer_attr: dict[BufferId, BufferAttr]
_life_cycle_grouping: TypedIndexList[LifeCycleId, PoolGroupIndex]
_levels: TypedIndexList[CacheLevel, CacheLevelManager]
@ -480,7 +481,8 @@ class StorageManager:
self, lc_id: LifeCycleId, pages: Iterator[Page | None]
) -> Iterator[int | None]:
"Reference implementation. Not fast enough for production."
scale = self._slot_to_page_indices[lc_id]
scale = self._slot_to_page_indices[lc_id][PoolIndex0]
assert all(scale == s for s in self._slot_to_page_indices[lc_id])
return (map_optional(page, lambda p: scale * int(p.slot_id)) for page in pages)
def get_buffer_attr(self, layer_id: LayerId, data_role: DataRole) -> BufferAttr:
@ -492,8 +494,8 @@ class StorageManager:
return self._levels[level].storage.slot_address(pg_idx, pool_idx, slot_id)
def get_page_indices_for_slot(self, life_cycle: LifeCycleId, slot_id: SlotId) -> PageIndex:
scale = self._slot_to_page_indices[life_cycle]
return PageIndex(scale * slot_id)
scale = self._slot_to_page_indices[life_cycle][PoolIndex0]
return PageIndex(scale * int(slot_id))
def get_statistics(
self, level: CacheLevel = GPU_LEVEL

View File

@ -138,7 +138,12 @@ class FakeEngine:
pool = manager.get_mem_pool_base_address(layer_id, role)
stride = manager.get_page_stride(layer_id, role)
lc_id = manager._storage._layer_to_life_cycle_ids[layer_id]
pages = kv_cache.get_page_indices(lc_id, beam)
base_pages = kv_cache.get_base_page_indices(lc_id, beam)
page_scale = manager.get_page_index_scale(layer_id, role)
pages = [
BAD_PAGE_INDEX if base_page is BAD_PAGE_INDEX else base_page * page_scale
for base_page in base_pages
]
capacity = kv_cache.capacity
history_len = len(history)
assert len(history) == history_len
@ -181,9 +186,14 @@ class FakeEngine:
pool = manager.get_mem_pool_base_address(layer_id, role)
stride = manager.get_page_stride(layer_id, role)
lc_id = manager._storage._layer_to_life_cycle_ids[layer_id]
pages = kv_cache.get_page_indices(lc_id, beam)[
base_pages = kv_cache.get_base_page_indices(lc_id, beam)[
: div_up(history_len + len(input), tokens_per_block)
]
page_scale = manager.get_page_index_scale(layer_id, role)
pages = [
BAD_PAGE_INDEX if base_page is BAD_PAGE_INDEX else base_page * page_scale
for base_page in base_pages
]
capacity = kv_cache.capacity
input_range = (history_len, history_len + len(input))
assert input_range[1] <= capacity

View File

@ -36,6 +36,7 @@ if not TYPE_CHECKING and find_spec("kv_cache_manager_v2") is not None:
BufferSlice,
CacheLevel,
CudaStream,
DataRole,
DiskCacheTierConfig,
GpuCacheTierConfig,
HostCacheTierConfig,
@ -78,6 +79,7 @@ else:
BufferSlice,
CacheLevel,
CudaStream,
DataRole,
DiskCacheTierConfig,
GpuCacheTierConfig,
HostCacheTierConfig,
@ -340,12 +342,12 @@ class TestNoBatching(TestKVCacheManagerV2):
if use_external_page_index_buf:
max_num_blocks = div_up(seq_len, self.cfg.tokens_per_block)
num_layer_groups = len(self.manager.layer_grouping)
page_indices = [
base_page_indices = [
array.array("i", [-1]) * max_num_blocks for _ in range(num_layer_groups)
]
for id in range(num_layer_groups):
req0.kv_cache.set_page_index_buf(
DEFAULT_BEAM_INDEX, LayerGroupId(id), memoryview(page_indices[id])
req0.kv_cache.set_base_page_index_buf(
DEFAULT_BEAM_INDEX, LayerGroupId(id), memoryview(base_page_indices[id])
)
with TemporaryCudaStream([]) as s:
stream = cast(CudaStream, s.handle)
@ -367,12 +369,12 @@ class TestNoBatching(TestKVCacheManagerV2):
if use_external_page_index_buf:
max_num_blocks = div_up(seq_len, self.cfg.tokens_per_block)
num_layer_groups = len(self.manager.layer_grouping)
page_indices = [
base_page_indices = [
array.array("i", [-1]) * max_num_blocks for _ in range(num_layer_groups)
]
for id in range(num_layer_groups):
req0.kv_cache.set_page_index_buf(
DEFAULT_BEAM_INDEX, LayerGroupId(id), memoryview(page_indices[id])
req0.kv_cache.set_base_page_index_buf(
DEFAULT_BEAM_INDEX, LayerGroupId(id), memoryview(base_page_indices[id])
)
with TemporaryCudaStream([]) as s:
stream = cast(CudaStream, s.handle)
@ -954,5 +956,105 @@ class TestDisaggregatedServing(unittest.TestCase):
node.kv_cache.close()
class TestComplexModels(unittest.TestCase):
def setUp(self) -> None:
init_cuda_once()
gc.collect()
gc.disable()
def tearDown(self) -> None:
gc.enable()
def test_complex_model_0(self) -> None:
role = DataRole("buf0")
layers = [
AttentionLayerConfig(
layer_id=LayerId(0),
buffers=[BufferConfig(role=role, size=131072)],
sliding_window_size=128,
num_sink_tokens=None,
),
AttentionLayerConfig(
layer_id=LayerId(1),
buffers=[BufferConfig(role=role, size=131072)],
sliding_window_size=128,
num_sink_tokens=None,
),
AttentionLayerConfig(
layer_id=LayerId(2),
buffers=[BufferConfig(role=role, size=98304)],
sliding_window_size=None,
num_sink_tokens=None,
),
AttentionLayerConfig(
layer_id=LayerId(3),
buffers=[BufferConfig(role=role, size=163840)],
sliding_window_size=64,
num_sink_tokens=None,
),
AttentionLayerConfig(
layer_id=LayerId(4),
buffers=[BufferConfig(role=role, size=163840)],
sliding_window_size=64,
num_sink_tokens=None,
),
AttentionLayerConfig(
layer_id=LayerId(5),
buffers=[BufferConfig(role=role, size=65536)],
sliding_window_size=None,
num_sink_tokens=None,
),
AttentionLayerConfig(
layer_id=LayerId(6),
buffers=[BufferConfig(role=role, size=131072)],
sliding_window_size=64,
num_sink_tokens=None,
),
AttentionLayerConfig(
layer_id=LayerId(7),
buffers=[BufferConfig(role=role, size=131072)],
sliding_window_size=64,
num_sink_tokens=None,
),
AttentionLayerConfig(
layer_id=LayerId(8),
buffers=[BufferConfig(role=role, size=131072)],
sliding_window_size=128,
num_sink_tokens=None,
),
AttentionLayerConfig(
layer_id=LayerId(9),
buffers=[BufferConfig(role=role, size=32768)],
sliding_window_size=None,
num_sink_tokens=None,
),
AttentionLayerConfig(
layer_id=LayerId(10),
buffers=[BufferConfig(role=role, size=262144)],
sliding_window_size=128,
num_sink_tokens=None,
),
AttentionLayerConfig(
layer_id=LayerId(11),
buffers=[BufferConfig(role=role, size=262144)],
sliding_window_size=128,
num_sink_tokens=None,
),
]
config = KVCacheManagerConfig(
tokens_per_block=128,
vocab_size=1024,
cache_tiers=[
GpuCacheTierConfig(quota=1024 * 1024 * 1024),
HostCacheTierConfig(quota=8000 << 20),
],
max_util_for_resume=0.95,
layers=layers,
)
manager = KVCacheManager(config)
del manager
if __name__ == "__main__":
unittest.main()