mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
Pack KV caches into contiguous per-block allocations for DeepSeek V4
For DeepSeek V4, pack all layer data contiguously per block so that NIXL registers one region per block instead of one per layer. Full-attention MLA + SWA/compressor caches share one contiguous allocation per block. Each layer gets an as_strided view with storage_offset into the packed backing tensor. The packed backing tensor and block_stride are passed through the cross-layer KV cache registration API so NIXL registers one region with block_len=block_stride instead of many separate regions. Previously: 92 NIXL regions, 92 tiny P2P transfers per block (~16KB). Now: 1 region, 1 large RDMA transfer per block (~1.48MB). Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
@@ -0,0 +1,214 @@
|
||||
"""Tests for contiguous KV block packing in _get_kv_cache_config_deepseek_v4."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.v1.core.kv_cache_utils import _get_kv_cache_config_deepseek_v4
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
KVCacheGroupSpec,
|
||||
MLAAttentionSpec,
|
||||
UniformTypeKVCacheSpecs,
|
||||
)
|
||||
|
||||
|
||||
def _make_mla_spec(page_size: int, block_size: int = 256) -> MLAAttentionSpec:
|
||||
return MLAAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=512,
|
||||
dtype=torch.uint8,
|
||||
page_size_padded=page_size,
|
||||
cache_dtype_str="fp8_ds_mla",
|
||||
model_version="deepseek_v4",
|
||||
alignment=576,
|
||||
)
|
||||
|
||||
|
||||
def _make_groups(n_c4, n_c128, n_swa):
|
||||
PS_C4_MLA = 37440
|
||||
PS_C4_IDX = 8640
|
||||
PS_C128 = 1728
|
||||
PS_SWA = 37440
|
||||
|
||||
mla_specs = {}
|
||||
for i in range(n_c4):
|
||||
mla_specs[f"c4_mla.{i}"] = _make_mla_spec(PS_C4_MLA)
|
||||
mla_specs[f"c4_idx.{i}"] = _make_mla_spec(PS_C4_IDX)
|
||||
for i in range(n_c128):
|
||||
mla_specs[f"c128_mla.{i}"] = _make_mla_spec(PS_C128)
|
||||
|
||||
mla_group = KVCacheGroupSpec(
|
||||
layer_names=list(mla_specs.keys()),
|
||||
kv_cache_spec=UniformTypeKVCacheSpecs(block_size=256, kv_cache_specs=mla_specs),
|
||||
)
|
||||
|
||||
swa_specs = {}
|
||||
for i in range(n_swa):
|
||||
swa_specs[f"swa.{i}"] = _make_mla_spec(PS_SWA)
|
||||
|
||||
swa_group = KVCacheGroupSpec(
|
||||
layer_names=list(swa_specs.keys()),
|
||||
kv_cache_spec=UniformTypeKVCacheSpecs(block_size=256, kv_cache_specs=swa_specs),
|
||||
)
|
||||
|
||||
return [mla_group, swa_group]
|
||||
|
||||
|
||||
def _mock_vllm_config():
|
||||
config = MagicMock()
|
||||
config.scheduler_config.num_gpu_blocks_override = None
|
||||
return config
|
||||
|
||||
|
||||
def _run(n_c4=3, n_c128=2, n_swa=5, mem=100 * 1024 * 1024):
|
||||
groups = _make_groups(n_c4, n_c128, n_swa)
|
||||
return _get_kv_cache_config_deepseek_v4(_mock_vllm_config(), groups, mem)
|
||||
|
||||
|
||||
def _split_tensors(tensors):
|
||||
mla = [t for t in tensors if any("mla" in n or "idx" in n for n in t.shared_by)]
|
||||
swa = [t for t in tensors if any("swa" in n for n in t.shared_by)]
|
||||
return mla, swa
|
||||
|
||||
|
||||
class TestMlaContiguousPacking:
|
||||
def test_all_mla_share_one_size(self):
|
||||
_, tensors = _run()
|
||||
mla, _ = _split_tensors(tensors)
|
||||
assert len(set(t.size for t in mla)) == 1
|
||||
|
||||
def test_all_mla_share_one_block_stride(self):
|
||||
_, tensors = _run()
|
||||
mla, _ = _split_tensors(tensors)
|
||||
strides = set(t.block_stride for t in mla)
|
||||
assert len(strides) == 1
|
||||
assert strides.pop() > 0
|
||||
|
||||
def test_mla_offsets_are_unique(self):
|
||||
_, tensors = _run(n_c4=5, n_c128=4)
|
||||
mla, _ = _split_tensors(tensors)
|
||||
offsets = [t.offset for t in mla]
|
||||
assert len(offsets) == len(set(offsets))
|
||||
|
||||
def test_mla_offsets_fit_within_block_stride(self):
|
||||
_, tensors = _run(n_c4=5, n_c128=4)
|
||||
mla, _ = _split_tensors(tensors)
|
||||
stride = mla[0].block_stride
|
||||
for t in mla:
|
||||
assert 0 <= t.offset < stride, f"offset {t.offset} >= block_stride {stride}"
|
||||
|
||||
def test_mla_all_layers_accounted_for(self):
|
||||
n_c4, n_c128 = 5, 4
|
||||
_, tensors = _run(n_c4=n_c4, n_c128=n_c128)
|
||||
mla, _ = _split_tensors(tensors)
|
||||
all_names = set()
|
||||
for t in mla:
|
||||
all_names.update(t.shared_by)
|
||||
expected_count = n_c4 * 2 + n_c128 # c4_mla + c4_idx + c128_mla
|
||||
assert len(all_names) == expected_count
|
||||
|
||||
def test_size_equals_stride_times_num_blocks(self):
|
||||
num_blocks, tensors = _run()
|
||||
mla, _ = _split_tensors(tensors)
|
||||
for t in mla:
|
||||
assert t.size == t.block_stride * num_blocks
|
||||
|
||||
|
||||
class TestSwaContiguousPacking:
|
||||
def test_swa_packed_separately_from_mla(self):
|
||||
_, tensors = _run()
|
||||
mla, swa = _split_tensors(tensors)
|
||||
assert len(swa) > 0
|
||||
assert len(mla) > 0
|
||||
assert set(t.size for t in swa) != set(t.size for t in mla)
|
||||
|
||||
def test_swa_share_one_size(self):
|
||||
_, tensors = _run()
|
||||
_, swa = _split_tensors(tensors)
|
||||
assert len(set(t.size for t in swa)) == 1
|
||||
|
||||
def test_swa_share_one_block_stride(self):
|
||||
_, tensors = _run()
|
||||
_, swa = _split_tensors(tensors)
|
||||
strides = set(t.block_stride for t in swa)
|
||||
assert len(strides) == 1
|
||||
assert strides.pop() > 0
|
||||
|
||||
def test_swa_offsets_are_unique(self):
|
||||
_, tensors = _run(n_swa=10)
|
||||
_, swa = _split_tensors(tensors)
|
||||
offsets = [t.offset for t in swa]
|
||||
assert len(offsets) == len(set(offsets))
|
||||
|
||||
def test_swa_all_layers_accounted_for(self):
|
||||
n_swa = 7
|
||||
_, tensors = _run(n_swa=n_swa)
|
||||
_, swa = _split_tensors(tensors)
|
||||
all_names = set()
|
||||
for t in swa:
|
||||
all_names.update(t.shared_by)
|
||||
assert len(all_names) == n_swa
|
||||
|
||||
def test_swa_size_equals_stride_times_num_blocks(self):
|
||||
num_blocks, tensors = _run()
|
||||
_, swa = _split_tensors(tensors)
|
||||
for t in swa:
|
||||
assert t.size == t.block_stride * num_blocks
|
||||
|
||||
|
||||
class TestStridedViewCorrectness:
|
||||
def test_views_are_independent(self):
|
||||
page_sizes = [1728, 8640, 37440]
|
||||
layer_tuple_bytes = sum(page_sizes)
|
||||
num_tuples = 3
|
||||
block_stride = layer_tuple_bytes * num_tuples
|
||||
num_blocks = 4
|
||||
|
||||
backing = torch.zeros(block_stride * num_blocks, dtype=torch.uint8)
|
||||
|
||||
views = []
|
||||
for t in range(num_tuples):
|
||||
for ps_idx, ps in enumerate(page_sizes):
|
||||
offset = t * layer_tuple_bytes + sum(page_sizes[:ps_idx])
|
||||
view = torch.as_strided(
|
||||
backing,
|
||||
size=(num_blocks, ps),
|
||||
stride=(block_stride, 1),
|
||||
storage_offset=offset,
|
||||
)
|
||||
views.append((f"tuple{t}_ps{ps}", view))
|
||||
|
||||
for i, (name, view) in enumerate(views):
|
||||
view.fill_(i + 1)
|
||||
for j, (_, other) in enumerate(views):
|
||||
if j <= i:
|
||||
continue
|
||||
assert other.sum() == 0, f"Writing to {name} corrupted view {j}"
|
||||
|
||||
for i, (name, view) in enumerate(views):
|
||||
assert view.sum() == (i + 1) * view.numel()
|
||||
|
||||
def test_block_isolation(self):
|
||||
ps = 37440
|
||||
n_layers = 5
|
||||
block_stride = ps * n_layers
|
||||
num_blocks = 3
|
||||
|
||||
backing = torch.zeros(block_stride * num_blocks, dtype=torch.uint8)
|
||||
views = [
|
||||
torch.as_strided(backing, (num_blocks, ps), (block_stride, 1), layer * ps)
|
||||
for layer in range(n_layers)
|
||||
]
|
||||
|
||||
views[2][1].fill_(42)
|
||||
assert views[2][0].sum() == 0
|
||||
assert views[2][2].sum() == 0
|
||||
for i in range(n_layers):
|
||||
if i != 2:
|
||||
assert views[i].sum() == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -259,7 +259,10 @@ class KVConnectorBase_V1(ABC):
|
||||
return
|
||||
|
||||
def register_cross_layers_kv_cache(
|
||||
self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"]
|
||||
self,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_backend: type["AttentionBackend"] | None,
|
||||
block_stride: int | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize with a single KV cache tensor used by all layers.
|
||||
|
||||
@@ -248,7 +248,10 @@ class MooncakeStoreConnector(KVConnectorBase_V1, SupportsHMA):
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def register_cross_layers_kv_cache(
|
||||
self, kv_cache: torch.Tensor, attn_backend: type
|
||||
self,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_backend: type | None,
|
||||
block_stride: int | None = None,
|
||||
):
|
||||
assert self.connector_worker is not None
|
||||
assert (
|
||||
|
||||
@@ -236,11 +236,16 @@ class MultiConnector(KVConnectorBase_V1, SupportsHMA):
|
||||
return ret
|
||||
|
||||
def register_cross_layers_kv_cache(
|
||||
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
||||
self,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_backend: type[AttentionBackend] | None,
|
||||
block_stride: int | None = None,
|
||||
):
|
||||
# Register on all connectors
|
||||
for c in self._connectors:
|
||||
c.register_cross_layers_kv_cache(kv_cache, attn_backend)
|
||||
c.register_cross_layers_kv_cache(
|
||||
kv_cache, attn_backend, block_stride=block_stride
|
||||
)
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
for c in self._connectors:
|
||||
|
||||
@@ -210,10 +210,15 @@ class NixlConnector(KVConnectorBase_V1, SupportsHMA):
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def register_cross_layers_kv_cache(
|
||||
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
||||
self,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_backend: type[AttentionBackend] | None,
|
||||
block_stride: int | None = None,
|
||||
):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_cross_layers_kv_caches(kv_cache)
|
||||
self.connector_worker.register_cross_layers_kv_caches(
|
||||
kv_cache, block_stride=block_stride
|
||||
)
|
||||
|
||||
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
|
||||
assert self.connector_worker is not None
|
||||
|
||||
@@ -774,15 +774,18 @@ class NixlConnectorWorker:
|
||||
|
||||
fut.add_done_callback(request_ready)
|
||||
|
||||
def register_cross_layers_kv_caches(self, kv_cache: torch.Tensor) -> None:
|
||||
def register_cross_layers_kv_caches(
|
||||
self,
|
||||
kv_cache: torch.Tensor,
|
||||
block_stride: int | None = None,
|
||||
) -> None:
|
||||
"""Register a cross-layers KV cache tensor with NIXL.
|
||||
|
||||
`use_uniform_kv_cache()` guarantees a single KV cache group whose
|
||||
layers all share the same `AttentionSpec`, so any layer name from
|
||||
`_layer_specs` yields the correct per-layer spec for `page_size_bytes`.
|
||||
When block_stride is provided (packed heterogeneous layout),
|
||||
it overrides the per-layer page_size calculation.
|
||||
"""
|
||||
self._packed_block_stride = block_stride
|
||||
first_layer = next(iter(self._layer_specs))
|
||||
# Forwarding a real layer name rather than a synthetic key
|
||||
self.register_kv_caches({first_layer: kv_cache})
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
@@ -859,19 +862,23 @@ class NixlConnectorWorker:
|
||||
)
|
||||
# `layer_spec.page_size_bytes` only accounts for logical page_size, that is
|
||||
# the page_size assuming constant `self._logical_num_blocks`.
|
||||
physical_page_size = (
|
||||
layer_spec.page_size_bytes
|
||||
if isinstance(layer_spec, MambaSpec)
|
||||
else layer_spec.page_size_bytes
|
||||
// self._physical_blocks_per_logical_kv_block
|
||||
)
|
||||
# For when registering multiple tensors eg K/V in separate regions.
|
||||
physical_page_size = physical_page_size // len(cache_list)
|
||||
if self.transfer_topo._cross_layers_blocks:
|
||||
# When cross-layers blocks are used, multiply by number of layers
|
||||
physical_page_size = physical_page_size * len(
|
||||
self.kv_cache_config.kv_cache_tensors
|
||||
packed_stride = getattr(self, "_packed_block_stride", None)
|
||||
if packed_stride is not None:
|
||||
physical_page_size = packed_stride
|
||||
else:
|
||||
physical_page_size = (
|
||||
layer_spec.page_size_bytes
|
||||
if isinstance(layer_spec, MambaSpec)
|
||||
else layer_spec.page_size_bytes
|
||||
// self._physical_blocks_per_logical_kv_block
|
||||
)
|
||||
# For when registering multiple tensors eg K/V in separate regions.
|
||||
physical_page_size = physical_page_size // len(cache_list)
|
||||
if self.transfer_topo._cross_layers_blocks:
|
||||
# When cross-layers blocks are used, multiply by number of layers
|
||||
physical_page_size = physical_page_size * len(
|
||||
self.kv_cache_config.kv_cache_tensors
|
||||
)
|
||||
num_blocks = (
|
||||
self._logical_num_blocks
|
||||
if isinstance(layer_spec, MambaSpec)
|
||||
@@ -906,7 +913,7 @@ class NixlConnectorWorker:
|
||||
else:
|
||||
self.block_len_per_layer.append(physical_page_size)
|
||||
|
||||
if cache.shape[0] != num_blocks:
|
||||
if packed_stride is None and cache.shape[0] != num_blocks:
|
||||
raise AssertionError(
|
||||
"All kv cache tensors must have the same number of "
|
||||
f"blocks; layer={layer_name}, "
|
||||
|
||||
@@ -183,8 +183,12 @@ class OffloadingConnectorWorker:
|
||||
self._register_handlers(canonical_kv_caches)
|
||||
|
||||
def register_cross_layers_kv_cache(
|
||||
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
||||
self,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_backend: type[AttentionBackend] | None,
|
||||
block_stride: int | None = None,
|
||||
):
|
||||
assert attn_backend is not None
|
||||
# verify that num_blocks is at physical position 0 in the cross-layers
|
||||
# tensor layout.
|
||||
test_shape = attn_backend.get_kv_cache_shape(
|
||||
|
||||
@@ -76,7 +76,10 @@ class OffloadingConnector(KVConnectorBase_V1, SupportsHMA):
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def register_cross_layers_kv_cache(
|
||||
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
||||
self,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_backend: type[AttentionBackend] | None,
|
||||
block_stride: int | None = None,
|
||||
):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend)
|
||||
|
||||
@@ -1240,17 +1240,43 @@ def _get_kv_cache_config_deepseek_v4(
|
||||
num_blocks = available_memory // (layer_tuple_page_bytes * num_layer_tuples)
|
||||
num_blocks = may_override_num_blocks(vllm_config, num_blocks)
|
||||
|
||||
full_block_bytes = layer_tuple_page_bytes * num_layer_tuples
|
||||
total_size = full_block_bytes * num_blocks
|
||||
|
||||
kv_cache_tensors: list[KVCacheTensor] = []
|
||||
|
||||
for tuple_idx in range(num_layer_tuples):
|
||||
ps_offset = 0
|
||||
for ps in page_sizes:
|
||||
shared_by: list[str] = []
|
||||
for b in bucketed:
|
||||
mla_shared: list[str] = []
|
||||
swa_shared: list[str] = []
|
||||
for g_idx, b in enumerate(bucketed):
|
||||
bucket = b.get(ps)
|
||||
if bucket is not None and tuple_idx < len(bucket):
|
||||
shared_by.append(bucket[tuple_idx])
|
||||
kv_cache_tensors.append(
|
||||
KVCacheTensor(size=ps * num_blocks, shared_by=shared_by)
|
||||
)
|
||||
if g_idx == 0:
|
||||
mla_shared.append(bucket[tuple_idx])
|
||||
else:
|
||||
swa_shared.append(bucket[tuple_idx])
|
||||
offset = tuple_idx * layer_tuple_page_bytes + ps_offset
|
||||
if mla_shared:
|
||||
kv_cache_tensors.append(
|
||||
KVCacheTensor(
|
||||
size=total_size,
|
||||
shared_by=mla_shared,
|
||||
offset=offset,
|
||||
block_stride=full_block_bytes,
|
||||
)
|
||||
)
|
||||
if swa_shared:
|
||||
kv_cache_tensors.append(
|
||||
KVCacheTensor(
|
||||
size=total_size,
|
||||
shared_by=swa_shared,
|
||||
offset=offset,
|
||||
block_stride=full_block_bytes,
|
||||
)
|
||||
)
|
||||
ps_offset += ps
|
||||
|
||||
return num_blocks, kv_cache_tensors
|
||||
|
||||
|
||||
@@ -834,6 +834,8 @@ class KVCacheTensor:
|
||||
|
||||
size: int # size of the KV cache tensor in bytes
|
||||
shared_by: list[str] # layer names that share the same KV cache tensor
|
||||
offset: int = 0 # byte offset of this layer within a contiguous block
|
||||
block_stride: int = 0 # total bytes per block in a packed layout (0 = not packed)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -151,8 +151,16 @@ def _allocate_kv_cache(
|
||||
kv_cache_config: KVCacheConfig, shared_layers: dict[str, str], device: torch.device
|
||||
):
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
||||
packed_backing: torch.Tensor | None = None
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
|
||||
if kv_cache_tensor.block_stride > 0:
|
||||
if packed_backing is None:
|
||||
packed_backing = torch.zeros(
|
||||
kv_cache_tensor.size, dtype=torch.int8, device=device
|
||||
)
|
||||
tensor = packed_backing
|
||||
else:
|
||||
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
|
||||
for layer_name in kv_cache_tensor.shared_by:
|
||||
kv_cache_raw_tensors[layer_name] = tensor
|
||||
|
||||
@@ -163,7 +171,7 @@ def _allocate_kv_cache(
|
||||
assert layer_names == (kv_cache_raw_tensors.keys() | shared_layers.keys()), (
|
||||
"Some layers are not correctly initialized"
|
||||
)
|
||||
return kv_cache_raw_tensors
|
||||
return kv_cache_raw_tensors, packed_backing
|
||||
|
||||
|
||||
def _reshape_kv_cache(
|
||||
@@ -172,10 +180,18 @@ def _reshape_kv_cache(
|
||||
cache_dtype: str,
|
||||
kernel_block_sizes: list[int],
|
||||
shared_kv_cache_layers: dict[str, str],
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
) -> dict[str, Any]:
|
||||
kv_caches: dict[str, Any] = {}
|
||||
has_attn, has_mamba = False, False
|
||||
|
||||
layer_packing: dict[str, tuple[int, int]] = {}
|
||||
if kv_cache_config is not None:
|
||||
for kv_tensor in kv_cache_config.kv_cache_tensors:
|
||||
if kv_tensor.block_stride > 0:
|
||||
for ln in kv_tensor.shared_by:
|
||||
layer_packing[ln] = (kv_tensor.offset, kv_tensor.block_stride)
|
||||
|
||||
for group in attn_groups:
|
||||
if group.kv_cache_group_id >= len(kernel_block_sizes):
|
||||
continue
|
||||
@@ -194,8 +210,13 @@ def _reshape_kv_cache(
|
||||
continue
|
||||
|
||||
kv_raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
assert kv_raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = kv_raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
||||
packing = layer_packing.get(layer_name)
|
||||
if packing is not None:
|
||||
_, blk_stride = packing
|
||||
num_blocks = kv_raw_tensor.numel() // blk_stride
|
||||
else:
|
||||
assert kv_raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = kv_raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
||||
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
has_attn = True
|
||||
@@ -229,14 +250,19 @@ def _reshape_kv_cache(
|
||||
|
||||
dtype = kv_cache_spec.dtype
|
||||
kv_tensor = kv_raw_tensor.view(dtype)
|
||||
if kv_cache_spec.page_size_padded is not None:
|
||||
# Use strided view to handle page_size_bytes that
|
||||
# include padding. This follows the same pattern as
|
||||
# MambaSpec handling in gpu_model_runner.py.
|
||||
# NOTE: This assumes kv_cache_shape[0] == num_blocks
|
||||
# (i.e. the first physical dimension is the block
|
||||
# index), which holds for all current backends
|
||||
# (MLA, FlashAttention, TritonAttention, etc.).
|
||||
if packing is not None:
|
||||
layer_offset, blk_stride = packing
|
||||
dtype_size = get_dtype_size(dtype)
|
||||
page_stride = blk_stride // dtype_size
|
||||
strides = list(torch.empty(kv_cache_shape).stride())
|
||||
strides[inv_order[0]] = page_stride
|
||||
kv_cache = torch.as_strided(
|
||||
kv_tensor,
|
||||
size=kv_cache_shape,
|
||||
stride=tuple(strides),
|
||||
storage_offset=layer_offset // dtype_size,
|
||||
)
|
||||
elif kv_cache_spec.page_size_padded is not None:
|
||||
dtype_size = get_dtype_size(dtype)
|
||||
page_stride = kv_cache_spec.page_size_bytes // dtype_size
|
||||
strides = list(torch.empty(kv_cache_shape).stride())
|
||||
@@ -247,7 +273,6 @@ def _reshape_kv_cache(
|
||||
stride=tuple(strides),
|
||||
)
|
||||
else:
|
||||
# No padding — safe to use a contiguous view.
|
||||
kv_cache = kv_tensor.view(kv_cache_shape)
|
||||
kv_caches[layer_name] = kv_cache.permute(*inv_order)
|
||||
|
||||
@@ -349,9 +374,9 @@ def init_kv_cache(
|
||||
cache_dtype: str,
|
||||
kernel_block_sizes: list[int],
|
||||
vllm_config: VllmConfig,
|
||||
) -> dict[str, Any]:
|
||||
) -> tuple[dict[str, Any], torch.Tensor | None, int | None]:
|
||||
shared_kv_cache_layers = get_shared_kv_cache_layers(vllm_config)
|
||||
kv_cache_raw_tensors = _allocate_kv_cache(
|
||||
kv_cache_raw_tensors, packed_backing = _allocate_kv_cache(
|
||||
kv_cache_config, shared_kv_cache_layers, device
|
||||
)
|
||||
flattened_attn_groups = list(group for groups in attn_groups for group in groups)
|
||||
@@ -361,9 +386,21 @@ def init_kv_cache(
|
||||
kernel_block_sizes=kernel_block_sizes,
|
||||
cache_dtype=cache_dtype,
|
||||
shared_kv_cache_layers=shared_kv_cache_layers,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
|
||||
return kv_caches
|
||||
|
||||
packed_block_stride = None
|
||||
if packed_backing is not None:
|
||||
packed_block_stride = next(
|
||||
(
|
||||
t.block_stride
|
||||
for t in kv_cache_config.kv_cache_tensors
|
||||
if t.block_stride > 0
|
||||
),
|
||||
None,
|
||||
)
|
||||
return kv_caches, packed_backing, packed_block_stride
|
||||
|
||||
|
||||
def build_slot_mappings_by_layer(
|
||||
|
||||
@@ -46,14 +46,20 @@ class KVConnector:
|
||||
|
||||
class ActiveKVConnector(KVConnector):
|
||||
def __init__(
|
||||
self, vllm_config: VllmConfig, kv_caches_dict: dict[str, torch.Tensor]
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
kv_caches_dict: dict[str, torch.Tensor],
|
||||
packed_backing: torch.Tensor | None = None,
|
||||
packed_block_stride: int | None = None,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.kv_connector = get_kv_transfer_group()
|
||||
# Register kv caches with KV Connector if applicable.
|
||||
# TODO: support cross_layers_kv_cache
|
||||
# (see https://github.com/vllm-project/vllm/pull/27743)
|
||||
self.kv_connector.register_kv_caches(kv_caches_dict)
|
||||
if packed_backing is not None and packed_block_stride is not None:
|
||||
self.kv_connector.register_cross_layers_kv_cache(
|
||||
packed_backing, None, block_stride=packed_block_stride
|
||||
)
|
||||
else:
|
||||
self.kv_connector.register_kv_caches(kv_caches_dict)
|
||||
self.kv_connector.set_host_xfer_buffer_ops(copy_kv_blocks)
|
||||
|
||||
self._disabled = False
|
||||
@@ -114,10 +120,17 @@ NO_OP_KV_CONNECTOR = KVConnector()
|
||||
|
||||
|
||||
def get_kv_connector(
|
||||
vllm_config: VllmConfig, kv_caches_dict: dict[str, torch.Tensor]
|
||||
vllm_config: VllmConfig,
|
||||
kv_caches_dict: dict[str, torch.Tensor],
|
||||
packed_backing: torch.Tensor | None = None,
|
||||
packed_block_stride: int | None = None,
|
||||
) -> KVConnector:
|
||||
if not has_kv_transfer_group():
|
||||
# No-op connector.
|
||||
return NO_OP_KV_CONNECTOR
|
||||
|
||||
return ActiveKVConnector(vllm_config, kv_caches_dict)
|
||||
return ActiveKVConnector(
|
||||
vllm_config,
|
||||
kv_caches_dict,
|
||||
packed_backing=packed_backing,
|
||||
packed_block_stride=packed_block_stride,
|
||||
)
|
||||
|
||||
@@ -465,7 +465,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
|
||||
self.kv_caches: list[torch.Tensor] = []
|
||||
kv_caches_dict = init_kv_cache(
|
||||
kv_caches_dict, packed_backing, packed_block_stride = init_kv_cache(
|
||||
self.kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_cache_config,
|
||||
@@ -475,7 +475,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.kernel_block_sizes,
|
||||
self.vllm_config,
|
||||
)
|
||||
self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)
|
||||
self.kv_connector = get_kv_connector(
|
||||
self.vllm_config,
|
||||
kv_caches_dict,
|
||||
packed_backing=packed_backing,
|
||||
packed_block_stride=packed_block_stride,
|
||||
)
|
||||
|
||||
def _init_kv_zero_meta(self) -> None:
|
||||
"""Build KV-block zeroing metadata; invoked from gpu_worker."""
|
||||
|
||||
@@ -6995,7 +6995,7 @@ class GPUModelRunner(
|
||||
|
||||
def _allocate_kv_cache_tensors(
|
||||
self, kv_cache_config: KVCacheConfig
|
||||
) -> dict[str, torch.Tensor]:
|
||||
) -> tuple[dict[str, torch.Tensor], torch.Tensor | None]:
|
||||
"""
|
||||
Initializes the KV cache buffer with the correct size. The buffer needs
|
||||
to be reshaped to the desired shape before being used by the models.
|
||||
@@ -7007,12 +7007,20 @@ class GPUModelRunner(
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
||||
packed_backing: torch.Tensor | None = None
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
tensor = torch.zeros(
|
||||
kv_cache_tensor.size, dtype=torch.int8, device=self.device
|
||||
)
|
||||
if kv_cache_tensor.block_stride > 0:
|
||||
if packed_backing is None:
|
||||
packed_backing = torch.zeros(
|
||||
kv_cache_tensor.size, dtype=torch.int8, device=self.device
|
||||
)
|
||||
backing = packed_backing
|
||||
else:
|
||||
backing = torch.zeros(
|
||||
kv_cache_tensor.size, dtype=torch.int8, device=self.device
|
||||
)
|
||||
for layer_name in kv_cache_tensor.shared_by:
|
||||
kv_cache_raw_tensors[layer_name] = tensor
|
||||
kv_cache_raw_tensors[layer_name] = backing
|
||||
|
||||
layer_names = set()
|
||||
for group in kv_cache_config.kv_cache_groups:
|
||||
@@ -7023,7 +7031,7 @@ class GPUModelRunner(
|
||||
assert layer_names == set(kv_cache_raw_tensors.keys()), (
|
||||
"Some layers are not correctly initialized"
|
||||
)
|
||||
return kv_cache_raw_tensors
|
||||
return kv_cache_raw_tensors, packed_backing
|
||||
|
||||
def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
|
||||
return itertools.chain.from_iterable(self.attn_groups)
|
||||
@@ -7052,6 +7060,12 @@ class GPUModelRunner(
|
||||
"""
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
has_attn, has_mamba = False, False
|
||||
|
||||
layer_packing: dict[str, tuple[int, int]] = {}
|
||||
for kv_tensor in self.kv_cache_config.kv_cache_tensors:
|
||||
if kv_tensor.block_stride > 0:
|
||||
for ln in kv_tensor.shared_by:
|
||||
layer_packing[ln] = (kv_tensor.offset, kv_tensor.block_stride)
|
||||
for group in self._kv_cache_spec_attn_group_iterator():
|
||||
kv_cache_spec = group.kv_cache_spec
|
||||
attn_backend = group.backend
|
||||
@@ -7063,8 +7077,13 @@ class GPUModelRunner(
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
continue
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
||||
packing = layer_packing.get(layer_name)
|
||||
if packing is not None:
|
||||
_, blk_stride = packing
|
||||
num_blocks = raw_tensor.numel() // blk_stride
|
||||
else:
|
||||
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
has_attn = True
|
||||
num_blocks_per_kv_block = (
|
||||
@@ -7106,15 +7125,19 @@ class GPUModelRunner(
|
||||
]
|
||||
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name].view(dtype)
|
||||
if kv_cache_spec.page_size_padded is not None:
|
||||
# Use strided view to handle page_size_bytes that
|
||||
# include padding. This follows
|
||||
# the same pattern as MambaSpec handling below.
|
||||
# NOTE: This assumes kv_cache_shape[0] == num_blocks
|
||||
# (i.e. the first physical dimension is the block
|
||||
# index), which holds for MLA backends but NOT for
|
||||
# standard attention backends whose shape starts with
|
||||
# a K/V dimension of size 2.
|
||||
if packing is not None:
|
||||
layer_offset, blk_stride = packing
|
||||
dtype_size = get_dtype_size(dtype)
|
||||
page_stride = blk_stride // dtype_size
|
||||
strides = list(torch.empty(kv_cache_shape).stride())
|
||||
strides[inv_order[0]] = page_stride
|
||||
kv_cache = torch.as_strided(
|
||||
raw_tensor,
|
||||
size=kv_cache_shape,
|
||||
stride=tuple(strides),
|
||||
storage_offset=layer_offset // dtype_size,
|
||||
)
|
||||
elif kv_cache_spec.page_size_padded is not None:
|
||||
dtype_size = get_dtype_size(dtype)
|
||||
page_stride = kv_cache_spec.page_size_bytes // dtype_size
|
||||
strides = list(torch.empty(kv_cache_shape).stride())
|
||||
@@ -7125,7 +7148,6 @@ class GPUModelRunner(
|
||||
stride=tuple(strides),
|
||||
)
|
||||
else:
|
||||
# No padding — safe to use a contiguous view.
|
||||
kv_cache = raw_tensor.view(kv_cache_shape)
|
||||
kv_caches[layer_name] = kv_cache.permute(*inv_order)
|
||||
|
||||
@@ -7227,13 +7249,24 @@ class GPUModelRunner(
|
||||
else:
|
||||
# Fallback to the general case
|
||||
# Initialize the memory buffer for KV cache
|
||||
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
|
||||
kv_cache_raw_tensors, packed_backing = self._allocate_kv_cache_tensors(
|
||||
kv_cache_config
|
||||
)
|
||||
|
||||
# Change the memory buffer to the desired shape
|
||||
kv_caches = self._reshape_kv_cache_tensors(
|
||||
kv_cache_raw_tensors, kernel_block_sizes
|
||||
)
|
||||
|
||||
if packed_backing is not None:
|
||||
block_stride = next(
|
||||
t.block_stride
|
||||
for t in kv_cache_config.kv_cache_tensors
|
||||
if t.block_stride > 0
|
||||
)
|
||||
self.cross_layers_kv_cache = packed_backing
|
||||
self._packed_block_stride = block_stride
|
||||
|
||||
# Set up cross-layer KV cache sharing
|
||||
for layer_name, target_layer_name in self.shared_kv_cache_layers.items():
|
||||
logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name)
|
||||
@@ -7329,9 +7362,11 @@ class GPUModelRunner(
|
||||
if has_kv_transfer_group() and not is_profiling:
|
||||
kv_transfer_group = get_kv_transfer_group()
|
||||
if self.cross_layers_kv_cache is not None:
|
||||
assert self.cross_layers_attn_backend is not None
|
||||
block_stride = getattr(self, "_packed_block_stride", None)
|
||||
kv_transfer_group.register_cross_layers_kv_cache(
|
||||
self.cross_layers_kv_cache, self.cross_layers_attn_backend
|
||||
self.cross_layers_kv_cache,
|
||||
self.cross_layers_attn_backend,
|
||||
block_stride=block_stride,
|
||||
)
|
||||
else:
|
||||
kv_transfer_group.register_kv_caches(kv_caches)
|
||||
|
||||
Reference in New Issue
Block a user