Pack KV caches into contiguous per-block allocations for DeepSeek V4

For DeepSeek V4, pack all layer data contiguously per block so that
NIXL registers one region per block instead of one per layer.

Full-attention MLA + SWA/compressor caches share one contiguous
allocation per block. Each layer gets an as_strided view with
storage_offset into the packed backing tensor.

The packed backing tensor and block_stride are passed through the
cross-layer KV cache registration API so NIXL registers one region
with block_len=block_stride instead of many separate regions.

Previously: 92 NIXL regions, 92 tiny P2P transfers per block (~16KB).
Now: 1 region, 1 large RDMA transfer per block (~1.48MB).

Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
Tyler Michael Smith
2026-06-03 14:02:34 -04:00
parent 6a11d72df7
commit 5919036d8c
14 changed files with 441 additions and 79 deletions
+214
View File
@@ -0,0 +1,214 @@
"""Tests for contiguous KV block packing in _get_kv_cache_config_deepseek_v4."""
from unittest.mock import MagicMock
import pytest
import torch
from vllm.v1.core.kv_cache_utils import _get_kv_cache_config_deepseek_v4
from vllm.v1.kv_cache_interface import (
KVCacheGroupSpec,
MLAAttentionSpec,
UniformTypeKVCacheSpecs,
)
def _make_mla_spec(page_size: int, block_size: int = 256) -> MLAAttentionSpec:
return MLAAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=512,
dtype=torch.uint8,
page_size_padded=page_size,
cache_dtype_str="fp8_ds_mla",
model_version="deepseek_v4",
alignment=576,
)
def _make_groups(n_c4, n_c128, n_swa):
PS_C4_MLA = 37440
PS_C4_IDX = 8640
PS_C128 = 1728
PS_SWA = 37440
mla_specs = {}
for i in range(n_c4):
mla_specs[f"c4_mla.{i}"] = _make_mla_spec(PS_C4_MLA)
mla_specs[f"c4_idx.{i}"] = _make_mla_spec(PS_C4_IDX)
for i in range(n_c128):
mla_specs[f"c128_mla.{i}"] = _make_mla_spec(PS_C128)
mla_group = KVCacheGroupSpec(
layer_names=list(mla_specs.keys()),
kv_cache_spec=UniformTypeKVCacheSpecs(block_size=256, kv_cache_specs=mla_specs),
)
swa_specs = {}
for i in range(n_swa):
swa_specs[f"swa.{i}"] = _make_mla_spec(PS_SWA)
swa_group = KVCacheGroupSpec(
layer_names=list(swa_specs.keys()),
kv_cache_spec=UniformTypeKVCacheSpecs(block_size=256, kv_cache_specs=swa_specs),
)
return [mla_group, swa_group]
def _mock_vllm_config():
config = MagicMock()
config.scheduler_config.num_gpu_blocks_override = None
return config
def _run(n_c4=3, n_c128=2, n_swa=5, mem=100 * 1024 * 1024):
groups = _make_groups(n_c4, n_c128, n_swa)
return _get_kv_cache_config_deepseek_v4(_mock_vllm_config(), groups, mem)
def _split_tensors(tensors):
mla = [t for t in tensors if any("mla" in n or "idx" in n for n in t.shared_by)]
swa = [t for t in tensors if any("swa" in n for n in t.shared_by)]
return mla, swa
class TestMlaContiguousPacking:
def test_all_mla_share_one_size(self):
_, tensors = _run()
mla, _ = _split_tensors(tensors)
assert len(set(t.size for t in mla)) == 1
def test_all_mla_share_one_block_stride(self):
_, tensors = _run()
mla, _ = _split_tensors(tensors)
strides = set(t.block_stride for t in mla)
assert len(strides) == 1
assert strides.pop() > 0
def test_mla_offsets_are_unique(self):
_, tensors = _run(n_c4=5, n_c128=4)
mla, _ = _split_tensors(tensors)
offsets = [t.offset for t in mla]
assert len(offsets) == len(set(offsets))
def test_mla_offsets_fit_within_block_stride(self):
_, tensors = _run(n_c4=5, n_c128=4)
mla, _ = _split_tensors(tensors)
stride = mla[0].block_stride
for t in mla:
assert 0 <= t.offset < stride, f"offset {t.offset} >= block_stride {stride}"
def test_mla_all_layers_accounted_for(self):
n_c4, n_c128 = 5, 4
_, tensors = _run(n_c4=n_c4, n_c128=n_c128)
mla, _ = _split_tensors(tensors)
all_names = set()
for t in mla:
all_names.update(t.shared_by)
expected_count = n_c4 * 2 + n_c128 # c4_mla + c4_idx + c128_mla
assert len(all_names) == expected_count
def test_size_equals_stride_times_num_blocks(self):
num_blocks, tensors = _run()
mla, _ = _split_tensors(tensors)
for t in mla:
assert t.size == t.block_stride * num_blocks
class TestSwaContiguousPacking:
def test_swa_packed_separately_from_mla(self):
_, tensors = _run()
mla, swa = _split_tensors(tensors)
assert len(swa) > 0
assert len(mla) > 0
assert set(t.size for t in swa) != set(t.size for t in mla)
def test_swa_share_one_size(self):
_, tensors = _run()
_, swa = _split_tensors(tensors)
assert len(set(t.size for t in swa)) == 1
def test_swa_share_one_block_stride(self):
_, tensors = _run()
_, swa = _split_tensors(tensors)
strides = set(t.block_stride for t in swa)
assert len(strides) == 1
assert strides.pop() > 0
def test_swa_offsets_are_unique(self):
_, tensors = _run(n_swa=10)
_, swa = _split_tensors(tensors)
offsets = [t.offset for t in swa]
assert len(offsets) == len(set(offsets))
def test_swa_all_layers_accounted_for(self):
n_swa = 7
_, tensors = _run(n_swa=n_swa)
_, swa = _split_tensors(tensors)
all_names = set()
for t in swa:
all_names.update(t.shared_by)
assert len(all_names) == n_swa
def test_swa_size_equals_stride_times_num_blocks(self):
num_blocks, tensors = _run()
_, swa = _split_tensors(tensors)
for t in swa:
assert t.size == t.block_stride * num_blocks
class TestStridedViewCorrectness:
def test_views_are_independent(self):
page_sizes = [1728, 8640, 37440]
layer_tuple_bytes = sum(page_sizes)
num_tuples = 3
block_stride = layer_tuple_bytes * num_tuples
num_blocks = 4
backing = torch.zeros(block_stride * num_blocks, dtype=torch.uint8)
views = []
for t in range(num_tuples):
for ps_idx, ps in enumerate(page_sizes):
offset = t * layer_tuple_bytes + sum(page_sizes[:ps_idx])
view = torch.as_strided(
backing,
size=(num_blocks, ps),
stride=(block_stride, 1),
storage_offset=offset,
)
views.append((f"tuple{t}_ps{ps}", view))
for i, (name, view) in enumerate(views):
view.fill_(i + 1)
for j, (_, other) in enumerate(views):
if j <= i:
continue
assert other.sum() == 0, f"Writing to {name} corrupted view {j}"
for i, (name, view) in enumerate(views):
assert view.sum() == (i + 1) * view.numel()
def test_block_isolation(self):
ps = 37440
n_layers = 5
block_stride = ps * n_layers
num_blocks = 3
backing = torch.zeros(block_stride * num_blocks, dtype=torch.uint8)
views = [
torch.as_strided(backing, (num_blocks, ps), (block_stride, 1), layer * ps)
for layer in range(n_layers)
]
views[2][1].fill_(42)
assert views[2][0].sum() == 0
assert views[2][2].sum() == 0
for i in range(n_layers):
if i != 2:
assert views[i].sum() == 0
if __name__ == "__main__":
pytest.main([__file__, "-v"])
@@ -259,7 +259,10 @@ class KVConnectorBase_V1(ABC):
return return
def register_cross_layers_kv_cache( def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"] self,
kv_cache: torch.Tensor,
attn_backend: type["AttentionBackend"] | None,
block_stride: int | None = None,
): ):
""" """
Initialize with a single KV cache tensor used by all layers. Initialize with a single KV cache tensor used by all layers.
@@ -248,7 +248,10 @@ class MooncakeStoreConnector(KVConnectorBase_V1, SupportsHMA):
self.connector_worker.register_kv_caches(kv_caches) self.connector_worker.register_kv_caches(kv_caches)
def register_cross_layers_kv_cache( def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type self,
kv_cache: torch.Tensor,
attn_backend: type | None,
block_stride: int | None = None,
): ):
assert self.connector_worker is not None assert self.connector_worker is not None
assert ( assert (
@@ -236,11 +236,16 @@ class MultiConnector(KVConnectorBase_V1, SupportsHMA):
return ret return ret
def register_cross_layers_kv_cache( def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend] self,
kv_cache: torch.Tensor,
attn_backend: type[AttentionBackend] | None,
block_stride: int | None = None,
): ):
# Register on all connectors # Register on all connectors
for c in self._connectors: for c in self._connectors:
c.register_cross_layers_kv_cache(kv_cache, attn_backend) c.register_cross_layers_kv_cache(
kv_cache, attn_backend, block_stride=block_stride
)
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
for c in self._connectors: for c in self._connectors:
@@ -210,10 +210,15 @@ class NixlConnector(KVConnectorBase_V1, SupportsHMA):
self.connector_worker.register_kv_caches(kv_caches) self.connector_worker.register_kv_caches(kv_caches)
def register_cross_layers_kv_cache( def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend] self,
kv_cache: torch.Tensor,
attn_backend: type[AttentionBackend] | None,
block_stride: int | None = None,
): ):
assert self.connector_worker is not None assert self.connector_worker is not None
self.connector_worker.register_cross_layers_kv_caches(kv_cache) self.connector_worker.register_cross_layers_kv_caches(
kv_cache, block_stride=block_stride
)
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
assert self.connector_worker is not None assert self.connector_worker is not None
@@ -774,15 +774,18 @@ class NixlConnectorWorker:
fut.add_done_callback(request_ready) fut.add_done_callback(request_ready)
def register_cross_layers_kv_caches(self, kv_cache: torch.Tensor) -> None: def register_cross_layers_kv_caches(
self,
kv_cache: torch.Tensor,
block_stride: int | None = None,
) -> None:
"""Register a cross-layers KV cache tensor with NIXL. """Register a cross-layers KV cache tensor with NIXL.
`use_uniform_kv_cache()` guarantees a single KV cache group whose When block_stride is provided (packed heterogeneous layout),
layers all share the same `AttentionSpec`, so any layer name from it overrides the per-layer page_size calculation.
`_layer_specs` yields the correct per-layer spec for `page_size_bytes`.
""" """
self._packed_block_stride = block_stride
first_layer = next(iter(self._layer_specs)) first_layer = next(iter(self._layer_specs))
# Forwarding a real layer name rather than a synthetic key
self.register_kv_caches({first_layer: kv_cache}) self.register_kv_caches({first_layer: kv_cache})
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
@@ -859,19 +862,23 @@ class NixlConnectorWorker:
) )
# `layer_spec.page_size_bytes` only accounts for logical page_size, that is # `layer_spec.page_size_bytes` only accounts for logical page_size, that is
# the page_size assuming constant `self._logical_num_blocks`. # the page_size assuming constant `self._logical_num_blocks`.
physical_page_size = ( packed_stride = getattr(self, "_packed_block_stride", None)
layer_spec.page_size_bytes if packed_stride is not None:
if isinstance(layer_spec, MambaSpec) physical_page_size = packed_stride
else layer_spec.page_size_bytes else:
// self._physical_blocks_per_logical_kv_block physical_page_size = (
) layer_spec.page_size_bytes
# For when registering multiple tensors eg K/V in separate regions. if isinstance(layer_spec, MambaSpec)
physical_page_size = physical_page_size // len(cache_list) else layer_spec.page_size_bytes
if self.transfer_topo._cross_layers_blocks: // self._physical_blocks_per_logical_kv_block
# When cross-layers blocks are used, multiply by number of layers
physical_page_size = physical_page_size * len(
self.kv_cache_config.kv_cache_tensors
) )
# For when registering multiple tensors eg K/V in separate regions.
physical_page_size = physical_page_size // len(cache_list)
if self.transfer_topo._cross_layers_blocks:
# When cross-layers blocks are used, multiply by number of layers
physical_page_size = physical_page_size * len(
self.kv_cache_config.kv_cache_tensors
)
num_blocks = ( num_blocks = (
self._logical_num_blocks self._logical_num_blocks
if isinstance(layer_spec, MambaSpec) if isinstance(layer_spec, MambaSpec)
@@ -906,7 +913,7 @@ class NixlConnectorWorker:
else: else:
self.block_len_per_layer.append(physical_page_size) self.block_len_per_layer.append(physical_page_size)
if cache.shape[0] != num_blocks: if packed_stride is None and cache.shape[0] != num_blocks:
raise AssertionError( raise AssertionError(
"All kv cache tensors must have the same number of " "All kv cache tensors must have the same number of "
f"blocks; layer={layer_name}, " f"blocks; layer={layer_name}, "
@@ -183,8 +183,12 @@ class OffloadingConnectorWorker:
self._register_handlers(canonical_kv_caches) self._register_handlers(canonical_kv_caches)
def register_cross_layers_kv_cache( def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend] self,
kv_cache: torch.Tensor,
attn_backend: type[AttentionBackend] | None,
block_stride: int | None = None,
): ):
assert attn_backend is not None
# verify that num_blocks is at physical position 0 in the cross-layers # verify that num_blocks is at physical position 0 in the cross-layers
# tensor layout. # tensor layout.
test_shape = attn_backend.get_kv_cache_shape( test_shape = attn_backend.get_kv_cache_shape(
@@ -76,7 +76,10 @@ class OffloadingConnector(KVConnectorBase_V1, SupportsHMA):
self.connector_worker.register_kv_caches(kv_caches) self.connector_worker.register_kv_caches(kv_caches)
def register_cross_layers_kv_cache( def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend] self,
kv_cache: torch.Tensor,
attn_backend: type[AttentionBackend] | None,
block_stride: int | None = None,
): ):
assert self.connector_worker is not None assert self.connector_worker is not None
self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend) self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend)
+32 -6
View File
@@ -1240,17 +1240,43 @@ def _get_kv_cache_config_deepseek_v4(
num_blocks = available_memory // (layer_tuple_page_bytes * num_layer_tuples) num_blocks = available_memory // (layer_tuple_page_bytes * num_layer_tuples)
num_blocks = may_override_num_blocks(vllm_config, num_blocks) num_blocks = may_override_num_blocks(vllm_config, num_blocks)
full_block_bytes = layer_tuple_page_bytes * num_layer_tuples
total_size = full_block_bytes * num_blocks
kv_cache_tensors: list[KVCacheTensor] = [] kv_cache_tensors: list[KVCacheTensor] = []
for tuple_idx in range(num_layer_tuples): for tuple_idx in range(num_layer_tuples):
ps_offset = 0
for ps in page_sizes: for ps in page_sizes:
shared_by: list[str] = [] mla_shared: list[str] = []
for b in bucketed: swa_shared: list[str] = []
for g_idx, b in enumerate(bucketed):
bucket = b.get(ps) bucket = b.get(ps)
if bucket is not None and tuple_idx < len(bucket): if bucket is not None and tuple_idx < len(bucket):
shared_by.append(bucket[tuple_idx]) if g_idx == 0:
kv_cache_tensors.append( mla_shared.append(bucket[tuple_idx])
KVCacheTensor(size=ps * num_blocks, shared_by=shared_by) else:
) swa_shared.append(bucket[tuple_idx])
offset = tuple_idx * layer_tuple_page_bytes + ps_offset
if mla_shared:
kv_cache_tensors.append(
KVCacheTensor(
size=total_size,
shared_by=mla_shared,
offset=offset,
block_stride=full_block_bytes,
)
)
if swa_shared:
kv_cache_tensors.append(
KVCacheTensor(
size=total_size,
shared_by=swa_shared,
offset=offset,
block_stride=full_block_bytes,
)
)
ps_offset += ps
return num_blocks, kv_cache_tensors return num_blocks, kv_cache_tensors
+2
View File
@@ -834,6 +834,8 @@ class KVCacheTensor:
size: int # size of the KV cache tensor in bytes size: int # size of the KV cache tensor in bytes
shared_by: list[str] # layer names that share the same KV cache tensor shared_by: list[str] # layer names that share the same KV cache tensor
offset: int = 0 # byte offset of this layer within a contiguous block
block_stride: int = 0 # total bytes per block in a packed layout (0 = not packed)
@dataclass @dataclass
+53 -16
View File
@@ -151,8 +151,16 @@ def _allocate_kv_cache(
kv_cache_config: KVCacheConfig, shared_layers: dict[str, str], device: torch.device kv_cache_config: KVCacheConfig, shared_layers: dict[str, str], device: torch.device
): ):
kv_cache_raw_tensors: dict[str, torch.Tensor] = {} kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
packed_backing: torch.Tensor | None = None
for kv_cache_tensor in kv_cache_config.kv_cache_tensors: for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device) if kv_cache_tensor.block_stride > 0:
if packed_backing is None:
packed_backing = torch.zeros(
kv_cache_tensor.size, dtype=torch.int8, device=device
)
tensor = packed_backing
else:
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
for layer_name in kv_cache_tensor.shared_by: for layer_name in kv_cache_tensor.shared_by:
kv_cache_raw_tensors[layer_name] = tensor kv_cache_raw_tensors[layer_name] = tensor
@@ -163,7 +171,7 @@ def _allocate_kv_cache(
assert layer_names == (kv_cache_raw_tensors.keys() | shared_layers.keys()), ( assert layer_names == (kv_cache_raw_tensors.keys() | shared_layers.keys()), (
"Some layers are not correctly initialized" "Some layers are not correctly initialized"
) )
return kv_cache_raw_tensors return kv_cache_raw_tensors, packed_backing
def _reshape_kv_cache( def _reshape_kv_cache(
@@ -172,10 +180,18 @@ def _reshape_kv_cache(
cache_dtype: str, cache_dtype: str,
kernel_block_sizes: list[int], kernel_block_sizes: list[int],
shared_kv_cache_layers: dict[str, str], shared_kv_cache_layers: dict[str, str],
kv_cache_config: "KVCacheConfig | None" = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
kv_caches: dict[str, Any] = {} kv_caches: dict[str, Any] = {}
has_attn, has_mamba = False, False has_attn, has_mamba = False, False
layer_packing: dict[str, tuple[int, int]] = {}
if kv_cache_config is not None:
for kv_tensor in kv_cache_config.kv_cache_tensors:
if kv_tensor.block_stride > 0:
for ln in kv_tensor.shared_by:
layer_packing[ln] = (kv_tensor.offset, kv_tensor.block_stride)
for group in attn_groups: for group in attn_groups:
if group.kv_cache_group_id >= len(kernel_block_sizes): if group.kv_cache_group_id >= len(kernel_block_sizes):
continue continue
@@ -194,8 +210,13 @@ def _reshape_kv_cache(
continue continue
kv_raw_tensor = kv_cache_raw_tensors[layer_name] kv_raw_tensor = kv_cache_raw_tensors[layer_name]
assert kv_raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 packing = layer_packing.get(layer_name)
num_blocks = kv_raw_tensor.numel() // kv_cache_spec.page_size_bytes if packing is not None:
_, blk_stride = packing
num_blocks = kv_raw_tensor.numel() // blk_stride
else:
assert kv_raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = kv_raw_tensor.numel() // kv_cache_spec.page_size_bytes
if isinstance(kv_cache_spec, AttentionSpec): if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True has_attn = True
@@ -229,14 +250,19 @@ def _reshape_kv_cache(
dtype = kv_cache_spec.dtype dtype = kv_cache_spec.dtype
kv_tensor = kv_raw_tensor.view(dtype) kv_tensor = kv_raw_tensor.view(dtype)
if kv_cache_spec.page_size_padded is not None: if packing is not None:
# Use strided view to handle page_size_bytes that layer_offset, blk_stride = packing
# include padding. This follows the same pattern as dtype_size = get_dtype_size(dtype)
# MambaSpec handling in gpu_model_runner.py. page_stride = blk_stride // dtype_size
# NOTE: This assumes kv_cache_shape[0] == num_blocks strides = list(torch.empty(kv_cache_shape).stride())
# (i.e. the first physical dimension is the block strides[inv_order[0]] = page_stride
# index), which holds for all current backends kv_cache = torch.as_strided(
# (MLA, FlashAttention, TritonAttention, etc.). kv_tensor,
size=kv_cache_shape,
stride=tuple(strides),
storage_offset=layer_offset // dtype_size,
)
elif kv_cache_spec.page_size_padded is not None:
dtype_size = get_dtype_size(dtype) dtype_size = get_dtype_size(dtype)
page_stride = kv_cache_spec.page_size_bytes // dtype_size page_stride = kv_cache_spec.page_size_bytes // dtype_size
strides = list(torch.empty(kv_cache_shape).stride()) strides = list(torch.empty(kv_cache_shape).stride())
@@ -247,7 +273,6 @@ def _reshape_kv_cache(
stride=tuple(strides), stride=tuple(strides),
) )
else: else:
# No padding — safe to use a contiguous view.
kv_cache = kv_tensor.view(kv_cache_shape) kv_cache = kv_tensor.view(kv_cache_shape)
kv_caches[layer_name] = kv_cache.permute(*inv_order) kv_caches[layer_name] = kv_cache.permute(*inv_order)
@@ -349,9 +374,9 @@ def init_kv_cache(
cache_dtype: str, cache_dtype: str,
kernel_block_sizes: list[int], kernel_block_sizes: list[int],
vllm_config: VllmConfig, vllm_config: VllmConfig,
) -> dict[str, Any]: ) -> tuple[dict[str, Any], torch.Tensor | None, int | None]:
shared_kv_cache_layers = get_shared_kv_cache_layers(vllm_config) shared_kv_cache_layers = get_shared_kv_cache_layers(vllm_config)
kv_cache_raw_tensors = _allocate_kv_cache( kv_cache_raw_tensors, packed_backing = _allocate_kv_cache(
kv_cache_config, shared_kv_cache_layers, device kv_cache_config, shared_kv_cache_layers, device
) )
flattened_attn_groups = list(group for groups in attn_groups for group in groups) flattened_attn_groups = list(group for groups in attn_groups for group in groups)
@@ -361,9 +386,21 @@ def init_kv_cache(
kernel_block_sizes=kernel_block_sizes, kernel_block_sizes=kernel_block_sizes,
cache_dtype=cache_dtype, cache_dtype=cache_dtype,
shared_kv_cache_layers=shared_kv_cache_layers, shared_kv_cache_layers=shared_kv_cache_layers,
kv_cache_config=kv_cache_config,
) )
bind_kv_cache(kv_caches, forward_context, runner_kv_caches) bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
return kv_caches
packed_block_stride = None
if packed_backing is not None:
packed_block_stride = next(
(
t.block_stride
for t in kv_cache_config.kv_cache_tensors
if t.block_stride > 0
),
None,
)
return kv_caches, packed_backing, packed_block_stride
def build_slot_mappings_by_layer( def build_slot_mappings_by_layer(
+21 -8
View File
@@ -46,14 +46,20 @@ class KVConnector:
class ActiveKVConnector(KVConnector): class ActiveKVConnector(KVConnector):
def __init__( def __init__(
self, vllm_config: VllmConfig, kv_caches_dict: dict[str, torch.Tensor] self,
vllm_config: VllmConfig,
kv_caches_dict: dict[str, torch.Tensor],
packed_backing: torch.Tensor | None = None,
packed_block_stride: int | None = None,
): ):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.kv_connector = get_kv_transfer_group() self.kv_connector = get_kv_transfer_group()
# Register kv caches with KV Connector if applicable. if packed_backing is not None and packed_block_stride is not None:
# TODO: support cross_layers_kv_cache self.kv_connector.register_cross_layers_kv_cache(
# (see https://github.com/vllm-project/vllm/pull/27743) packed_backing, None, block_stride=packed_block_stride
self.kv_connector.register_kv_caches(kv_caches_dict) )
else:
self.kv_connector.register_kv_caches(kv_caches_dict)
self.kv_connector.set_host_xfer_buffer_ops(copy_kv_blocks) self.kv_connector.set_host_xfer_buffer_ops(copy_kv_blocks)
self._disabled = False self._disabled = False
@@ -114,10 +120,17 @@ NO_OP_KV_CONNECTOR = KVConnector()
def get_kv_connector( def get_kv_connector(
vllm_config: VllmConfig, kv_caches_dict: dict[str, torch.Tensor] vllm_config: VllmConfig,
kv_caches_dict: dict[str, torch.Tensor],
packed_backing: torch.Tensor | None = None,
packed_block_stride: int | None = None,
) -> KVConnector: ) -> KVConnector:
if not has_kv_transfer_group(): if not has_kv_transfer_group():
# No-op connector.
return NO_OP_KV_CONNECTOR return NO_OP_KV_CONNECTOR
return ActiveKVConnector(vllm_config, kv_caches_dict) return ActiveKVConnector(
vllm_config,
kv_caches_dict,
packed_backing=packed_backing,
packed_block_stride=packed_block_stride,
)
+7 -2
View File
@@ -465,7 +465,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
self.kv_caches: list[torch.Tensor] = [] self.kv_caches: list[torch.Tensor] = []
kv_caches_dict = init_kv_cache( kv_caches_dict, packed_backing, packed_block_stride = init_kv_cache(
self.kv_caches, self.kv_caches,
self.compilation_config.static_forward_context, self.compilation_config.static_forward_context,
self.kv_cache_config, self.kv_cache_config,
@@ -475,7 +475,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.kernel_block_sizes, self.kernel_block_sizes,
self.vllm_config, self.vllm_config,
) )
self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict) self.kv_connector = get_kv_connector(
self.vllm_config,
kv_caches_dict,
packed_backing=packed_backing,
packed_block_stride=packed_block_stride,
)
def _init_kv_zero_meta(self) -> None: def _init_kv_zero_meta(self) -> None:
"""Build KV-block zeroing metadata; invoked from gpu_worker.""" """Build KV-block zeroing metadata; invoked from gpu_worker."""
+56 -21
View File
@@ -6995,7 +6995,7 @@ class GPUModelRunner(
def _allocate_kv_cache_tensors( def _allocate_kv_cache_tensors(
self, kv_cache_config: KVCacheConfig self, kv_cache_config: KVCacheConfig
) -> dict[str, torch.Tensor]: ) -> tuple[dict[str, torch.Tensor], torch.Tensor | None]:
""" """
Initializes the KV cache buffer with the correct size. The buffer needs Initializes the KV cache buffer with the correct size. The buffer needs
to be reshaped to the desired shape before being used by the models. to be reshaped to the desired shape before being used by the models.
@@ -7007,12 +7007,20 @@ class GPUModelRunner(
corresponding memory buffer for KV cache. corresponding memory buffer for KV cache.
""" """
kv_cache_raw_tensors: dict[str, torch.Tensor] = {} kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
packed_backing: torch.Tensor | None = None
for kv_cache_tensor in kv_cache_config.kv_cache_tensors: for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
tensor = torch.zeros( if kv_cache_tensor.block_stride > 0:
kv_cache_tensor.size, dtype=torch.int8, device=self.device if packed_backing is None:
) packed_backing = torch.zeros(
kv_cache_tensor.size, dtype=torch.int8, device=self.device
)
backing = packed_backing
else:
backing = torch.zeros(
kv_cache_tensor.size, dtype=torch.int8, device=self.device
)
for layer_name in kv_cache_tensor.shared_by: for layer_name in kv_cache_tensor.shared_by:
kv_cache_raw_tensors[layer_name] = tensor kv_cache_raw_tensors[layer_name] = backing
layer_names = set() layer_names = set()
for group in kv_cache_config.kv_cache_groups: for group in kv_cache_config.kv_cache_groups:
@@ -7023,7 +7031,7 @@ class GPUModelRunner(
assert layer_names == set(kv_cache_raw_tensors.keys()), ( assert layer_names == set(kv_cache_raw_tensors.keys()), (
"Some layers are not correctly initialized" "Some layers are not correctly initialized"
) )
return kv_cache_raw_tensors return kv_cache_raw_tensors, packed_backing
def _attn_group_iterator(self) -> Iterator[AttentionGroup]: def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
return itertools.chain.from_iterable(self.attn_groups) return itertools.chain.from_iterable(self.attn_groups)
@@ -7052,6 +7060,12 @@ class GPUModelRunner(
""" """
kv_caches: dict[str, torch.Tensor] = {} kv_caches: dict[str, torch.Tensor] = {}
has_attn, has_mamba = False, False has_attn, has_mamba = False, False
layer_packing: dict[str, tuple[int, int]] = {}
for kv_tensor in self.kv_cache_config.kv_cache_tensors:
if kv_tensor.block_stride > 0:
for ln in kv_tensor.shared_by:
layer_packing[ln] = (kv_tensor.offset, kv_tensor.block_stride)
for group in self._kv_cache_spec_attn_group_iterator(): for group in self._kv_cache_spec_attn_group_iterator():
kv_cache_spec = group.kv_cache_spec kv_cache_spec = group.kv_cache_spec
attn_backend = group.backend attn_backend = group.backend
@@ -7063,8 +7077,13 @@ class GPUModelRunner(
if layer_name in self.runner_only_attn_layers: if layer_name in self.runner_only_attn_layers:
continue continue
raw_tensor = kv_cache_raw_tensors[layer_name] raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 packing = layer_packing.get(layer_name)
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes if packing is not None:
_, blk_stride = packing
num_blocks = raw_tensor.numel() // blk_stride
else:
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
if isinstance(kv_cache_spec, AttentionSpec): if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True has_attn = True
num_blocks_per_kv_block = ( num_blocks_per_kv_block = (
@@ -7106,15 +7125,19 @@ class GPUModelRunner(
] ]
raw_tensor = kv_cache_raw_tensors[layer_name].view(dtype) raw_tensor = kv_cache_raw_tensors[layer_name].view(dtype)
if kv_cache_spec.page_size_padded is not None: if packing is not None:
# Use strided view to handle page_size_bytes that layer_offset, blk_stride = packing
# include padding. This follows dtype_size = get_dtype_size(dtype)
# the same pattern as MambaSpec handling below. page_stride = blk_stride // dtype_size
# NOTE: This assumes kv_cache_shape[0] == num_blocks strides = list(torch.empty(kv_cache_shape).stride())
# (i.e. the first physical dimension is the block strides[inv_order[0]] = page_stride
# index), which holds for MLA backends but NOT for kv_cache = torch.as_strided(
# standard attention backends whose shape starts with raw_tensor,
# a K/V dimension of size 2. size=kv_cache_shape,
stride=tuple(strides),
storage_offset=layer_offset // dtype_size,
)
elif kv_cache_spec.page_size_padded is not None:
dtype_size = get_dtype_size(dtype) dtype_size = get_dtype_size(dtype)
page_stride = kv_cache_spec.page_size_bytes // dtype_size page_stride = kv_cache_spec.page_size_bytes // dtype_size
strides = list(torch.empty(kv_cache_shape).stride()) strides = list(torch.empty(kv_cache_shape).stride())
@@ -7125,7 +7148,6 @@ class GPUModelRunner(
stride=tuple(strides), stride=tuple(strides),
) )
else: else:
# No padding — safe to use a contiguous view.
kv_cache = raw_tensor.view(kv_cache_shape) kv_cache = raw_tensor.view(kv_cache_shape)
kv_caches[layer_name] = kv_cache.permute(*inv_order) kv_caches[layer_name] = kv_cache.permute(*inv_order)
@@ -7227,13 +7249,24 @@ class GPUModelRunner(
else: else:
# Fallback to the general case # Fallback to the general case
# Initialize the memory buffer for KV cache # Initialize the memory buffer for KV cache
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) kv_cache_raw_tensors, packed_backing = self._allocate_kv_cache_tensors(
kv_cache_config
)
# Change the memory buffer to the desired shape # Change the memory buffer to the desired shape
kv_caches = self._reshape_kv_cache_tensors( kv_caches = self._reshape_kv_cache_tensors(
kv_cache_raw_tensors, kernel_block_sizes kv_cache_raw_tensors, kernel_block_sizes
) )
if packed_backing is not None:
block_stride = next(
t.block_stride
for t in kv_cache_config.kv_cache_tensors
if t.block_stride > 0
)
self.cross_layers_kv_cache = packed_backing
self._packed_block_stride = block_stride
# Set up cross-layer KV cache sharing # Set up cross-layer KV cache sharing
for layer_name, target_layer_name in self.shared_kv_cache_layers.items(): for layer_name, target_layer_name in self.shared_kv_cache_layers.items():
logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name) logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name)
@@ -7329,9 +7362,11 @@ class GPUModelRunner(
if has_kv_transfer_group() and not is_profiling: if has_kv_transfer_group() and not is_profiling:
kv_transfer_group = get_kv_transfer_group() kv_transfer_group = get_kv_transfer_group()
if self.cross_layers_kv_cache is not None: if self.cross_layers_kv_cache is not None:
assert self.cross_layers_attn_backend is not None block_stride = getattr(self, "_packed_block_stride", None)
kv_transfer_group.register_cross_layers_kv_cache( kv_transfer_group.register_cross_layers_kv_cache(
self.cross_layers_kv_cache, self.cross_layers_attn_backend self.cross_layers_kv_cache,
self.cross_layers_attn_backend,
block_stride=block_stride,
) )
else: else:
kv_transfer_group.register_kv_caches(kv_caches) kv_transfer_group.register_kv_caches(kv_caches)