mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[None][feat] Enhance support for complex models (#11254)
Signed-off-by: Yao Yao <lowsfer@users.noreply.github.com>
This commit is contained in:
parent
4c1d9d0c10
commit
d9b936be94
@ -141,6 +141,9 @@ class _KVCache:
|
||||
def set_page_index_buf(
|
||||
self, beam_idx: BeamIndex, layer_group_id: LayerGroupId, buf: memoryview | None
|
||||
) -> None: ...
|
||||
def set_base_page_index_buf(
|
||||
self, beam_idx: BeamIndex, layer_group_id: LayerGroupId, buf: memoryview | None
|
||||
) -> None: ...
|
||||
@property
|
||||
def manager(self) -> "KVCacheManager": ...
|
||||
@property
|
||||
@ -157,6 +160,9 @@ class _KVCache:
|
||||
@beam_width.setter
|
||||
def beam_width(self, beam_width: BeamIndex) -> None: ...
|
||||
def get_page_indices(self, layer_group_id: int, beam_id: BeamIndex = ...) -> IndexSeq: ...
|
||||
def get_base_page_indices(
|
||||
self, layer_group_id: LayerGroupId, beam_id: BeamIndex = DEFAULT_BEAM_INDEX
|
||||
) -> IndexSeq: ...
|
||||
def get_aggregated_page_indices(
|
||||
self,
|
||||
layer_group_id: LayerGroupId,
|
||||
@ -229,6 +235,7 @@ class KVCacheManager:
|
||||
def get_mem_pool_base_address(self, layer_id: LayerId, data_role: DataRole) -> MemAddress: ...
|
||||
def get_page_stride(self, layer_id: LayerId, data_role: DataRole) -> int: ...
|
||||
def get_page_index_upper_bound(self, layer_id: LayerId, data_role: DataRole) -> int: ...
|
||||
def get_page_index_scale(self, layer_id: LayerId, data_role: DataRole) -> int: ...
|
||||
def create_kv_cache(
|
||||
self,
|
||||
lora_task_id: int | None = None,
|
||||
|
||||
@ -168,7 +168,8 @@ class _KVCache:
|
||||
"_history_length",
|
||||
"_commit_state",
|
||||
"_blocks",
|
||||
"_page_indices",
|
||||
"_base_page_indices",
|
||||
"_page_indices", # Deprecated. To be removed in the future.
|
||||
"_committed_tokens",
|
||||
"_num_committed_blocks",
|
||||
"_finish_event",
|
||||
@ -193,6 +194,7 @@ class _KVCache:
|
||||
_blocks: TypedIndexList[BlockOrdinal, SeqBlock]
|
||||
# we maintain _page_indices to accelerate the get_page_indices() API. In principle it can be
|
||||
# computed on the fly, but that would be slow due to python.
|
||||
_base_page_indices: TypedIndexList[BeamIndex, TypedIndexList[LifeCycleId, IndexSeq]]
|
||||
_page_indices: TypedIndexList[BeamIndex, TypedIndexList[LifeCycleId, IndexSeq]]
|
||||
_committed_tokens: list[TokenIdExt]
|
||||
# Sometimes we can't commit a block because all its tokens are already covered by another block in
|
||||
@ -226,6 +228,10 @@ class _KVCache:
|
||||
self._history_length = 0
|
||||
self._commit_state = self.CommitState.ALLOWED
|
||||
self._blocks = cast(TypedIndexList, [])
|
||||
self._base_page_indices = make_typed(
|
||||
lambda: make_typed(lambda: array.array("i"), self.manager._storage.num_life_cycles),
|
||||
self.beam_width,
|
||||
)
|
||||
self._page_indices = make_typed(
|
||||
lambda: make_typed(lambda: array.array("i"), self.manager._storage.num_life_cycles),
|
||||
self.beam_width,
|
||||
@ -242,7 +248,10 @@ class _KVCache:
|
||||
def set_page_index_buf(
|
||||
self, beam_idx: BeamIndex, layer_group_id: LayerGroupId, buf: memoryview | None
|
||||
) -> None:
|
||||
"""Set the buffer for page indices, so we directly update indices in user buffer to
|
||||
"""
|
||||
Deprecated. Use set_base_page_index_buf() instead.
|
||||
|
||||
Set the buffer for page indices, so we directly update indices in user buffer to
|
||||
avoid user-side copy. This is the zero-copy alternative of get_page_indices()"""
|
||||
length = self.num_blocks
|
||||
old_indices = self._page_indices[beam_idx][layer_group_id]
|
||||
@ -256,6 +265,28 @@ class _KVCache:
|
||||
new_indices = buf
|
||||
self._page_indices[beam_idx][layer_group_id] = new_indices
|
||||
|
||||
def set_base_page_index_buf(
|
||||
self, beam_idx: BeamIndex, layer_group_id: LayerGroupId, buf: memoryview | None
|
||||
) -> None:
|
||||
"""
|
||||
Set the buffer for base page indices, so we directly update indices in user buffer to
|
||||
avoid user-side copy. This is the zero-copy alternative of get_base_page_indices().
|
||||
|
||||
Note that base page indices are not meant for direct use in the kernels. They need to
|
||||
be scaled by kv_cache_manager.page_index_scale().
|
||||
"""
|
||||
length = self.num_blocks
|
||||
old_indices = self._base_page_indices[beam_idx][layer_group_id]
|
||||
new_indices: IndexSeq
|
||||
if buf is None:
|
||||
new_indices = array.array("i", old_indices[:length])
|
||||
else:
|
||||
assert buf.ndim == 1 and buf.format == "i" and len(buf) >= length
|
||||
buf[:length] = old_indices[:length]
|
||||
buf[length:] = array.array("i", [BAD_PAGE_INDEX]) * (len(buf) - length)
|
||||
new_indices = buf
|
||||
self._base_page_indices[beam_idx][layer_group_id] = new_indices
|
||||
|
||||
@property
|
||||
def manager(self) -> "KVCacheManager":
|
||||
return self._manager
|
||||
@ -314,6 +345,9 @@ class _KVCache:
|
||||
def get_page_indices(
|
||||
self, layer_group_id: LayerGroupId, beam_id: BeamIndex = DEFAULT_BEAM_INDEX
|
||||
) -> IndexSeq:
|
||||
"""
|
||||
Deprecated. Use get_base_page_indices() instead.
|
||||
"""
|
||||
indices = self._page_indices[beam_id][layer_group_id]
|
||||
assert NDEBUG or all(
|
||||
v == value_or(r, BAD_PAGE_INDEX)
|
||||
@ -321,6 +355,16 @@ class _KVCache:
|
||||
)
|
||||
return indices
|
||||
|
||||
def get_base_page_indices(
|
||||
self, layer_group_id: LayerGroupId, beam_id: BeamIndex = DEFAULT_BEAM_INDEX
|
||||
) -> IndexSeq:
|
||||
indices = self._base_page_indices[beam_id][layer_group_id]
|
||||
assert NDEBUG or all(
|
||||
v == value_or(r, BAD_PAGE_INDEX)
|
||||
for v, r in zip(indices, self._get_base_page_indices_ref(layer_group_id, beam_id))
|
||||
)
|
||||
return indices
|
||||
|
||||
def get_aggregated_page_indices(
|
||||
self,
|
||||
layer_group_id: LayerGroupId,
|
||||
@ -380,15 +424,16 @@ class _KVCache:
|
||||
if new_num_blocks < old_num_blocks:
|
||||
with self._record_event():
|
||||
del self._blocks[new_num_blocks:]
|
||||
for beam_indices in self._page_indices:
|
||||
for indices in beam_indices:
|
||||
assert all(i == BAD_PAGE_INDEX for i in indices[new_num_blocks:])
|
||||
if type(indices) is array.array:
|
||||
del indices[new_num_blocks:]
|
||||
else:
|
||||
indices[new_num_blocks:] = array.array("i", [BAD_PAGE_INDEX]) * (
|
||||
len(indices) - new_num_blocks
|
||||
)
|
||||
for page_indices in (self._base_page_indices, self._page_indices):
|
||||
for beam_indices in page_indices:
|
||||
for indices in beam_indices:
|
||||
assert all(i == BAD_PAGE_INDEX for i in indices[new_num_blocks:])
|
||||
if type(indices) is array.array:
|
||||
del indices[new_num_blocks:]
|
||||
else:
|
||||
indices[new_num_blocks:] = array.array("i", [BAD_PAGE_INDEX]) * (
|
||||
len(indices) - new_num_blocks
|
||||
)
|
||||
elif new_num_blocks > old_num_blocks:
|
||||
num_new_slots = filled_list(0, num_life_cycles)
|
||||
stale_ranges = [
|
||||
@ -408,13 +453,14 @@ class _KVCache:
|
||||
except OutOfPagesError:
|
||||
self._lock_held_blocks(backup_holders)
|
||||
return False
|
||||
for beam_indices in self._page_indices:
|
||||
for indices in beam_indices:
|
||||
if type(indices) is array.array:
|
||||
assert len(indices) == old_num_blocks
|
||||
indices.extend([BAD_PAGE_INDEX] * (new_num_blocks - old_num_blocks))
|
||||
else:
|
||||
assert len(indices) >= new_num_blocks
|
||||
for page_indices in (self._base_page_indices, self._page_indices):
|
||||
for beam_indices in page_indices:
|
||||
for indices in beam_indices:
|
||||
if type(indices) is array.array:
|
||||
assert len(indices) == old_num_blocks
|
||||
indices.extend([BAD_PAGE_INDEX] * (new_num_blocks - old_num_blocks))
|
||||
else:
|
||||
assert len(indices) >= new_num_blocks
|
||||
stream_wait_events(
|
||||
self.cuda_stream, (s.ready_event for s in chain.from_iterable(slots))
|
||||
)
|
||||
@ -543,6 +589,10 @@ class _KVCache:
|
||||
assert self.status == self.Status.ACTIVE
|
||||
assert self._check_sanity()
|
||||
assert self._finish_event is None
|
||||
for beam_idx, beam_indices in typed_enumerate(self._base_page_indices):
|
||||
for lc, indices in typed_enumerate(beam_indices):
|
||||
if type(indices) is memoryview:
|
||||
self.set_base_page_index_buf(beam_idx, lc, None)
|
||||
for beam_idx, beam_indices in typed_enumerate(self._page_indices):
|
||||
for lc, indices in typed_enumerate(beam_indices):
|
||||
if type(indices) is memoryview:
|
||||
@ -999,12 +1049,13 @@ class _KVCache:
|
||||
"failure by disallowing partial matching."
|
||||
)
|
||||
self._num_committed_blocks = BlockOrdinal(len(self._committed_tokens) // tokens_per_block)
|
||||
for beam_indices in self._page_indices:
|
||||
for indices in beam_indices:
|
||||
if type(indices) is array.array:
|
||||
indices.extend([BAD_PAGE_INDEX] * (self.num_blocks - len(indices)))
|
||||
else:
|
||||
assert len(indices) >= self.num_blocks
|
||||
for page_indices in (self._base_page_indices, self._page_indices):
|
||||
for beam_indices in page_indices:
|
||||
for indices in beam_indices:
|
||||
if type(indices) is array.array:
|
||||
indices.extend([BAD_PAGE_INDEX] * (self.num_blocks - len(indices)))
|
||||
else:
|
||||
assert len(indices) >= self.num_blocks
|
||||
|
||||
def _clear_blocks(self) -> None:
|
||||
# drop the last block first
|
||||
@ -1020,6 +1071,14 @@ class _KVCache:
|
||||
finally:
|
||||
self._finish_event = None
|
||||
|
||||
def _update_base_page_index(
|
||||
self, beam_idx: BeamIndex, ordinal: BlockOrdinal, lc: LifeCycleId, page_index: PageIndex
|
||||
) -> PageIndex:
|
||||
indices = self._base_page_indices[beam_idx][lc]
|
||||
old = PageIndex(indices[ordinal])
|
||||
indices[ordinal] = page_index
|
||||
return old
|
||||
|
||||
def _update_page_index(
|
||||
self, beam_idx: BeamIndex, ordinal: BlockOrdinal, lc: LifeCycleId, page_index: PageIndex
|
||||
) -> PageIndex:
|
||||
@ -1042,6 +1101,13 @@ class _KVCache:
|
||||
)
|
||||
return self._storage.get_page_indices_ref(lc, pages)
|
||||
|
||||
def _get_base_page_indices_ref(
|
||||
self, lc: LifeCycleId, beam_id: BeamIndex = DEFAULT_BEAM_INDEX
|
||||
) -> Iterator[int | None]:
|
||||
assert beam_id < self.beam_width
|
||||
assert self.is_active
|
||||
return self.get_aggregated_page_indices(lc, beam_id)
|
||||
|
||||
def _shortcut_set_capacity(self, capacity: int) -> bool:
|
||||
"Shortcut for cases without side effects. Just for better performance."
|
||||
tokens_per_block = self.tokens_per_block
|
||||
|
||||
@ -152,6 +152,14 @@ class KVCacheManager:
|
||||
slot_size = pool_group.slot_size[pool_idx]
|
||||
return exact_div(slot_size, attr.size) * num_slots - exact_div(attr.offset, attr.size)
|
||||
|
||||
def get_page_index_scale(self, layer_id: LayerId, data_role: DataRole) -> int:
|
||||
"""
|
||||
The multiplier to convert from base page indices to page indices expected by operators/kernels.
|
||||
"""
|
||||
storage = self._storage
|
||||
attr = storage.get_buffer_attr(layer_id, data_role)
|
||||
return storage._slot_to_page_indices[attr.life_cycle_id][attr.pool_index]
|
||||
|
||||
def create_kv_cache(
|
||||
self,
|
||||
lora_task_id: int | None = None,
|
||||
|
||||
@ -41,7 +41,7 @@ if TYPE_CHECKING:
|
||||
from ._eviction_controller import NodeRef
|
||||
from ._exceptions import LogicError
|
||||
from ._life_cycle_registry import LifeCycleId
|
||||
from ._storage._core import Slot
|
||||
from ._storage._core import PoolIndex0, Slot
|
||||
from ._utils import (
|
||||
CachedCudaEvent,
|
||||
assert_critical,
|
||||
@ -390,6 +390,10 @@ class _SharedPageLock:
|
||||
if not skip_wait:
|
||||
self.page.ready_event.wait_in_stream(kv_cache.cuda_stream)
|
||||
self._user = LockOwner(rawref.ref(kv_cache), beam_index, ordinal, life_cycle)
|
||||
old_base_index = kv_cache._update_base_page_index(
|
||||
beam_index, ordinal, life_cycle, PageIndex(self.page.slot_id)
|
||||
)
|
||||
assert old_base_index == BAD_PAGE_INDEX
|
||||
new_index = self._get_page_index()
|
||||
old_index = kv_cache._update_page_index(beam_index, ordinal, life_cycle, new_index)
|
||||
assert old_index == BAD_PAGE_INDEX
|
||||
@ -402,19 +406,29 @@ class _SharedPageLock:
|
||||
assert self._uniq_lock is not None
|
||||
page = self.page
|
||||
self._uniq_lock.finish_events.append(unwrap_rawref(self._user.kv_cache).finish_event)
|
||||
beam_index = self._user.beam_index
|
||||
ordinal = self._user.ordinal
|
||||
life_cycle = self._user.life_cycle
|
||||
kv_cache = unwrap_rawref(self._user.kv_cache)
|
||||
new_index = BAD_PAGE_INDEX
|
||||
old_index = unwrap_rawref(self._user.kv_cache)._update_page_index(
|
||||
self._user.beam_index, self._user.ordinal, self._user.life_cycle, new_index
|
||||
old_base_index = kv_cache._update_base_page_index(
|
||||
beam_index, ordinal, life_cycle, new_index
|
||||
)
|
||||
assert NDEBUG or old_base_index == self._get_base_page_index()
|
||||
old_index = kv_cache._update_page_index(beam_index, ordinal, life_cycle, new_index)
|
||||
assert NDEBUG or old_index == self._get_page_index()
|
||||
self._uniq_lock = None
|
||||
return page
|
||||
|
||||
def _get_page_index(self) -> PageIndex:
|
||||
storage = unwrap_rawref(self._user.kv_cache).manager._storage
|
||||
num_buffers_per_slot = storage._slot_to_page_indices[self._user.life_cycle]
|
||||
user = self._user
|
||||
num_buffers_per_slot = storage._slot_to_page_indices[user.life_cycle][PoolIndex0]
|
||||
return PageIndex(self.page.slot_id * num_buffers_per_slot)
|
||||
|
||||
def _get_base_page_index(self) -> PageIndex:
|
||||
return PageIndex(self.page.slot_id)
|
||||
|
||||
|
||||
BlockPage = _SharedPageLock | _PageHolder | None
|
||||
|
||||
|
||||
@ -119,14 +119,13 @@ class StorageConfig:
|
||||
offset += cb.single_buffer_size
|
||||
return ret
|
||||
|
||||
def slot_to_page_indices(self) -> TypedIndexList[LifeCycleId, int]:
|
||||
ret = filled_list(0, self.num_life_cycles)
|
||||
def slot_to_page_indices(self) -> TypedIndexList[LifeCycleId, TypedIndexList[PoolIndex, int]]:
|
||||
ret = [[]] * self.num_life_cycles
|
||||
for pg in self.slot_desc_list:
|
||||
for slot in pg.variants:
|
||||
life_cycle = slot.life_cycle_id
|
||||
page = slot.coalesced_buffers[0]
|
||||
ret[life_cycle] = page.num_buffers
|
||||
return ret
|
||||
ret[life_cycle] = [cb.num_buffers for cb in slot.coalesced_buffers]
|
||||
return cast(TypedIndexList[LifeCycleId, TypedIndexList[PoolIndex, int]], ret)
|
||||
|
||||
def layer_to_life_cycle_ids(self) -> dict[LayerId, LifeCycleId]:
|
||||
map = dict[LayerId, LifeCycleId]()
|
||||
@ -160,11 +159,6 @@ def create_storage_config(config: KVCacheManagerConfig) -> StorageConfig:
|
||||
# @TODO: add test for this case.
|
||||
slot_groups: list[SlotDescVariant] = []
|
||||
for life_cycle_id, size_to_buffers in buffer_groups.items():
|
||||
assert len(set(len(buffer_ids) for buffer_ids in size_to_buffers.values())) == 1, (
|
||||
"Not yet supported. While we can support this easily, we need to know whether the kernels "
|
||||
"need to share page indices or not. We haven't seen such models, yet. So we leave this as a "
|
||||
"future work."
|
||||
)
|
||||
slots = [
|
||||
CoalescedBuffer(size, tuple(buffer_ids)) for size, buffer_ids in size_to_buffers.items()
|
||||
]
|
||||
|
||||
@ -58,6 +58,10 @@ PoolGroupIndex = NewType("PoolGroupIndex", int)
|
||||
PoolIndex = NewType("PoolIndex", int)
|
||||
SlotId = NewType("SlotId", int)
|
||||
|
||||
# A temporary work-around while migrating to new page index API.
|
||||
# To be removed later.
|
||||
PoolIndex0 = PoolIndex(0)
|
||||
|
||||
|
||||
class SlotPoolBase(abc.ABC):
|
||||
_slot_size: int
|
||||
|
||||
@ -46,6 +46,7 @@ from ._storage._core import (
|
||||
PoolGroupBase,
|
||||
PoolGroupIndex,
|
||||
PoolIndex,
|
||||
PoolIndex0,
|
||||
Slot,
|
||||
SlotId,
|
||||
)
|
||||
@ -159,7 +160,7 @@ class StorageManager:
|
||||
)
|
||||
_life_cycles: LifeCycleRegistry
|
||||
_layer_to_life_cycle_ids: dict[LayerId, LifeCycleId]
|
||||
_slot_to_page_indices: TypedIndexList[LifeCycleId, int]
|
||||
_slot_to_page_indices: TypedIndexList[LifeCycleId, TypedIndexList[PoolIndex, int]]
|
||||
_buffer_attr: dict[BufferId, BufferAttr]
|
||||
_life_cycle_grouping: TypedIndexList[LifeCycleId, PoolGroupIndex]
|
||||
_levels: TypedIndexList[CacheLevel, CacheLevelManager]
|
||||
@ -480,7 +481,8 @@ class StorageManager:
|
||||
self, lc_id: LifeCycleId, pages: Iterator[Page | None]
|
||||
) -> Iterator[int | None]:
|
||||
"Reference implementation. Not fast enough for production."
|
||||
scale = self._slot_to_page_indices[lc_id]
|
||||
scale = self._slot_to_page_indices[lc_id][PoolIndex0]
|
||||
assert all(scale == s for s in self._slot_to_page_indices[lc_id])
|
||||
return (map_optional(page, lambda p: scale * int(p.slot_id)) for page in pages)
|
||||
|
||||
def get_buffer_attr(self, layer_id: LayerId, data_role: DataRole) -> BufferAttr:
|
||||
@ -492,8 +494,8 @@ class StorageManager:
|
||||
return self._levels[level].storage.slot_address(pg_idx, pool_idx, slot_id)
|
||||
|
||||
def get_page_indices_for_slot(self, life_cycle: LifeCycleId, slot_id: SlotId) -> PageIndex:
|
||||
scale = self._slot_to_page_indices[life_cycle]
|
||||
return PageIndex(scale * slot_id)
|
||||
scale = self._slot_to_page_indices[life_cycle][PoolIndex0]
|
||||
return PageIndex(scale * int(slot_id))
|
||||
|
||||
def get_statistics(
|
||||
self, level: CacheLevel = GPU_LEVEL
|
||||
|
||||
@ -138,7 +138,12 @@ class FakeEngine:
|
||||
pool = manager.get_mem_pool_base_address(layer_id, role)
|
||||
stride = manager.get_page_stride(layer_id, role)
|
||||
lc_id = manager._storage._layer_to_life_cycle_ids[layer_id]
|
||||
pages = kv_cache.get_page_indices(lc_id, beam)
|
||||
base_pages = kv_cache.get_base_page_indices(lc_id, beam)
|
||||
page_scale = manager.get_page_index_scale(layer_id, role)
|
||||
pages = [
|
||||
BAD_PAGE_INDEX if base_page is BAD_PAGE_INDEX else base_page * page_scale
|
||||
for base_page in base_pages
|
||||
]
|
||||
capacity = kv_cache.capacity
|
||||
history_len = len(history)
|
||||
assert len(history) == history_len
|
||||
@ -181,9 +186,14 @@ class FakeEngine:
|
||||
pool = manager.get_mem_pool_base_address(layer_id, role)
|
||||
stride = manager.get_page_stride(layer_id, role)
|
||||
lc_id = manager._storage._layer_to_life_cycle_ids[layer_id]
|
||||
pages = kv_cache.get_page_indices(lc_id, beam)[
|
||||
base_pages = kv_cache.get_base_page_indices(lc_id, beam)[
|
||||
: div_up(history_len + len(input), tokens_per_block)
|
||||
]
|
||||
page_scale = manager.get_page_index_scale(layer_id, role)
|
||||
pages = [
|
||||
BAD_PAGE_INDEX if base_page is BAD_PAGE_INDEX else base_page * page_scale
|
||||
for base_page in base_pages
|
||||
]
|
||||
capacity = kv_cache.capacity
|
||||
input_range = (history_len, history_len + len(input))
|
||||
assert input_range[1] <= capacity
|
||||
|
||||
@ -36,6 +36,7 @@ if not TYPE_CHECKING and find_spec("kv_cache_manager_v2") is not None:
|
||||
BufferSlice,
|
||||
CacheLevel,
|
||||
CudaStream,
|
||||
DataRole,
|
||||
DiskCacheTierConfig,
|
||||
GpuCacheTierConfig,
|
||||
HostCacheTierConfig,
|
||||
@ -78,6 +79,7 @@ else:
|
||||
BufferSlice,
|
||||
CacheLevel,
|
||||
CudaStream,
|
||||
DataRole,
|
||||
DiskCacheTierConfig,
|
||||
GpuCacheTierConfig,
|
||||
HostCacheTierConfig,
|
||||
@ -340,12 +342,12 @@ class TestNoBatching(TestKVCacheManagerV2):
|
||||
if use_external_page_index_buf:
|
||||
max_num_blocks = div_up(seq_len, self.cfg.tokens_per_block)
|
||||
num_layer_groups = len(self.manager.layer_grouping)
|
||||
page_indices = [
|
||||
base_page_indices = [
|
||||
array.array("i", [-1]) * max_num_blocks for _ in range(num_layer_groups)
|
||||
]
|
||||
for id in range(num_layer_groups):
|
||||
req0.kv_cache.set_page_index_buf(
|
||||
DEFAULT_BEAM_INDEX, LayerGroupId(id), memoryview(page_indices[id])
|
||||
req0.kv_cache.set_base_page_index_buf(
|
||||
DEFAULT_BEAM_INDEX, LayerGroupId(id), memoryview(base_page_indices[id])
|
||||
)
|
||||
with TemporaryCudaStream([]) as s:
|
||||
stream = cast(CudaStream, s.handle)
|
||||
@ -367,12 +369,12 @@ class TestNoBatching(TestKVCacheManagerV2):
|
||||
if use_external_page_index_buf:
|
||||
max_num_blocks = div_up(seq_len, self.cfg.tokens_per_block)
|
||||
num_layer_groups = len(self.manager.layer_grouping)
|
||||
page_indices = [
|
||||
base_page_indices = [
|
||||
array.array("i", [-1]) * max_num_blocks for _ in range(num_layer_groups)
|
||||
]
|
||||
for id in range(num_layer_groups):
|
||||
req0.kv_cache.set_page_index_buf(
|
||||
DEFAULT_BEAM_INDEX, LayerGroupId(id), memoryview(page_indices[id])
|
||||
req0.kv_cache.set_base_page_index_buf(
|
||||
DEFAULT_BEAM_INDEX, LayerGroupId(id), memoryview(base_page_indices[id])
|
||||
)
|
||||
with TemporaryCudaStream([]) as s:
|
||||
stream = cast(CudaStream, s.handle)
|
||||
@ -954,5 +956,105 @@ class TestDisaggregatedServing(unittest.TestCase):
|
||||
node.kv_cache.close()
|
||||
|
||||
|
||||
class TestComplexModels(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
init_cuda_once()
|
||||
gc.collect()
|
||||
gc.disable()
|
||||
|
||||
def tearDown(self) -> None:
|
||||
gc.enable()
|
||||
|
||||
def test_complex_model_0(self) -> None:
|
||||
role = DataRole("buf0")
|
||||
layers = [
|
||||
AttentionLayerConfig(
|
||||
layer_id=LayerId(0),
|
||||
buffers=[BufferConfig(role=role, size=131072)],
|
||||
sliding_window_size=128,
|
||||
num_sink_tokens=None,
|
||||
),
|
||||
AttentionLayerConfig(
|
||||
layer_id=LayerId(1),
|
||||
buffers=[BufferConfig(role=role, size=131072)],
|
||||
sliding_window_size=128,
|
||||
num_sink_tokens=None,
|
||||
),
|
||||
AttentionLayerConfig(
|
||||
layer_id=LayerId(2),
|
||||
buffers=[BufferConfig(role=role, size=98304)],
|
||||
sliding_window_size=None,
|
||||
num_sink_tokens=None,
|
||||
),
|
||||
AttentionLayerConfig(
|
||||
layer_id=LayerId(3),
|
||||
buffers=[BufferConfig(role=role, size=163840)],
|
||||
sliding_window_size=64,
|
||||
num_sink_tokens=None,
|
||||
),
|
||||
AttentionLayerConfig(
|
||||
layer_id=LayerId(4),
|
||||
buffers=[BufferConfig(role=role, size=163840)],
|
||||
sliding_window_size=64,
|
||||
num_sink_tokens=None,
|
||||
),
|
||||
AttentionLayerConfig(
|
||||
layer_id=LayerId(5),
|
||||
buffers=[BufferConfig(role=role, size=65536)],
|
||||
sliding_window_size=None,
|
||||
num_sink_tokens=None,
|
||||
),
|
||||
AttentionLayerConfig(
|
||||
layer_id=LayerId(6),
|
||||
buffers=[BufferConfig(role=role, size=131072)],
|
||||
sliding_window_size=64,
|
||||
num_sink_tokens=None,
|
||||
),
|
||||
AttentionLayerConfig(
|
||||
layer_id=LayerId(7),
|
||||
buffers=[BufferConfig(role=role, size=131072)],
|
||||
sliding_window_size=64,
|
||||
num_sink_tokens=None,
|
||||
),
|
||||
AttentionLayerConfig(
|
||||
layer_id=LayerId(8),
|
||||
buffers=[BufferConfig(role=role, size=131072)],
|
||||
sliding_window_size=128,
|
||||
num_sink_tokens=None,
|
||||
),
|
||||
AttentionLayerConfig(
|
||||
layer_id=LayerId(9),
|
||||
buffers=[BufferConfig(role=role, size=32768)],
|
||||
sliding_window_size=None,
|
||||
num_sink_tokens=None,
|
||||
),
|
||||
AttentionLayerConfig(
|
||||
layer_id=LayerId(10),
|
||||
buffers=[BufferConfig(role=role, size=262144)],
|
||||
sliding_window_size=128,
|
||||
num_sink_tokens=None,
|
||||
),
|
||||
AttentionLayerConfig(
|
||||
layer_id=LayerId(11),
|
||||
buffers=[BufferConfig(role=role, size=262144)],
|
||||
sliding_window_size=128,
|
||||
num_sink_tokens=None,
|
||||
),
|
||||
]
|
||||
|
||||
config = KVCacheManagerConfig(
|
||||
tokens_per_block=128,
|
||||
vocab_size=1024,
|
||||
cache_tiers=[
|
||||
GpuCacheTierConfig(quota=1024 * 1024 * 1024),
|
||||
HostCacheTierConfig(quota=8000 << 20),
|
||||
],
|
||||
max_util_for_resume=0.95,
|
||||
layers=layers,
|
||||
)
|
||||
manager = KVCacheManager(config)
|
||||
del manager
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user