diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/__init__.pyi b/tensorrt_llm/runtime/kv_cache_manager_v2/__init__.pyi index 549f0617a1..4df8477755 100644 --- a/tensorrt_llm/runtime/kv_cache_manager_v2/__init__.pyi +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/__init__.pyi @@ -141,6 +141,9 @@ class _KVCache: def set_page_index_buf( self, beam_idx: BeamIndex, layer_group_id: LayerGroupId, buf: memoryview | None ) -> None: ... + def set_base_page_index_buf( + self, beam_idx: BeamIndex, layer_group_id: LayerGroupId, buf: memoryview | None + ) -> None: ... @property def manager(self) -> "KVCacheManager": ... @property @@ -157,6 +160,9 @@ class _KVCache: @beam_width.setter def beam_width(self, beam_width: BeamIndex) -> None: ... def get_page_indices(self, layer_group_id: int, beam_id: BeamIndex = ...) -> IndexSeq: ... + def get_base_page_indices( + self, layer_group_id: LayerGroupId, beam_id: BeamIndex = DEFAULT_BEAM_INDEX + ) -> IndexSeq: ... def get_aggregated_page_indices( self, layer_group_id: LayerGroupId, @@ -229,6 +235,7 @@ class KVCacheManager: def get_mem_pool_base_address(self, layer_id: LayerId, data_role: DataRole) -> MemAddress: ... def get_page_stride(self, layer_id: LayerId, data_role: DataRole) -> int: ... def get_page_index_upper_bound(self, layer_id: LayerId, data_role: DataRole) -> int: ... + def get_page_index_scale(self, layer_id: LayerId, data_role: DataRole) -> int: ... def create_kv_cache( self, lora_task_id: int | None = None, diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache.py b/tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache.py index bfafda94b2..7edc56b23d 100644 --- a/tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache.py +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache.py @@ -168,7 +168,8 @@ class _KVCache: "_history_length", "_commit_state", "_blocks", - "_page_indices", + "_base_page_indices", + "_page_indices", # Deprecated. To be removed in the future. "_committed_tokens", "_num_committed_blocks", "_finish_event", @@ -193,6 +194,7 @@ class _KVCache: _blocks: TypedIndexList[BlockOrdinal, SeqBlock] # we maintain _page_indices to accelerate the get_page_indices() API. In principle it can be # computed on the fly, but that would be slow due to python. + _base_page_indices: TypedIndexList[BeamIndex, TypedIndexList[LifeCycleId, IndexSeq]] _page_indices: TypedIndexList[BeamIndex, TypedIndexList[LifeCycleId, IndexSeq]] _committed_tokens: list[TokenIdExt] # Sometimes we can't commit a block because all its tokens are already covered by another block in @@ -226,6 +228,10 @@ class _KVCache: self._history_length = 0 self._commit_state = self.CommitState.ALLOWED self._blocks = cast(TypedIndexList, []) + self._base_page_indices = make_typed( + lambda: make_typed(lambda: array.array("i"), self.manager._storage.num_life_cycles), + self.beam_width, + ) self._page_indices = make_typed( lambda: make_typed(lambda: array.array("i"), self.manager._storage.num_life_cycles), self.beam_width, @@ -242,7 +248,10 @@ class _KVCache: def set_page_index_buf( self, beam_idx: BeamIndex, layer_group_id: LayerGroupId, buf: memoryview | None ) -> None: - """Set the buffer for page indices, so we directly update indices in user buffer to + """ + Deprecated. Use set_base_page_index_buf() instead. + + Set the buffer for page indices, so we directly update indices in user buffer to avoid user-side copy. This is the zero-copy alternative of get_page_indices()""" length = self.num_blocks old_indices = self._page_indices[beam_idx][layer_group_id] @@ -256,6 +265,28 @@ class _KVCache: new_indices = buf self._page_indices[beam_idx][layer_group_id] = new_indices + def set_base_page_index_buf( + self, beam_idx: BeamIndex, layer_group_id: LayerGroupId, buf: memoryview | None + ) -> None: + """ + Set the buffer for base page indices, so we directly update indices in user buffer to + avoid user-side copy. This is the zero-copy alternative of get_base_page_indices(). + + Note that base page indices are not meant for direct use in the kernels. They need to + be scaled by kv_cache_manager.page_index_scale(). + """ + length = self.num_blocks + old_indices = self._base_page_indices[beam_idx][layer_group_id] + new_indices: IndexSeq + if buf is None: + new_indices = array.array("i", old_indices[:length]) + else: + assert buf.ndim == 1 and buf.format == "i" and len(buf) >= length + buf[:length] = old_indices[:length] + buf[length:] = array.array("i", [BAD_PAGE_INDEX]) * (len(buf) - length) + new_indices = buf + self._base_page_indices[beam_idx][layer_group_id] = new_indices + @property def manager(self) -> "KVCacheManager": return self._manager @@ -314,6 +345,9 @@ class _KVCache: def get_page_indices( self, layer_group_id: LayerGroupId, beam_id: BeamIndex = DEFAULT_BEAM_INDEX ) -> IndexSeq: + """ + Deprecated. Use get_base_page_indices() instead. + """ indices = self._page_indices[beam_id][layer_group_id] assert NDEBUG or all( v == value_or(r, BAD_PAGE_INDEX) @@ -321,6 +355,16 @@ class _KVCache: ) return indices + def get_base_page_indices( + self, layer_group_id: LayerGroupId, beam_id: BeamIndex = DEFAULT_BEAM_INDEX + ) -> IndexSeq: + indices = self._base_page_indices[beam_id][layer_group_id] + assert NDEBUG or all( + v == value_or(r, BAD_PAGE_INDEX) + for v, r in zip(indices, self._get_base_page_indices_ref(layer_group_id, beam_id)) + ) + return indices + def get_aggregated_page_indices( self, layer_group_id: LayerGroupId, @@ -380,15 +424,16 @@ class _KVCache: if new_num_blocks < old_num_blocks: with self._record_event(): del self._blocks[new_num_blocks:] - for beam_indices in self._page_indices: - for indices in beam_indices: - assert all(i == BAD_PAGE_INDEX for i in indices[new_num_blocks:]) - if type(indices) is array.array: - del indices[new_num_blocks:] - else: - indices[new_num_blocks:] = array.array("i", [BAD_PAGE_INDEX]) * ( - len(indices) - new_num_blocks - ) + for page_indices in (self._base_page_indices, self._page_indices): + for beam_indices in page_indices: + for indices in beam_indices: + assert all(i == BAD_PAGE_INDEX for i in indices[new_num_blocks:]) + if type(indices) is array.array: + del indices[new_num_blocks:] + else: + indices[new_num_blocks:] = array.array("i", [BAD_PAGE_INDEX]) * ( + len(indices) - new_num_blocks + ) elif new_num_blocks > old_num_blocks: num_new_slots = filled_list(0, num_life_cycles) stale_ranges = [ @@ -408,13 +453,14 @@ class _KVCache: except OutOfPagesError: self._lock_held_blocks(backup_holders) return False - for beam_indices in self._page_indices: - for indices in beam_indices: - if type(indices) is array.array: - assert len(indices) == old_num_blocks - indices.extend([BAD_PAGE_INDEX] * (new_num_blocks - old_num_blocks)) - else: - assert len(indices) >= new_num_blocks + for page_indices in (self._base_page_indices, self._page_indices): + for beam_indices in page_indices: + for indices in beam_indices: + if type(indices) is array.array: + assert len(indices) == old_num_blocks + indices.extend([BAD_PAGE_INDEX] * (new_num_blocks - old_num_blocks)) + else: + assert len(indices) >= new_num_blocks stream_wait_events( self.cuda_stream, (s.ready_event for s in chain.from_iterable(slots)) ) @@ -543,6 +589,10 @@ class _KVCache: assert self.status == self.Status.ACTIVE assert self._check_sanity() assert self._finish_event is None + for beam_idx, beam_indices in typed_enumerate(self._base_page_indices): + for lc, indices in typed_enumerate(beam_indices): + if type(indices) is memoryview: + self.set_base_page_index_buf(beam_idx, lc, None) for beam_idx, beam_indices in typed_enumerate(self._page_indices): for lc, indices in typed_enumerate(beam_indices): if type(indices) is memoryview: @@ -999,12 +1049,13 @@ class _KVCache: "failure by disallowing partial matching." ) self._num_committed_blocks = BlockOrdinal(len(self._committed_tokens) // tokens_per_block) - for beam_indices in self._page_indices: - for indices in beam_indices: - if type(indices) is array.array: - indices.extend([BAD_PAGE_INDEX] * (self.num_blocks - len(indices))) - else: - assert len(indices) >= self.num_blocks + for page_indices in (self._base_page_indices, self._page_indices): + for beam_indices in page_indices: + for indices in beam_indices: + if type(indices) is array.array: + indices.extend([BAD_PAGE_INDEX] * (self.num_blocks - len(indices))) + else: + assert len(indices) >= self.num_blocks def _clear_blocks(self) -> None: # drop the last block first @@ -1020,6 +1071,14 @@ class _KVCache: finally: self._finish_event = None + def _update_base_page_index( + self, beam_idx: BeamIndex, ordinal: BlockOrdinal, lc: LifeCycleId, page_index: PageIndex + ) -> PageIndex: + indices = self._base_page_indices[beam_idx][lc] + old = PageIndex(indices[ordinal]) + indices[ordinal] = page_index + return old + def _update_page_index( self, beam_idx: BeamIndex, ordinal: BlockOrdinal, lc: LifeCycleId, page_index: PageIndex ) -> PageIndex: @@ -1042,6 +1101,13 @@ class _KVCache: ) return self._storage.get_page_indices_ref(lc, pages) + def _get_base_page_indices_ref( + self, lc: LifeCycleId, beam_id: BeamIndex = DEFAULT_BEAM_INDEX + ) -> Iterator[int | None]: + assert beam_id < self.beam_width + assert self.is_active + return self.get_aggregated_page_indices(lc, beam_id) + def _shortcut_set_capacity(self, capacity: int) -> bool: "Shortcut for cases without side effects. Just for better performance." tokens_per_block = self.tokens_per_block diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache_manager.py b/tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache_manager.py index 925aca4be8..3184ecd531 100644 --- a/tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache_manager.py +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache_manager.py @@ -152,6 +152,14 @@ class KVCacheManager: slot_size = pool_group.slot_size[pool_idx] return exact_div(slot_size, attr.size) * num_slots - exact_div(attr.offset, attr.size) + def get_page_index_scale(self, layer_id: LayerId, data_role: DataRole) -> int: + """ + The multiplier to convert from base page indices to page indices expected by operators/kernels. + """ + storage = self._storage + attr = storage.get_buffer_attr(layer_id, data_role) + return storage._slot_to_page_indices[attr.life_cycle_id][attr.pool_index] + def create_kv_cache( self, lora_task_id: int | None = None, diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/_page.py b/tensorrt_llm/runtime/kv_cache_manager_v2/_page.py index 9727feb81f..5bff868603 100644 --- a/tensorrt_llm/runtime/kv_cache_manager_v2/_page.py +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/_page.py @@ -41,7 +41,7 @@ if TYPE_CHECKING: from ._eviction_controller import NodeRef from ._exceptions import LogicError from ._life_cycle_registry import LifeCycleId -from ._storage._core import Slot +from ._storage._core import PoolIndex0, Slot from ._utils import ( CachedCudaEvent, assert_critical, @@ -390,6 +390,10 @@ class _SharedPageLock: if not skip_wait: self.page.ready_event.wait_in_stream(kv_cache.cuda_stream) self._user = LockOwner(rawref.ref(kv_cache), beam_index, ordinal, life_cycle) + old_base_index = kv_cache._update_base_page_index( + beam_index, ordinal, life_cycle, PageIndex(self.page.slot_id) + ) + assert old_base_index == BAD_PAGE_INDEX new_index = self._get_page_index() old_index = kv_cache._update_page_index(beam_index, ordinal, life_cycle, new_index) assert old_index == BAD_PAGE_INDEX @@ -402,19 +406,29 @@ class _SharedPageLock: assert self._uniq_lock is not None page = self.page self._uniq_lock.finish_events.append(unwrap_rawref(self._user.kv_cache).finish_event) + beam_index = self._user.beam_index + ordinal = self._user.ordinal + life_cycle = self._user.life_cycle + kv_cache = unwrap_rawref(self._user.kv_cache) new_index = BAD_PAGE_INDEX - old_index = unwrap_rawref(self._user.kv_cache)._update_page_index( - self._user.beam_index, self._user.ordinal, self._user.life_cycle, new_index + old_base_index = kv_cache._update_base_page_index( + beam_index, ordinal, life_cycle, new_index ) + assert NDEBUG or old_base_index == self._get_base_page_index() + old_index = kv_cache._update_page_index(beam_index, ordinal, life_cycle, new_index) assert NDEBUG or old_index == self._get_page_index() self._uniq_lock = None return page def _get_page_index(self) -> PageIndex: storage = unwrap_rawref(self._user.kv_cache).manager._storage - num_buffers_per_slot = storage._slot_to_page_indices[self._user.life_cycle] + user = self._user + num_buffers_per_slot = storage._slot_to_page_indices[user.life_cycle][PoolIndex0] return PageIndex(self.page.slot_id * num_buffers_per_slot) + def _get_base_page_index(self) -> PageIndex: + return PageIndex(self.page.slot_id) + BlockPage = _SharedPageLock | _PageHolder | None diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/_storage/_config.py b/tensorrt_llm/runtime/kv_cache_manager_v2/_storage/_config.py index 69a23cebd4..460b89d6cc 100644 --- a/tensorrt_llm/runtime/kv_cache_manager_v2/_storage/_config.py +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/_storage/_config.py @@ -119,14 +119,13 @@ class StorageConfig: offset += cb.single_buffer_size return ret - def slot_to_page_indices(self) -> TypedIndexList[LifeCycleId, int]: - ret = filled_list(0, self.num_life_cycles) + def slot_to_page_indices(self) -> TypedIndexList[LifeCycleId, TypedIndexList[PoolIndex, int]]: + ret = [[]] * self.num_life_cycles for pg in self.slot_desc_list: for slot in pg.variants: life_cycle = slot.life_cycle_id - page = slot.coalesced_buffers[0] - ret[life_cycle] = page.num_buffers - return ret + ret[life_cycle] = [cb.num_buffers for cb in slot.coalesced_buffers] + return cast(TypedIndexList[LifeCycleId, TypedIndexList[PoolIndex, int]], ret) def layer_to_life_cycle_ids(self) -> dict[LayerId, LifeCycleId]: map = dict[LayerId, LifeCycleId]() @@ -160,11 +159,6 @@ def create_storage_config(config: KVCacheManagerConfig) -> StorageConfig: # @TODO: add test for this case. slot_groups: list[SlotDescVariant] = [] for life_cycle_id, size_to_buffers in buffer_groups.items(): - assert len(set(len(buffer_ids) for buffer_ids in size_to_buffers.values())) == 1, ( - "Not yet supported. While we can support this easily, we need to know whether the kernels " - "need to share page indices or not. We haven't seen such models, yet. So we leave this as a " - "future work." - ) slots = [ CoalescedBuffer(size, tuple(buffer_ids)) for size, buffer_ids in size_to_buffers.items() ] diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/_storage/_core.py b/tensorrt_llm/runtime/kv_cache_manager_v2/_storage/_core.py index 4d66c18258..8191611aaf 100644 --- a/tensorrt_llm/runtime/kv_cache_manager_v2/_storage/_core.py +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/_storage/_core.py @@ -58,6 +58,10 @@ PoolGroupIndex = NewType("PoolGroupIndex", int) PoolIndex = NewType("PoolIndex", int) SlotId = NewType("SlotId", int) +# A temporary work-around while migrating to new page index API. +# To be removed later. +PoolIndex0 = PoolIndex(0) + class SlotPoolBase(abc.ABC): _slot_size: int diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/_storage_manager.py b/tensorrt_llm/runtime/kv_cache_manager_v2/_storage_manager.py index 87a78d58b0..d75c4e7d2d 100644 --- a/tensorrt_llm/runtime/kv_cache_manager_v2/_storage_manager.py +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/_storage_manager.py @@ -46,6 +46,7 @@ from ._storage._core import ( PoolGroupBase, PoolGroupIndex, PoolIndex, + PoolIndex0, Slot, SlotId, ) @@ -159,7 +160,7 @@ class StorageManager: ) _life_cycles: LifeCycleRegistry _layer_to_life_cycle_ids: dict[LayerId, LifeCycleId] - _slot_to_page_indices: TypedIndexList[LifeCycleId, int] + _slot_to_page_indices: TypedIndexList[LifeCycleId, TypedIndexList[PoolIndex, int]] _buffer_attr: dict[BufferId, BufferAttr] _life_cycle_grouping: TypedIndexList[LifeCycleId, PoolGroupIndex] _levels: TypedIndexList[CacheLevel, CacheLevelManager] @@ -480,7 +481,8 @@ class StorageManager: self, lc_id: LifeCycleId, pages: Iterator[Page | None] ) -> Iterator[int | None]: "Reference implementation. Not fast enough for production." - scale = self._slot_to_page_indices[lc_id] + scale = self._slot_to_page_indices[lc_id][PoolIndex0] + assert all(scale == s for s in self._slot_to_page_indices[lc_id]) return (map_optional(page, lambda p: scale * int(p.slot_id)) for page in pages) def get_buffer_attr(self, layer_id: LayerId, data_role: DataRole) -> BufferAttr: @@ -492,8 +494,8 @@ class StorageManager: return self._levels[level].storage.slot_address(pg_idx, pool_idx, slot_id) def get_page_indices_for_slot(self, life_cycle: LifeCycleId, slot_id: SlotId) -> PageIndex: - scale = self._slot_to_page_indices[life_cycle] - return PageIndex(scale * slot_id) + scale = self._slot_to_page_indices[life_cycle][PoolIndex0] + return PageIndex(scale * int(slot_id)) def get_statistics( self, level: CacheLevel = GPU_LEVEL diff --git a/tests/unittest/kv_cache_manager_v2_tests/fake_engine.py b/tests/unittest/kv_cache_manager_v2_tests/fake_engine.py index c5c6d233ce..5bce76e5f3 100644 --- a/tests/unittest/kv_cache_manager_v2_tests/fake_engine.py +++ b/tests/unittest/kv_cache_manager_v2_tests/fake_engine.py @@ -138,7 +138,12 @@ class FakeEngine: pool = manager.get_mem_pool_base_address(layer_id, role) stride = manager.get_page_stride(layer_id, role) lc_id = manager._storage._layer_to_life_cycle_ids[layer_id] - pages = kv_cache.get_page_indices(lc_id, beam) + base_pages = kv_cache.get_base_page_indices(lc_id, beam) + page_scale = manager.get_page_index_scale(layer_id, role) + pages = [ + BAD_PAGE_INDEX if base_page is BAD_PAGE_INDEX else base_page * page_scale + for base_page in base_pages + ] capacity = kv_cache.capacity history_len = len(history) assert len(history) == history_len @@ -181,9 +186,14 @@ class FakeEngine: pool = manager.get_mem_pool_base_address(layer_id, role) stride = manager.get_page_stride(layer_id, role) lc_id = manager._storage._layer_to_life_cycle_ids[layer_id] - pages = kv_cache.get_page_indices(lc_id, beam)[ + base_pages = kv_cache.get_base_page_indices(lc_id, beam)[ : div_up(history_len + len(input), tokens_per_block) ] + page_scale = manager.get_page_index_scale(layer_id, role) + pages = [ + BAD_PAGE_INDEX if base_page is BAD_PAGE_INDEX else base_page * page_scale + for base_page in base_pages + ] capacity = kv_cache.capacity input_range = (history_len, history_len + len(input)) assert input_range[1] <= capacity diff --git a/tests/unittest/kv_cache_manager_v2_tests/test_kv_cache_manager_v2.py b/tests/unittest/kv_cache_manager_v2_tests/test_kv_cache_manager_v2.py index 73d6b0b75c..acedb8704d 100755 --- a/tests/unittest/kv_cache_manager_v2_tests/test_kv_cache_manager_v2.py +++ b/tests/unittest/kv_cache_manager_v2_tests/test_kv_cache_manager_v2.py @@ -36,6 +36,7 @@ if not TYPE_CHECKING and find_spec("kv_cache_manager_v2") is not None: BufferSlice, CacheLevel, CudaStream, + DataRole, DiskCacheTierConfig, GpuCacheTierConfig, HostCacheTierConfig, @@ -78,6 +79,7 @@ else: BufferSlice, CacheLevel, CudaStream, + DataRole, DiskCacheTierConfig, GpuCacheTierConfig, HostCacheTierConfig, @@ -340,12 +342,12 @@ class TestNoBatching(TestKVCacheManagerV2): if use_external_page_index_buf: max_num_blocks = div_up(seq_len, self.cfg.tokens_per_block) num_layer_groups = len(self.manager.layer_grouping) - page_indices = [ + base_page_indices = [ array.array("i", [-1]) * max_num_blocks for _ in range(num_layer_groups) ] for id in range(num_layer_groups): - req0.kv_cache.set_page_index_buf( - DEFAULT_BEAM_INDEX, LayerGroupId(id), memoryview(page_indices[id]) + req0.kv_cache.set_base_page_index_buf( + DEFAULT_BEAM_INDEX, LayerGroupId(id), memoryview(base_page_indices[id]) ) with TemporaryCudaStream([]) as s: stream = cast(CudaStream, s.handle) @@ -367,12 +369,12 @@ class TestNoBatching(TestKVCacheManagerV2): if use_external_page_index_buf: max_num_blocks = div_up(seq_len, self.cfg.tokens_per_block) num_layer_groups = len(self.manager.layer_grouping) - page_indices = [ + base_page_indices = [ array.array("i", [-1]) * max_num_blocks for _ in range(num_layer_groups) ] for id in range(num_layer_groups): - req0.kv_cache.set_page_index_buf( - DEFAULT_BEAM_INDEX, LayerGroupId(id), memoryview(page_indices[id]) + req0.kv_cache.set_base_page_index_buf( + DEFAULT_BEAM_INDEX, LayerGroupId(id), memoryview(base_page_indices[id]) ) with TemporaryCudaStream([]) as s: stream = cast(CudaStream, s.handle) @@ -954,5 +956,105 @@ class TestDisaggregatedServing(unittest.TestCase): node.kv_cache.close() +class TestComplexModels(unittest.TestCase): + def setUp(self) -> None: + init_cuda_once() + gc.collect() + gc.disable() + + def tearDown(self) -> None: + gc.enable() + + def test_complex_model_0(self) -> None: + role = DataRole("buf0") + layers = [ + AttentionLayerConfig( + layer_id=LayerId(0), + buffers=[BufferConfig(role=role, size=131072)], + sliding_window_size=128, + num_sink_tokens=None, + ), + AttentionLayerConfig( + layer_id=LayerId(1), + buffers=[BufferConfig(role=role, size=131072)], + sliding_window_size=128, + num_sink_tokens=None, + ), + AttentionLayerConfig( + layer_id=LayerId(2), + buffers=[BufferConfig(role=role, size=98304)], + sliding_window_size=None, + num_sink_tokens=None, + ), + AttentionLayerConfig( + layer_id=LayerId(3), + buffers=[BufferConfig(role=role, size=163840)], + sliding_window_size=64, + num_sink_tokens=None, + ), + AttentionLayerConfig( + layer_id=LayerId(4), + buffers=[BufferConfig(role=role, size=163840)], + sliding_window_size=64, + num_sink_tokens=None, + ), + AttentionLayerConfig( + layer_id=LayerId(5), + buffers=[BufferConfig(role=role, size=65536)], + sliding_window_size=None, + num_sink_tokens=None, + ), + AttentionLayerConfig( + layer_id=LayerId(6), + buffers=[BufferConfig(role=role, size=131072)], + sliding_window_size=64, + num_sink_tokens=None, + ), + AttentionLayerConfig( + layer_id=LayerId(7), + buffers=[BufferConfig(role=role, size=131072)], + sliding_window_size=64, + num_sink_tokens=None, + ), + AttentionLayerConfig( + layer_id=LayerId(8), + buffers=[BufferConfig(role=role, size=131072)], + sliding_window_size=128, + num_sink_tokens=None, + ), + AttentionLayerConfig( + layer_id=LayerId(9), + buffers=[BufferConfig(role=role, size=32768)], + sliding_window_size=None, + num_sink_tokens=None, + ), + AttentionLayerConfig( + layer_id=LayerId(10), + buffers=[BufferConfig(role=role, size=262144)], + sliding_window_size=128, + num_sink_tokens=None, + ), + AttentionLayerConfig( + layer_id=LayerId(11), + buffers=[BufferConfig(role=role, size=262144)], + sliding_window_size=128, + num_sink_tokens=None, + ), + ] + + config = KVCacheManagerConfig( + tokens_per_block=128, + vocab_size=1024, + cache_tiers=[ + GpuCacheTierConfig(quota=1024 * 1024 * 1024), + HostCacheTierConfig(quota=8000 << 20), + ], + max_util_for_resume=0.95, + layers=layers, + ) + manager = KVCacheManager(config) + del manager + + if __name__ == "__main__": unittest.main()