mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
Pack KV caches into contiguous per-block allocations for DeepSeek V4
For DeepSeek V4, pack all layer data contiguously per block so that NIXL registers one region per block instead of one per layer. Full-attention MLA + SWA/compressor caches share one contiguous allocation per block. Each layer gets an as_strided view with storage_offset into the packed backing tensor. The packed backing tensor and block_stride are passed through the cross-layer KV cache registration API so NIXL registers one region with block_len=block_stride instead of many separate regions. Previously: 92 NIXL regions, 92 tiny P2P transfers per block (~16KB). Now: 1 region, 1 large RDMA transfer per block (~1.48MB). Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
@@ -0,0 +1,214 @@
|
|||||||
|
"""Tests for contiguous KV block packing in _get_kv_cache_config_deepseek_v4."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.v1.core.kv_cache_utils import _get_kv_cache_config_deepseek_v4
|
||||||
|
from vllm.v1.kv_cache_interface import (
|
||||||
|
KVCacheGroupSpec,
|
||||||
|
MLAAttentionSpec,
|
||||||
|
UniformTypeKVCacheSpecs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mla_spec(page_size: int, block_size: int = 256) -> MLAAttentionSpec:
|
||||||
|
return MLAAttentionSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=512,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
page_size_padded=page_size,
|
||||||
|
cache_dtype_str="fp8_ds_mla",
|
||||||
|
model_version="deepseek_v4",
|
||||||
|
alignment=576,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_groups(n_c4, n_c128, n_swa):
|
||||||
|
PS_C4_MLA = 37440
|
||||||
|
PS_C4_IDX = 8640
|
||||||
|
PS_C128 = 1728
|
||||||
|
PS_SWA = 37440
|
||||||
|
|
||||||
|
mla_specs = {}
|
||||||
|
for i in range(n_c4):
|
||||||
|
mla_specs[f"c4_mla.{i}"] = _make_mla_spec(PS_C4_MLA)
|
||||||
|
mla_specs[f"c4_idx.{i}"] = _make_mla_spec(PS_C4_IDX)
|
||||||
|
for i in range(n_c128):
|
||||||
|
mla_specs[f"c128_mla.{i}"] = _make_mla_spec(PS_C128)
|
||||||
|
|
||||||
|
mla_group = KVCacheGroupSpec(
|
||||||
|
layer_names=list(mla_specs.keys()),
|
||||||
|
kv_cache_spec=UniformTypeKVCacheSpecs(block_size=256, kv_cache_specs=mla_specs),
|
||||||
|
)
|
||||||
|
|
||||||
|
swa_specs = {}
|
||||||
|
for i in range(n_swa):
|
||||||
|
swa_specs[f"swa.{i}"] = _make_mla_spec(PS_SWA)
|
||||||
|
|
||||||
|
swa_group = KVCacheGroupSpec(
|
||||||
|
layer_names=list(swa_specs.keys()),
|
||||||
|
kv_cache_spec=UniformTypeKVCacheSpecs(block_size=256, kv_cache_specs=swa_specs),
|
||||||
|
)
|
||||||
|
|
||||||
|
return [mla_group, swa_group]
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_vllm_config():
|
||||||
|
config = MagicMock()
|
||||||
|
config.scheduler_config.num_gpu_blocks_override = None
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def _run(n_c4=3, n_c128=2, n_swa=5, mem=100 * 1024 * 1024):
|
||||||
|
groups = _make_groups(n_c4, n_c128, n_swa)
|
||||||
|
return _get_kv_cache_config_deepseek_v4(_mock_vllm_config(), groups, mem)
|
||||||
|
|
||||||
|
|
||||||
|
def _split_tensors(tensors):
|
||||||
|
mla = [t for t in tensors if any("mla" in n or "idx" in n for n in t.shared_by)]
|
||||||
|
swa = [t for t in tensors if any("swa" in n for n in t.shared_by)]
|
||||||
|
return mla, swa
|
||||||
|
|
||||||
|
|
||||||
|
class TestMlaContiguousPacking:
|
||||||
|
def test_all_mla_share_one_size(self):
|
||||||
|
_, tensors = _run()
|
||||||
|
mla, _ = _split_tensors(tensors)
|
||||||
|
assert len(set(t.size for t in mla)) == 1
|
||||||
|
|
||||||
|
def test_all_mla_share_one_block_stride(self):
|
||||||
|
_, tensors = _run()
|
||||||
|
mla, _ = _split_tensors(tensors)
|
||||||
|
strides = set(t.block_stride for t in mla)
|
||||||
|
assert len(strides) == 1
|
||||||
|
assert strides.pop() > 0
|
||||||
|
|
||||||
|
def test_mla_offsets_are_unique(self):
|
||||||
|
_, tensors = _run(n_c4=5, n_c128=4)
|
||||||
|
mla, _ = _split_tensors(tensors)
|
||||||
|
offsets = [t.offset for t in mla]
|
||||||
|
assert len(offsets) == len(set(offsets))
|
||||||
|
|
||||||
|
def test_mla_offsets_fit_within_block_stride(self):
|
||||||
|
_, tensors = _run(n_c4=5, n_c128=4)
|
||||||
|
mla, _ = _split_tensors(tensors)
|
||||||
|
stride = mla[0].block_stride
|
||||||
|
for t in mla:
|
||||||
|
assert 0 <= t.offset < stride, f"offset {t.offset} >= block_stride {stride}"
|
||||||
|
|
||||||
|
def test_mla_all_layers_accounted_for(self):
|
||||||
|
n_c4, n_c128 = 5, 4
|
||||||
|
_, tensors = _run(n_c4=n_c4, n_c128=n_c128)
|
||||||
|
mla, _ = _split_tensors(tensors)
|
||||||
|
all_names = set()
|
||||||
|
for t in mla:
|
||||||
|
all_names.update(t.shared_by)
|
||||||
|
expected_count = n_c4 * 2 + n_c128 # c4_mla + c4_idx + c128_mla
|
||||||
|
assert len(all_names) == expected_count
|
||||||
|
|
||||||
|
def test_size_equals_stride_times_num_blocks(self):
|
||||||
|
num_blocks, tensors = _run()
|
||||||
|
mla, _ = _split_tensors(tensors)
|
||||||
|
for t in mla:
|
||||||
|
assert t.size == t.block_stride * num_blocks
|
||||||
|
|
||||||
|
|
||||||
|
class TestSwaContiguousPacking:
|
||||||
|
def test_swa_packed_separately_from_mla(self):
|
||||||
|
_, tensors = _run()
|
||||||
|
mla, swa = _split_tensors(tensors)
|
||||||
|
assert len(swa) > 0
|
||||||
|
assert len(mla) > 0
|
||||||
|
assert set(t.size for t in swa) != set(t.size for t in mla)
|
||||||
|
|
||||||
|
def test_swa_share_one_size(self):
|
||||||
|
_, tensors = _run()
|
||||||
|
_, swa = _split_tensors(tensors)
|
||||||
|
assert len(set(t.size for t in swa)) == 1
|
||||||
|
|
||||||
|
def test_swa_share_one_block_stride(self):
|
||||||
|
_, tensors = _run()
|
||||||
|
_, swa = _split_tensors(tensors)
|
||||||
|
strides = set(t.block_stride for t in swa)
|
||||||
|
assert len(strides) == 1
|
||||||
|
assert strides.pop() > 0
|
||||||
|
|
||||||
|
def test_swa_offsets_are_unique(self):
|
||||||
|
_, tensors = _run(n_swa=10)
|
||||||
|
_, swa = _split_tensors(tensors)
|
||||||
|
offsets = [t.offset for t in swa]
|
||||||
|
assert len(offsets) == len(set(offsets))
|
||||||
|
|
||||||
|
def test_swa_all_layers_accounted_for(self):
|
||||||
|
n_swa = 7
|
||||||
|
_, tensors = _run(n_swa=n_swa)
|
||||||
|
_, swa = _split_tensors(tensors)
|
||||||
|
all_names = set()
|
||||||
|
for t in swa:
|
||||||
|
all_names.update(t.shared_by)
|
||||||
|
assert len(all_names) == n_swa
|
||||||
|
|
||||||
|
def test_swa_size_equals_stride_times_num_blocks(self):
|
||||||
|
num_blocks, tensors = _run()
|
||||||
|
_, swa = _split_tensors(tensors)
|
||||||
|
for t in swa:
|
||||||
|
assert t.size == t.block_stride * num_blocks
|
||||||
|
|
||||||
|
|
||||||
|
class TestStridedViewCorrectness:
|
||||||
|
def test_views_are_independent(self):
|
||||||
|
page_sizes = [1728, 8640, 37440]
|
||||||
|
layer_tuple_bytes = sum(page_sizes)
|
||||||
|
num_tuples = 3
|
||||||
|
block_stride = layer_tuple_bytes * num_tuples
|
||||||
|
num_blocks = 4
|
||||||
|
|
||||||
|
backing = torch.zeros(block_stride * num_blocks, dtype=torch.uint8)
|
||||||
|
|
||||||
|
views = []
|
||||||
|
for t in range(num_tuples):
|
||||||
|
for ps_idx, ps in enumerate(page_sizes):
|
||||||
|
offset = t * layer_tuple_bytes + sum(page_sizes[:ps_idx])
|
||||||
|
view = torch.as_strided(
|
||||||
|
backing,
|
||||||
|
size=(num_blocks, ps),
|
||||||
|
stride=(block_stride, 1),
|
||||||
|
storage_offset=offset,
|
||||||
|
)
|
||||||
|
views.append((f"tuple{t}_ps{ps}", view))
|
||||||
|
|
||||||
|
for i, (name, view) in enumerate(views):
|
||||||
|
view.fill_(i + 1)
|
||||||
|
for j, (_, other) in enumerate(views):
|
||||||
|
if j <= i:
|
||||||
|
continue
|
||||||
|
assert other.sum() == 0, f"Writing to {name} corrupted view {j}"
|
||||||
|
|
||||||
|
for i, (name, view) in enumerate(views):
|
||||||
|
assert view.sum() == (i + 1) * view.numel()
|
||||||
|
|
||||||
|
def test_block_isolation(self):
|
||||||
|
ps = 37440
|
||||||
|
n_layers = 5
|
||||||
|
block_stride = ps * n_layers
|
||||||
|
num_blocks = 3
|
||||||
|
|
||||||
|
backing = torch.zeros(block_stride * num_blocks, dtype=torch.uint8)
|
||||||
|
views = [
|
||||||
|
torch.as_strided(backing, (num_blocks, ps), (block_stride, 1), layer * ps)
|
||||||
|
for layer in range(n_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
views[2][1].fill_(42)
|
||||||
|
assert views[2][0].sum() == 0
|
||||||
|
assert views[2][2].sum() == 0
|
||||||
|
for i in range(n_layers):
|
||||||
|
if i != 2:
|
||||||
|
assert views[i].sum() == 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
@@ -259,7 +259,10 @@ class KVConnectorBase_V1(ABC):
|
|||||||
return
|
return
|
||||||
|
|
||||||
def register_cross_layers_kv_cache(
|
def register_cross_layers_kv_cache(
|
||||||
self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"]
|
self,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_backend: type["AttentionBackend"] | None,
|
||||||
|
block_stride: int | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize with a single KV cache tensor used by all layers.
|
Initialize with a single KV cache tensor used by all layers.
|
||||||
|
|||||||
@@ -248,7 +248,10 @@ class MooncakeStoreConnector(KVConnectorBase_V1, SupportsHMA):
|
|||||||
self.connector_worker.register_kv_caches(kv_caches)
|
self.connector_worker.register_kv_caches(kv_caches)
|
||||||
|
|
||||||
def register_cross_layers_kv_cache(
|
def register_cross_layers_kv_cache(
|
||||||
self, kv_cache: torch.Tensor, attn_backend: type
|
self,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_backend: type | None,
|
||||||
|
block_stride: int | None = None,
|
||||||
):
|
):
|
||||||
assert self.connector_worker is not None
|
assert self.connector_worker is not None
|
||||||
assert (
|
assert (
|
||||||
|
|||||||
@@ -236,11 +236,16 @@ class MultiConnector(KVConnectorBase_V1, SupportsHMA):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
def register_cross_layers_kv_cache(
|
def register_cross_layers_kv_cache(
|
||||||
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
self,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_backend: type[AttentionBackend] | None,
|
||||||
|
block_stride: int | None = None,
|
||||||
):
|
):
|
||||||
# Register on all connectors
|
# Register on all connectors
|
||||||
for c in self._connectors:
|
for c in self._connectors:
|
||||||
c.register_cross_layers_kv_cache(kv_cache, attn_backend)
|
c.register_cross_layers_kv_cache(
|
||||||
|
kv_cache, attn_backend, block_stride=block_stride
|
||||||
|
)
|
||||||
|
|
||||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||||
for c in self._connectors:
|
for c in self._connectors:
|
||||||
|
|||||||
@@ -210,10 +210,15 @@ class NixlConnector(KVConnectorBase_V1, SupportsHMA):
|
|||||||
self.connector_worker.register_kv_caches(kv_caches)
|
self.connector_worker.register_kv_caches(kv_caches)
|
||||||
|
|
||||||
def register_cross_layers_kv_cache(
|
def register_cross_layers_kv_cache(
|
||||||
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
self,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_backend: type[AttentionBackend] | None,
|
||||||
|
block_stride: int | None = None,
|
||||||
):
|
):
|
||||||
assert self.connector_worker is not None
|
assert self.connector_worker is not None
|
||||||
self.connector_worker.register_cross_layers_kv_caches(kv_cache)
|
self.connector_worker.register_cross_layers_kv_caches(
|
||||||
|
kv_cache, block_stride=block_stride
|
||||||
|
)
|
||||||
|
|
||||||
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
|
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
|
||||||
assert self.connector_worker is not None
|
assert self.connector_worker is not None
|
||||||
|
|||||||
@@ -774,15 +774,18 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
fut.add_done_callback(request_ready)
|
fut.add_done_callback(request_ready)
|
||||||
|
|
||||||
def register_cross_layers_kv_caches(self, kv_cache: torch.Tensor) -> None:
|
def register_cross_layers_kv_caches(
|
||||||
|
self,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
block_stride: int | None = None,
|
||||||
|
) -> None:
|
||||||
"""Register a cross-layers KV cache tensor with NIXL.
|
"""Register a cross-layers KV cache tensor with NIXL.
|
||||||
|
|
||||||
`use_uniform_kv_cache()` guarantees a single KV cache group whose
|
When block_stride is provided (packed heterogeneous layout),
|
||||||
layers all share the same `AttentionSpec`, so any layer name from
|
it overrides the per-layer page_size calculation.
|
||||||
`_layer_specs` yields the correct per-layer spec for `page_size_bytes`.
|
|
||||||
"""
|
"""
|
||||||
|
self._packed_block_stride = block_stride
|
||||||
first_layer = next(iter(self._layer_specs))
|
first_layer = next(iter(self._layer_specs))
|
||||||
# Forwarding a real layer name rather than a synthetic key
|
|
||||||
self.register_kv_caches({first_layer: kv_cache})
|
self.register_kv_caches({first_layer: kv_cache})
|
||||||
|
|
||||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||||
@@ -859,19 +862,23 @@ class NixlConnectorWorker:
|
|||||||
)
|
)
|
||||||
# `layer_spec.page_size_bytes` only accounts for logical page_size, that is
|
# `layer_spec.page_size_bytes` only accounts for logical page_size, that is
|
||||||
# the page_size assuming constant `self._logical_num_blocks`.
|
# the page_size assuming constant `self._logical_num_blocks`.
|
||||||
physical_page_size = (
|
packed_stride = getattr(self, "_packed_block_stride", None)
|
||||||
layer_spec.page_size_bytes
|
if packed_stride is not None:
|
||||||
if isinstance(layer_spec, MambaSpec)
|
physical_page_size = packed_stride
|
||||||
else layer_spec.page_size_bytes
|
else:
|
||||||
// self._physical_blocks_per_logical_kv_block
|
physical_page_size = (
|
||||||
)
|
layer_spec.page_size_bytes
|
||||||
# For when registering multiple tensors eg K/V in separate regions.
|
if isinstance(layer_spec, MambaSpec)
|
||||||
physical_page_size = physical_page_size // len(cache_list)
|
else layer_spec.page_size_bytes
|
||||||
if self.transfer_topo._cross_layers_blocks:
|
// self._physical_blocks_per_logical_kv_block
|
||||||
# When cross-layers blocks are used, multiply by number of layers
|
|
||||||
physical_page_size = physical_page_size * len(
|
|
||||||
self.kv_cache_config.kv_cache_tensors
|
|
||||||
)
|
)
|
||||||
|
# For when registering multiple tensors eg K/V in separate regions.
|
||||||
|
physical_page_size = physical_page_size // len(cache_list)
|
||||||
|
if self.transfer_topo._cross_layers_blocks:
|
||||||
|
# When cross-layers blocks are used, multiply by number of layers
|
||||||
|
physical_page_size = physical_page_size * len(
|
||||||
|
self.kv_cache_config.kv_cache_tensors
|
||||||
|
)
|
||||||
num_blocks = (
|
num_blocks = (
|
||||||
self._logical_num_blocks
|
self._logical_num_blocks
|
||||||
if isinstance(layer_spec, MambaSpec)
|
if isinstance(layer_spec, MambaSpec)
|
||||||
@@ -906,7 +913,7 @@ class NixlConnectorWorker:
|
|||||||
else:
|
else:
|
||||||
self.block_len_per_layer.append(physical_page_size)
|
self.block_len_per_layer.append(physical_page_size)
|
||||||
|
|
||||||
if cache.shape[0] != num_blocks:
|
if packed_stride is None and cache.shape[0] != num_blocks:
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
"All kv cache tensors must have the same number of "
|
"All kv cache tensors must have the same number of "
|
||||||
f"blocks; layer={layer_name}, "
|
f"blocks; layer={layer_name}, "
|
||||||
|
|||||||
@@ -183,8 +183,12 @@ class OffloadingConnectorWorker:
|
|||||||
self._register_handlers(canonical_kv_caches)
|
self._register_handlers(canonical_kv_caches)
|
||||||
|
|
||||||
def register_cross_layers_kv_cache(
|
def register_cross_layers_kv_cache(
|
||||||
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
self,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_backend: type[AttentionBackend] | None,
|
||||||
|
block_stride: int | None = None,
|
||||||
):
|
):
|
||||||
|
assert attn_backend is not None
|
||||||
# verify that num_blocks is at physical position 0 in the cross-layers
|
# verify that num_blocks is at physical position 0 in the cross-layers
|
||||||
# tensor layout.
|
# tensor layout.
|
||||||
test_shape = attn_backend.get_kv_cache_shape(
|
test_shape = attn_backend.get_kv_cache_shape(
|
||||||
|
|||||||
@@ -76,7 +76,10 @@ class OffloadingConnector(KVConnectorBase_V1, SupportsHMA):
|
|||||||
self.connector_worker.register_kv_caches(kv_caches)
|
self.connector_worker.register_kv_caches(kv_caches)
|
||||||
|
|
||||||
def register_cross_layers_kv_cache(
|
def register_cross_layers_kv_cache(
|
||||||
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
self,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_backend: type[AttentionBackend] | None,
|
||||||
|
block_stride: int | None = None,
|
||||||
):
|
):
|
||||||
assert self.connector_worker is not None
|
assert self.connector_worker is not None
|
||||||
self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend)
|
self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend)
|
||||||
|
|||||||
@@ -1240,17 +1240,43 @@ def _get_kv_cache_config_deepseek_v4(
|
|||||||
num_blocks = available_memory // (layer_tuple_page_bytes * num_layer_tuples)
|
num_blocks = available_memory // (layer_tuple_page_bytes * num_layer_tuples)
|
||||||
num_blocks = may_override_num_blocks(vllm_config, num_blocks)
|
num_blocks = may_override_num_blocks(vllm_config, num_blocks)
|
||||||
|
|
||||||
|
full_block_bytes = layer_tuple_page_bytes * num_layer_tuples
|
||||||
|
total_size = full_block_bytes * num_blocks
|
||||||
|
|
||||||
kv_cache_tensors: list[KVCacheTensor] = []
|
kv_cache_tensors: list[KVCacheTensor] = []
|
||||||
|
|
||||||
for tuple_idx in range(num_layer_tuples):
|
for tuple_idx in range(num_layer_tuples):
|
||||||
|
ps_offset = 0
|
||||||
for ps in page_sizes:
|
for ps in page_sizes:
|
||||||
shared_by: list[str] = []
|
mla_shared: list[str] = []
|
||||||
for b in bucketed:
|
swa_shared: list[str] = []
|
||||||
|
for g_idx, b in enumerate(bucketed):
|
||||||
bucket = b.get(ps)
|
bucket = b.get(ps)
|
||||||
if bucket is not None and tuple_idx < len(bucket):
|
if bucket is not None and tuple_idx < len(bucket):
|
||||||
shared_by.append(bucket[tuple_idx])
|
if g_idx == 0:
|
||||||
kv_cache_tensors.append(
|
mla_shared.append(bucket[tuple_idx])
|
||||||
KVCacheTensor(size=ps * num_blocks, shared_by=shared_by)
|
else:
|
||||||
)
|
swa_shared.append(bucket[tuple_idx])
|
||||||
|
offset = tuple_idx * layer_tuple_page_bytes + ps_offset
|
||||||
|
if mla_shared:
|
||||||
|
kv_cache_tensors.append(
|
||||||
|
KVCacheTensor(
|
||||||
|
size=total_size,
|
||||||
|
shared_by=mla_shared,
|
||||||
|
offset=offset,
|
||||||
|
block_stride=full_block_bytes,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if swa_shared:
|
||||||
|
kv_cache_tensors.append(
|
||||||
|
KVCacheTensor(
|
||||||
|
size=total_size,
|
||||||
|
shared_by=swa_shared,
|
||||||
|
offset=offset,
|
||||||
|
block_stride=full_block_bytes,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ps_offset += ps
|
||||||
|
|
||||||
return num_blocks, kv_cache_tensors
|
return num_blocks, kv_cache_tensors
|
||||||
|
|
||||||
|
|||||||
@@ -834,6 +834,8 @@ class KVCacheTensor:
|
|||||||
|
|
||||||
size: int # size of the KV cache tensor in bytes
|
size: int # size of the KV cache tensor in bytes
|
||||||
shared_by: list[str] # layer names that share the same KV cache tensor
|
shared_by: list[str] # layer names that share the same KV cache tensor
|
||||||
|
offset: int = 0 # byte offset of this layer within a contiguous block
|
||||||
|
block_stride: int = 0 # total bytes per block in a packed layout (0 = not packed)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -151,8 +151,16 @@ def _allocate_kv_cache(
|
|||||||
kv_cache_config: KVCacheConfig, shared_layers: dict[str, str], device: torch.device
|
kv_cache_config: KVCacheConfig, shared_layers: dict[str, str], device: torch.device
|
||||||
):
|
):
|
||||||
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
||||||
|
packed_backing: torch.Tensor | None = None
|
||||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||||
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
|
if kv_cache_tensor.block_stride > 0:
|
||||||
|
if packed_backing is None:
|
||||||
|
packed_backing = torch.zeros(
|
||||||
|
kv_cache_tensor.size, dtype=torch.int8, device=device
|
||||||
|
)
|
||||||
|
tensor = packed_backing
|
||||||
|
else:
|
||||||
|
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
|
||||||
for layer_name in kv_cache_tensor.shared_by:
|
for layer_name in kv_cache_tensor.shared_by:
|
||||||
kv_cache_raw_tensors[layer_name] = tensor
|
kv_cache_raw_tensors[layer_name] = tensor
|
||||||
|
|
||||||
@@ -163,7 +171,7 @@ def _allocate_kv_cache(
|
|||||||
assert layer_names == (kv_cache_raw_tensors.keys() | shared_layers.keys()), (
|
assert layer_names == (kv_cache_raw_tensors.keys() | shared_layers.keys()), (
|
||||||
"Some layers are not correctly initialized"
|
"Some layers are not correctly initialized"
|
||||||
)
|
)
|
||||||
return kv_cache_raw_tensors
|
return kv_cache_raw_tensors, packed_backing
|
||||||
|
|
||||||
|
|
||||||
def _reshape_kv_cache(
|
def _reshape_kv_cache(
|
||||||
@@ -172,10 +180,18 @@ def _reshape_kv_cache(
|
|||||||
cache_dtype: str,
|
cache_dtype: str,
|
||||||
kernel_block_sizes: list[int],
|
kernel_block_sizes: list[int],
|
||||||
shared_kv_cache_layers: dict[str, str],
|
shared_kv_cache_layers: dict[str, str],
|
||||||
|
kv_cache_config: "KVCacheConfig | None" = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
kv_caches: dict[str, Any] = {}
|
kv_caches: dict[str, Any] = {}
|
||||||
has_attn, has_mamba = False, False
|
has_attn, has_mamba = False, False
|
||||||
|
|
||||||
|
layer_packing: dict[str, tuple[int, int]] = {}
|
||||||
|
if kv_cache_config is not None:
|
||||||
|
for kv_tensor in kv_cache_config.kv_cache_tensors:
|
||||||
|
if kv_tensor.block_stride > 0:
|
||||||
|
for ln in kv_tensor.shared_by:
|
||||||
|
layer_packing[ln] = (kv_tensor.offset, kv_tensor.block_stride)
|
||||||
|
|
||||||
for group in attn_groups:
|
for group in attn_groups:
|
||||||
if group.kv_cache_group_id >= len(kernel_block_sizes):
|
if group.kv_cache_group_id >= len(kernel_block_sizes):
|
||||||
continue
|
continue
|
||||||
@@ -194,8 +210,13 @@ def _reshape_kv_cache(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
kv_raw_tensor = kv_cache_raw_tensors[layer_name]
|
kv_raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||||
assert kv_raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
packing = layer_packing.get(layer_name)
|
||||||
num_blocks = kv_raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
if packing is not None:
|
||||||
|
_, blk_stride = packing
|
||||||
|
num_blocks = kv_raw_tensor.numel() // blk_stride
|
||||||
|
else:
|
||||||
|
assert kv_raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
||||||
|
num_blocks = kv_raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
||||||
|
|
||||||
if isinstance(kv_cache_spec, AttentionSpec):
|
if isinstance(kv_cache_spec, AttentionSpec):
|
||||||
has_attn = True
|
has_attn = True
|
||||||
@@ -229,14 +250,19 @@ def _reshape_kv_cache(
|
|||||||
|
|
||||||
dtype = kv_cache_spec.dtype
|
dtype = kv_cache_spec.dtype
|
||||||
kv_tensor = kv_raw_tensor.view(dtype)
|
kv_tensor = kv_raw_tensor.view(dtype)
|
||||||
if kv_cache_spec.page_size_padded is not None:
|
if packing is not None:
|
||||||
# Use strided view to handle page_size_bytes that
|
layer_offset, blk_stride = packing
|
||||||
# include padding. This follows the same pattern as
|
dtype_size = get_dtype_size(dtype)
|
||||||
# MambaSpec handling in gpu_model_runner.py.
|
page_stride = blk_stride // dtype_size
|
||||||
# NOTE: This assumes kv_cache_shape[0] == num_blocks
|
strides = list(torch.empty(kv_cache_shape).stride())
|
||||||
# (i.e. the first physical dimension is the block
|
strides[inv_order[0]] = page_stride
|
||||||
# index), which holds for all current backends
|
kv_cache = torch.as_strided(
|
||||||
# (MLA, FlashAttention, TritonAttention, etc.).
|
kv_tensor,
|
||||||
|
size=kv_cache_shape,
|
||||||
|
stride=tuple(strides),
|
||||||
|
storage_offset=layer_offset // dtype_size,
|
||||||
|
)
|
||||||
|
elif kv_cache_spec.page_size_padded is not None:
|
||||||
dtype_size = get_dtype_size(dtype)
|
dtype_size = get_dtype_size(dtype)
|
||||||
page_stride = kv_cache_spec.page_size_bytes // dtype_size
|
page_stride = kv_cache_spec.page_size_bytes // dtype_size
|
||||||
strides = list(torch.empty(kv_cache_shape).stride())
|
strides = list(torch.empty(kv_cache_shape).stride())
|
||||||
@@ -247,7 +273,6 @@ def _reshape_kv_cache(
|
|||||||
stride=tuple(strides),
|
stride=tuple(strides),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# No padding — safe to use a contiguous view.
|
|
||||||
kv_cache = kv_tensor.view(kv_cache_shape)
|
kv_cache = kv_tensor.view(kv_cache_shape)
|
||||||
kv_caches[layer_name] = kv_cache.permute(*inv_order)
|
kv_caches[layer_name] = kv_cache.permute(*inv_order)
|
||||||
|
|
||||||
@@ -349,9 +374,9 @@ def init_kv_cache(
|
|||||||
cache_dtype: str,
|
cache_dtype: str,
|
||||||
kernel_block_sizes: list[int],
|
kernel_block_sizes: list[int],
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
) -> dict[str, Any]:
|
) -> tuple[dict[str, Any], torch.Tensor | None, int | None]:
|
||||||
shared_kv_cache_layers = get_shared_kv_cache_layers(vllm_config)
|
shared_kv_cache_layers = get_shared_kv_cache_layers(vllm_config)
|
||||||
kv_cache_raw_tensors = _allocate_kv_cache(
|
kv_cache_raw_tensors, packed_backing = _allocate_kv_cache(
|
||||||
kv_cache_config, shared_kv_cache_layers, device
|
kv_cache_config, shared_kv_cache_layers, device
|
||||||
)
|
)
|
||||||
flattened_attn_groups = list(group for groups in attn_groups for group in groups)
|
flattened_attn_groups = list(group for groups in attn_groups for group in groups)
|
||||||
@@ -361,9 +386,21 @@ def init_kv_cache(
|
|||||||
kernel_block_sizes=kernel_block_sizes,
|
kernel_block_sizes=kernel_block_sizes,
|
||||||
cache_dtype=cache_dtype,
|
cache_dtype=cache_dtype,
|
||||||
shared_kv_cache_layers=shared_kv_cache_layers,
|
shared_kv_cache_layers=shared_kv_cache_layers,
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
)
|
)
|
||||||
bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
|
bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
|
||||||
return kv_caches
|
|
||||||
|
packed_block_stride = None
|
||||||
|
if packed_backing is not None:
|
||||||
|
packed_block_stride = next(
|
||||||
|
(
|
||||||
|
t.block_stride
|
||||||
|
for t in kv_cache_config.kv_cache_tensors
|
||||||
|
if t.block_stride > 0
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
return kv_caches, packed_backing, packed_block_stride
|
||||||
|
|
||||||
|
|
||||||
def build_slot_mappings_by_layer(
|
def build_slot_mappings_by_layer(
|
||||||
|
|||||||
@@ -46,14 +46,20 @@ class KVConnector:
|
|||||||
|
|
||||||
class ActiveKVConnector(KVConnector):
|
class ActiveKVConnector(KVConnector):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, vllm_config: VllmConfig, kv_caches_dict: dict[str, torch.Tensor]
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
kv_caches_dict: dict[str, torch.Tensor],
|
||||||
|
packed_backing: torch.Tensor | None = None,
|
||||||
|
packed_block_stride: int | None = None,
|
||||||
):
|
):
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.kv_connector = get_kv_transfer_group()
|
self.kv_connector = get_kv_transfer_group()
|
||||||
# Register kv caches with KV Connector if applicable.
|
if packed_backing is not None and packed_block_stride is not None:
|
||||||
# TODO: support cross_layers_kv_cache
|
self.kv_connector.register_cross_layers_kv_cache(
|
||||||
# (see https://github.com/vllm-project/vllm/pull/27743)
|
packed_backing, None, block_stride=packed_block_stride
|
||||||
self.kv_connector.register_kv_caches(kv_caches_dict)
|
)
|
||||||
|
else:
|
||||||
|
self.kv_connector.register_kv_caches(kv_caches_dict)
|
||||||
self.kv_connector.set_host_xfer_buffer_ops(copy_kv_blocks)
|
self.kv_connector.set_host_xfer_buffer_ops(copy_kv_blocks)
|
||||||
|
|
||||||
self._disabled = False
|
self._disabled = False
|
||||||
@@ -114,10 +120,17 @@ NO_OP_KV_CONNECTOR = KVConnector()
|
|||||||
|
|
||||||
|
|
||||||
def get_kv_connector(
|
def get_kv_connector(
|
||||||
vllm_config: VllmConfig, kv_caches_dict: dict[str, torch.Tensor]
|
vllm_config: VllmConfig,
|
||||||
|
kv_caches_dict: dict[str, torch.Tensor],
|
||||||
|
packed_backing: torch.Tensor | None = None,
|
||||||
|
packed_block_stride: int | None = None,
|
||||||
) -> KVConnector:
|
) -> KVConnector:
|
||||||
if not has_kv_transfer_group():
|
if not has_kv_transfer_group():
|
||||||
# No-op connector.
|
|
||||||
return NO_OP_KV_CONNECTOR
|
return NO_OP_KV_CONNECTOR
|
||||||
|
|
||||||
return ActiveKVConnector(vllm_config, kv_caches_dict)
|
return ActiveKVConnector(
|
||||||
|
vllm_config,
|
||||||
|
kv_caches_dict,
|
||||||
|
packed_backing=packed_backing,
|
||||||
|
packed_block_stride=packed_block_stride,
|
||||||
|
)
|
||||||
|
|||||||
@@ -465,7 +465,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.kv_caches: list[torch.Tensor] = []
|
self.kv_caches: list[torch.Tensor] = []
|
||||||
kv_caches_dict = init_kv_cache(
|
kv_caches_dict, packed_backing, packed_block_stride = init_kv_cache(
|
||||||
self.kv_caches,
|
self.kv_caches,
|
||||||
self.compilation_config.static_forward_context,
|
self.compilation_config.static_forward_context,
|
||||||
self.kv_cache_config,
|
self.kv_cache_config,
|
||||||
@@ -475,7 +475,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.kernel_block_sizes,
|
self.kernel_block_sizes,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
)
|
)
|
||||||
self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)
|
self.kv_connector = get_kv_connector(
|
||||||
|
self.vllm_config,
|
||||||
|
kv_caches_dict,
|
||||||
|
packed_backing=packed_backing,
|
||||||
|
packed_block_stride=packed_block_stride,
|
||||||
|
)
|
||||||
|
|
||||||
def _init_kv_zero_meta(self) -> None:
|
def _init_kv_zero_meta(self) -> None:
|
||||||
"""Build KV-block zeroing metadata; invoked from gpu_worker."""
|
"""Build KV-block zeroing metadata; invoked from gpu_worker."""
|
||||||
|
|||||||
@@ -6995,7 +6995,7 @@ class GPUModelRunner(
|
|||||||
|
|
||||||
def _allocate_kv_cache_tensors(
|
def _allocate_kv_cache_tensors(
|
||||||
self, kv_cache_config: KVCacheConfig
|
self, kv_cache_config: KVCacheConfig
|
||||||
) -> dict[str, torch.Tensor]:
|
) -> tuple[dict[str, torch.Tensor], torch.Tensor | None]:
|
||||||
"""
|
"""
|
||||||
Initializes the KV cache buffer with the correct size. The buffer needs
|
Initializes the KV cache buffer with the correct size. The buffer needs
|
||||||
to be reshaped to the desired shape before being used by the models.
|
to be reshaped to the desired shape before being used by the models.
|
||||||
@@ -7007,12 +7007,20 @@ class GPUModelRunner(
|
|||||||
corresponding memory buffer for KV cache.
|
corresponding memory buffer for KV cache.
|
||||||
"""
|
"""
|
||||||
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
||||||
|
packed_backing: torch.Tensor | None = None
|
||||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||||
tensor = torch.zeros(
|
if kv_cache_tensor.block_stride > 0:
|
||||||
kv_cache_tensor.size, dtype=torch.int8, device=self.device
|
if packed_backing is None:
|
||||||
)
|
packed_backing = torch.zeros(
|
||||||
|
kv_cache_tensor.size, dtype=torch.int8, device=self.device
|
||||||
|
)
|
||||||
|
backing = packed_backing
|
||||||
|
else:
|
||||||
|
backing = torch.zeros(
|
||||||
|
kv_cache_tensor.size, dtype=torch.int8, device=self.device
|
||||||
|
)
|
||||||
for layer_name in kv_cache_tensor.shared_by:
|
for layer_name in kv_cache_tensor.shared_by:
|
||||||
kv_cache_raw_tensors[layer_name] = tensor
|
kv_cache_raw_tensors[layer_name] = backing
|
||||||
|
|
||||||
layer_names = set()
|
layer_names = set()
|
||||||
for group in kv_cache_config.kv_cache_groups:
|
for group in kv_cache_config.kv_cache_groups:
|
||||||
@@ -7023,7 +7031,7 @@ class GPUModelRunner(
|
|||||||
assert layer_names == set(kv_cache_raw_tensors.keys()), (
|
assert layer_names == set(kv_cache_raw_tensors.keys()), (
|
||||||
"Some layers are not correctly initialized"
|
"Some layers are not correctly initialized"
|
||||||
)
|
)
|
||||||
return kv_cache_raw_tensors
|
return kv_cache_raw_tensors, packed_backing
|
||||||
|
|
||||||
def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
|
def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
|
||||||
return itertools.chain.from_iterable(self.attn_groups)
|
return itertools.chain.from_iterable(self.attn_groups)
|
||||||
@@ -7052,6 +7060,12 @@ class GPUModelRunner(
|
|||||||
"""
|
"""
|
||||||
kv_caches: dict[str, torch.Tensor] = {}
|
kv_caches: dict[str, torch.Tensor] = {}
|
||||||
has_attn, has_mamba = False, False
|
has_attn, has_mamba = False, False
|
||||||
|
|
||||||
|
layer_packing: dict[str, tuple[int, int]] = {}
|
||||||
|
for kv_tensor in self.kv_cache_config.kv_cache_tensors:
|
||||||
|
if kv_tensor.block_stride > 0:
|
||||||
|
for ln in kv_tensor.shared_by:
|
||||||
|
layer_packing[ln] = (kv_tensor.offset, kv_tensor.block_stride)
|
||||||
for group in self._kv_cache_spec_attn_group_iterator():
|
for group in self._kv_cache_spec_attn_group_iterator():
|
||||||
kv_cache_spec = group.kv_cache_spec
|
kv_cache_spec = group.kv_cache_spec
|
||||||
attn_backend = group.backend
|
attn_backend = group.backend
|
||||||
@@ -7063,8 +7077,13 @@ class GPUModelRunner(
|
|||||||
if layer_name in self.runner_only_attn_layers:
|
if layer_name in self.runner_only_attn_layers:
|
||||||
continue
|
continue
|
||||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||||
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
packing = layer_packing.get(layer_name)
|
||||||
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
if packing is not None:
|
||||||
|
_, blk_stride = packing
|
||||||
|
num_blocks = raw_tensor.numel() // blk_stride
|
||||||
|
else:
|
||||||
|
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
||||||
|
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
||||||
if isinstance(kv_cache_spec, AttentionSpec):
|
if isinstance(kv_cache_spec, AttentionSpec):
|
||||||
has_attn = True
|
has_attn = True
|
||||||
num_blocks_per_kv_block = (
|
num_blocks_per_kv_block = (
|
||||||
@@ -7106,15 +7125,19 @@ class GPUModelRunner(
|
|||||||
]
|
]
|
||||||
|
|
||||||
raw_tensor = kv_cache_raw_tensors[layer_name].view(dtype)
|
raw_tensor = kv_cache_raw_tensors[layer_name].view(dtype)
|
||||||
if kv_cache_spec.page_size_padded is not None:
|
if packing is not None:
|
||||||
# Use strided view to handle page_size_bytes that
|
layer_offset, blk_stride = packing
|
||||||
# include padding. This follows
|
dtype_size = get_dtype_size(dtype)
|
||||||
# the same pattern as MambaSpec handling below.
|
page_stride = blk_stride // dtype_size
|
||||||
# NOTE: This assumes kv_cache_shape[0] == num_blocks
|
strides = list(torch.empty(kv_cache_shape).stride())
|
||||||
# (i.e. the first physical dimension is the block
|
strides[inv_order[0]] = page_stride
|
||||||
# index), which holds for MLA backends but NOT for
|
kv_cache = torch.as_strided(
|
||||||
# standard attention backends whose shape starts with
|
raw_tensor,
|
||||||
# a K/V dimension of size 2.
|
size=kv_cache_shape,
|
||||||
|
stride=tuple(strides),
|
||||||
|
storage_offset=layer_offset // dtype_size,
|
||||||
|
)
|
||||||
|
elif kv_cache_spec.page_size_padded is not None:
|
||||||
dtype_size = get_dtype_size(dtype)
|
dtype_size = get_dtype_size(dtype)
|
||||||
page_stride = kv_cache_spec.page_size_bytes // dtype_size
|
page_stride = kv_cache_spec.page_size_bytes // dtype_size
|
||||||
strides = list(torch.empty(kv_cache_shape).stride())
|
strides = list(torch.empty(kv_cache_shape).stride())
|
||||||
@@ -7125,7 +7148,6 @@ class GPUModelRunner(
|
|||||||
stride=tuple(strides),
|
stride=tuple(strides),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# No padding — safe to use a contiguous view.
|
|
||||||
kv_cache = raw_tensor.view(kv_cache_shape)
|
kv_cache = raw_tensor.view(kv_cache_shape)
|
||||||
kv_caches[layer_name] = kv_cache.permute(*inv_order)
|
kv_caches[layer_name] = kv_cache.permute(*inv_order)
|
||||||
|
|
||||||
@@ -7227,13 +7249,24 @@ class GPUModelRunner(
|
|||||||
else:
|
else:
|
||||||
# Fallback to the general case
|
# Fallback to the general case
|
||||||
# Initialize the memory buffer for KV cache
|
# Initialize the memory buffer for KV cache
|
||||||
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
|
kv_cache_raw_tensors, packed_backing = self._allocate_kv_cache_tensors(
|
||||||
|
kv_cache_config
|
||||||
|
)
|
||||||
|
|
||||||
# Change the memory buffer to the desired shape
|
# Change the memory buffer to the desired shape
|
||||||
kv_caches = self._reshape_kv_cache_tensors(
|
kv_caches = self._reshape_kv_cache_tensors(
|
||||||
kv_cache_raw_tensors, kernel_block_sizes
|
kv_cache_raw_tensors, kernel_block_sizes
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if packed_backing is not None:
|
||||||
|
block_stride = next(
|
||||||
|
t.block_stride
|
||||||
|
for t in kv_cache_config.kv_cache_tensors
|
||||||
|
if t.block_stride > 0
|
||||||
|
)
|
||||||
|
self.cross_layers_kv_cache = packed_backing
|
||||||
|
self._packed_block_stride = block_stride
|
||||||
|
|
||||||
# Set up cross-layer KV cache sharing
|
# Set up cross-layer KV cache sharing
|
||||||
for layer_name, target_layer_name in self.shared_kv_cache_layers.items():
|
for layer_name, target_layer_name in self.shared_kv_cache_layers.items():
|
||||||
logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name)
|
logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name)
|
||||||
@@ -7329,9 +7362,11 @@ class GPUModelRunner(
|
|||||||
if has_kv_transfer_group() and not is_profiling:
|
if has_kv_transfer_group() and not is_profiling:
|
||||||
kv_transfer_group = get_kv_transfer_group()
|
kv_transfer_group = get_kv_transfer_group()
|
||||||
if self.cross_layers_kv_cache is not None:
|
if self.cross_layers_kv_cache is not None:
|
||||||
assert self.cross_layers_attn_backend is not None
|
block_stride = getattr(self, "_packed_block_stride", None)
|
||||||
kv_transfer_group.register_cross_layers_kv_cache(
|
kv_transfer_group.register_cross_layers_kv_cache(
|
||||||
self.cross_layers_kv_cache, self.cross_layers_attn_backend
|
self.cross_layers_kv_cache,
|
||||||
|
self.cross_layers_attn_backend,
|
||||||
|
block_stride=block_stride,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
kv_transfer_group.register_kv_caches(kv_caches)
|
kv_transfer_group.register_kv_caches(kv_caches)
|
||||||
|
|||||||
Reference in New Issue
Block a user