Pack KV caches into contiguous per-block allocations for DeepSeek V4

For DeepSeek V4, pack all layer data contiguously per block so that
NIXL registers one region per block instead of one per layer.

Full-attention MLA + SWA/compressor caches share one contiguous
allocation per block. Each layer gets an as_strided view with
storage_offset into the packed backing tensor.

The packed backing tensor and block_stride are passed through the
cross-layer KV cache registration API so NIXL registers one region
with block_len=block_stride instead of many separate regions.

Previously: 92 NIXL regions, 92 tiny P2P transfers per block (~16KB).
Now: 1 region, 1 large RDMA transfer per block (~1.48MB).

Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
Tyler Michael Smith
2026-06-03 14:02:34 -04:00
parent 6a11d72df7
commit 5919036d8c
14 changed files with 441 additions and 79 deletions
+214
View File
@@ -0,0 +1,214 @@
"""Tests for contiguous KV block packing in _get_kv_cache_config_deepseek_v4."""
from unittest.mock import MagicMock
import pytest
import torch
from vllm.v1.core.kv_cache_utils import _get_kv_cache_config_deepseek_v4
from vllm.v1.kv_cache_interface import (
KVCacheGroupSpec,
MLAAttentionSpec,
UniformTypeKVCacheSpecs,
)
def _make_mla_spec(page_size: int, block_size: int = 256) -> MLAAttentionSpec:
return MLAAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=512,
dtype=torch.uint8,
page_size_padded=page_size,
cache_dtype_str="fp8_ds_mla",
model_version="deepseek_v4",
alignment=576,
)
def _make_groups(n_c4, n_c128, n_swa):
PS_C4_MLA = 37440
PS_C4_IDX = 8640
PS_C128 = 1728
PS_SWA = 37440
mla_specs = {}
for i in range(n_c4):
mla_specs[f"c4_mla.{i}"] = _make_mla_spec(PS_C4_MLA)
mla_specs[f"c4_idx.{i}"] = _make_mla_spec(PS_C4_IDX)
for i in range(n_c128):
mla_specs[f"c128_mla.{i}"] = _make_mla_spec(PS_C128)
mla_group = KVCacheGroupSpec(
layer_names=list(mla_specs.keys()),
kv_cache_spec=UniformTypeKVCacheSpecs(block_size=256, kv_cache_specs=mla_specs),
)
swa_specs = {}
for i in range(n_swa):
swa_specs[f"swa.{i}"] = _make_mla_spec(PS_SWA)
swa_group = KVCacheGroupSpec(
layer_names=list(swa_specs.keys()),
kv_cache_spec=UniformTypeKVCacheSpecs(block_size=256, kv_cache_specs=swa_specs),
)
return [mla_group, swa_group]
def _mock_vllm_config():
config = MagicMock()
config.scheduler_config.num_gpu_blocks_override = None
return config
def _run(n_c4=3, n_c128=2, n_swa=5, mem=100 * 1024 * 1024):
groups = _make_groups(n_c4, n_c128, n_swa)
return _get_kv_cache_config_deepseek_v4(_mock_vllm_config(), groups, mem)
def _split_tensors(tensors):
mla = [t for t in tensors if any("mla" in n or "idx" in n for n in t.shared_by)]
swa = [t for t in tensors if any("swa" in n for n in t.shared_by)]
return mla, swa
class TestMlaContiguousPacking:
def test_all_mla_share_one_size(self):
_, tensors = _run()
mla, _ = _split_tensors(tensors)
assert len(set(t.size for t in mla)) == 1
def test_all_mla_share_one_block_stride(self):
_, tensors = _run()
mla, _ = _split_tensors(tensors)
strides = set(t.block_stride for t in mla)
assert len(strides) == 1
assert strides.pop() > 0
def test_mla_offsets_are_unique(self):
_, tensors = _run(n_c4=5, n_c128=4)
mla, _ = _split_tensors(tensors)
offsets = [t.offset for t in mla]
assert len(offsets) == len(set(offsets))
def test_mla_offsets_fit_within_block_stride(self):
_, tensors = _run(n_c4=5, n_c128=4)
mla, _ = _split_tensors(tensors)
stride = mla[0].block_stride
for t in mla:
assert 0 <= t.offset < stride, f"offset {t.offset} >= block_stride {stride}"
def test_mla_all_layers_accounted_for(self):
n_c4, n_c128 = 5, 4
_, tensors = _run(n_c4=n_c4, n_c128=n_c128)
mla, _ = _split_tensors(tensors)
all_names = set()
for t in mla:
all_names.update(t.shared_by)
expected_count = n_c4 * 2 + n_c128 # c4_mla + c4_idx + c128_mla
assert len(all_names) == expected_count
def test_size_equals_stride_times_num_blocks(self):
num_blocks, tensors = _run()
mla, _ = _split_tensors(tensors)
for t in mla:
assert t.size == t.block_stride * num_blocks
class TestSwaContiguousPacking:
def test_swa_packed_separately_from_mla(self):
_, tensors = _run()
mla, swa = _split_tensors(tensors)
assert len(swa) > 0
assert len(mla) > 0
assert set(t.size for t in swa) != set(t.size for t in mla)
def test_swa_share_one_size(self):
_, tensors = _run()
_, swa = _split_tensors(tensors)
assert len(set(t.size for t in swa)) == 1
def test_swa_share_one_block_stride(self):
_, tensors = _run()
_, swa = _split_tensors(tensors)
strides = set(t.block_stride for t in swa)
assert len(strides) == 1
assert strides.pop() > 0
def test_swa_offsets_are_unique(self):
_, tensors = _run(n_swa=10)
_, swa = _split_tensors(tensors)
offsets = [t.offset for t in swa]
assert len(offsets) == len(set(offsets))
def test_swa_all_layers_accounted_for(self):
n_swa = 7
_, tensors = _run(n_swa=n_swa)
_, swa = _split_tensors(tensors)
all_names = set()
for t in swa:
all_names.update(t.shared_by)
assert len(all_names) == n_swa
def test_swa_size_equals_stride_times_num_blocks(self):
num_blocks, tensors = _run()
_, swa = _split_tensors(tensors)
for t in swa:
assert t.size == t.block_stride * num_blocks
class TestStridedViewCorrectness:
def test_views_are_independent(self):
page_sizes = [1728, 8640, 37440]
layer_tuple_bytes = sum(page_sizes)
num_tuples = 3
block_stride = layer_tuple_bytes * num_tuples
num_blocks = 4
backing = torch.zeros(block_stride * num_blocks, dtype=torch.uint8)
views = []
for t in range(num_tuples):
for ps_idx, ps in enumerate(page_sizes):
offset = t * layer_tuple_bytes + sum(page_sizes[:ps_idx])
view = torch.as_strided(
backing,
size=(num_blocks, ps),
stride=(block_stride, 1),
storage_offset=offset,
)
views.append((f"tuple{t}_ps{ps}", view))
for i, (name, view) in enumerate(views):
view.fill_(i + 1)
for j, (_, other) in enumerate(views):
if j <= i:
continue
assert other.sum() == 0, f"Writing to {name} corrupted view {j}"
for i, (name, view) in enumerate(views):
assert view.sum() == (i + 1) * view.numel()
def test_block_isolation(self):
ps = 37440
n_layers = 5
block_stride = ps * n_layers
num_blocks = 3
backing = torch.zeros(block_stride * num_blocks, dtype=torch.uint8)
views = [
torch.as_strided(backing, (num_blocks, ps), (block_stride, 1), layer * ps)
for layer in range(n_layers)
]
views[2][1].fill_(42)
assert views[2][0].sum() == 0
assert views[2][2].sum() == 0
for i in range(n_layers):
if i != 2:
assert views[i].sum() == 0
if __name__ == "__main__":
pytest.main([__file__, "-v"])
@@ -259,7 +259,10 @@ class KVConnectorBase_V1(ABC):
return
def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"]
self,
kv_cache: torch.Tensor,
attn_backend: type["AttentionBackend"] | None,
block_stride: int | None = None,
):
"""
Initialize with a single KV cache tensor used by all layers.
@@ -248,7 +248,10 @@ class MooncakeStoreConnector(KVConnectorBase_V1, SupportsHMA):
self.connector_worker.register_kv_caches(kv_caches)
def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type
self,
kv_cache: torch.Tensor,
attn_backend: type | None,
block_stride: int | None = None,
):
assert self.connector_worker is not None
assert (
@@ -236,11 +236,16 @@ class MultiConnector(KVConnectorBase_V1, SupportsHMA):
return ret
def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
self,
kv_cache: torch.Tensor,
attn_backend: type[AttentionBackend] | None,
block_stride: int | None = None,
):
# Register on all connectors
for c in self._connectors:
c.register_cross_layers_kv_cache(kv_cache, attn_backend)
c.register_cross_layers_kv_cache(
kv_cache, attn_backend, block_stride=block_stride
)
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
for c in self._connectors:
@@ -210,10 +210,15 @@ class NixlConnector(KVConnectorBase_V1, SupportsHMA):
self.connector_worker.register_kv_caches(kv_caches)
def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
self,
kv_cache: torch.Tensor,
attn_backend: type[AttentionBackend] | None,
block_stride: int | None = None,
):
assert self.connector_worker is not None
self.connector_worker.register_cross_layers_kv_caches(kv_cache)
self.connector_worker.register_cross_layers_kv_caches(
kv_cache, block_stride=block_stride
)
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
assert self.connector_worker is not None
@@ -774,15 +774,18 @@ class NixlConnectorWorker:
fut.add_done_callback(request_ready)
def register_cross_layers_kv_caches(self, kv_cache: torch.Tensor) -> None:
def register_cross_layers_kv_caches(
self,
kv_cache: torch.Tensor,
block_stride: int | None = None,
) -> None:
"""Register a cross-layers KV cache tensor with NIXL.
`use_uniform_kv_cache()` guarantees a single KV cache group whose
layers all share the same `AttentionSpec`, so any layer name from
`_layer_specs` yields the correct per-layer spec for `page_size_bytes`.
When block_stride is provided (packed heterogeneous layout),
it overrides the per-layer page_size calculation.
"""
self._packed_block_stride = block_stride
first_layer = next(iter(self._layer_specs))
# Forwarding a real layer name rather than a synthetic key
self.register_kv_caches({first_layer: kv_cache})
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
@@ -859,19 +862,23 @@ class NixlConnectorWorker:
)
# `layer_spec.page_size_bytes` only accounts for logical page_size, that is
# the page_size assuming constant `self._logical_num_blocks`.
physical_page_size = (
layer_spec.page_size_bytes
if isinstance(layer_spec, MambaSpec)
else layer_spec.page_size_bytes
// self._physical_blocks_per_logical_kv_block
)
# For when registering multiple tensors eg K/V in separate regions.
physical_page_size = physical_page_size // len(cache_list)
if self.transfer_topo._cross_layers_blocks:
# When cross-layers blocks are used, multiply by number of layers
physical_page_size = physical_page_size * len(
self.kv_cache_config.kv_cache_tensors
packed_stride = getattr(self, "_packed_block_stride", None)
if packed_stride is not None:
physical_page_size = packed_stride
else:
physical_page_size = (
layer_spec.page_size_bytes
if isinstance(layer_spec, MambaSpec)
else layer_spec.page_size_bytes
// self._physical_blocks_per_logical_kv_block
)
# For when registering multiple tensors eg K/V in separate regions.
physical_page_size = physical_page_size // len(cache_list)
if self.transfer_topo._cross_layers_blocks:
# When cross-layers blocks are used, multiply by number of layers
physical_page_size = physical_page_size * len(
self.kv_cache_config.kv_cache_tensors
)
num_blocks = (
self._logical_num_blocks
if isinstance(layer_spec, MambaSpec)
@@ -906,7 +913,7 @@ class NixlConnectorWorker:
else:
self.block_len_per_layer.append(physical_page_size)
if cache.shape[0] != num_blocks:
if packed_stride is None and cache.shape[0] != num_blocks:
raise AssertionError(
"All kv cache tensors must have the same number of "
f"blocks; layer={layer_name}, "
@@ -183,8 +183,12 @@ class OffloadingConnectorWorker:
self._register_handlers(canonical_kv_caches)
def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
self,
kv_cache: torch.Tensor,
attn_backend: type[AttentionBackend] | None,
block_stride: int | None = None,
):
assert attn_backend is not None
# verify that num_blocks is at physical position 0 in the cross-layers
# tensor layout.
test_shape = attn_backend.get_kv_cache_shape(
@@ -76,7 +76,10 @@ class OffloadingConnector(KVConnectorBase_V1, SupportsHMA):
self.connector_worker.register_kv_caches(kv_caches)
def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
self,
kv_cache: torch.Tensor,
attn_backend: type[AttentionBackend] | None,
block_stride: int | None = None,
):
assert self.connector_worker is not None
self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend)
+32 -6
View File
@@ -1240,17 +1240,43 @@ def _get_kv_cache_config_deepseek_v4(
num_blocks = available_memory // (layer_tuple_page_bytes * num_layer_tuples)
num_blocks = may_override_num_blocks(vllm_config, num_blocks)
full_block_bytes = layer_tuple_page_bytes * num_layer_tuples
total_size = full_block_bytes * num_blocks
kv_cache_tensors: list[KVCacheTensor] = []
for tuple_idx in range(num_layer_tuples):
ps_offset = 0
for ps in page_sizes:
shared_by: list[str] = []
for b in bucketed:
mla_shared: list[str] = []
swa_shared: list[str] = []
for g_idx, b in enumerate(bucketed):
bucket = b.get(ps)
if bucket is not None and tuple_idx < len(bucket):
shared_by.append(bucket[tuple_idx])
kv_cache_tensors.append(
KVCacheTensor(size=ps * num_blocks, shared_by=shared_by)
)
if g_idx == 0:
mla_shared.append(bucket[tuple_idx])
else:
swa_shared.append(bucket[tuple_idx])
offset = tuple_idx * layer_tuple_page_bytes + ps_offset
if mla_shared:
kv_cache_tensors.append(
KVCacheTensor(
size=total_size,
shared_by=mla_shared,
offset=offset,
block_stride=full_block_bytes,
)
)
if swa_shared:
kv_cache_tensors.append(
KVCacheTensor(
size=total_size,
shared_by=swa_shared,
offset=offset,
block_stride=full_block_bytes,
)
)
ps_offset += ps
return num_blocks, kv_cache_tensors
+2
View File
@@ -834,6 +834,8 @@ class KVCacheTensor:
size: int # size of the KV cache tensor in bytes
shared_by: list[str] # layer names that share the same KV cache tensor
offset: int = 0 # byte offset of this layer within a contiguous block
block_stride: int = 0 # total bytes per block in a packed layout (0 = not packed)
@dataclass
+53 -16
View File
@@ -151,8 +151,16 @@ def _allocate_kv_cache(
kv_cache_config: KVCacheConfig, shared_layers: dict[str, str], device: torch.device
):
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
packed_backing: torch.Tensor | None = None
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
if kv_cache_tensor.block_stride > 0:
if packed_backing is None:
packed_backing = torch.zeros(
kv_cache_tensor.size, dtype=torch.int8, device=device
)
tensor = packed_backing
else:
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
for layer_name in kv_cache_tensor.shared_by:
kv_cache_raw_tensors[layer_name] = tensor
@@ -163,7 +171,7 @@ def _allocate_kv_cache(
assert layer_names == (kv_cache_raw_tensors.keys() | shared_layers.keys()), (
"Some layers are not correctly initialized"
)
return kv_cache_raw_tensors
return kv_cache_raw_tensors, packed_backing
def _reshape_kv_cache(
@@ -172,10 +180,18 @@ def _reshape_kv_cache(
cache_dtype: str,
kernel_block_sizes: list[int],
shared_kv_cache_layers: dict[str, str],
kv_cache_config: "KVCacheConfig | None" = None,
) -> dict[str, Any]:
kv_caches: dict[str, Any] = {}
has_attn, has_mamba = False, False
layer_packing: dict[str, tuple[int, int]] = {}
if kv_cache_config is not None:
for kv_tensor in kv_cache_config.kv_cache_tensors:
if kv_tensor.block_stride > 0:
for ln in kv_tensor.shared_by:
layer_packing[ln] = (kv_tensor.offset, kv_tensor.block_stride)
for group in attn_groups:
if group.kv_cache_group_id >= len(kernel_block_sizes):
continue
@@ -194,8 +210,13 @@ def _reshape_kv_cache(
continue
kv_raw_tensor = kv_cache_raw_tensors[layer_name]
assert kv_raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = kv_raw_tensor.numel() // kv_cache_spec.page_size_bytes
packing = layer_packing.get(layer_name)
if packing is not None:
_, blk_stride = packing
num_blocks = kv_raw_tensor.numel() // blk_stride
else:
assert kv_raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = kv_raw_tensor.numel() // kv_cache_spec.page_size_bytes
if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True
@@ -229,14 +250,19 @@ def _reshape_kv_cache(
dtype = kv_cache_spec.dtype
kv_tensor = kv_raw_tensor.view(dtype)
if kv_cache_spec.page_size_padded is not None:
# Use strided view to handle page_size_bytes that
# include padding. This follows the same pattern as
# MambaSpec handling in gpu_model_runner.py.
# NOTE: This assumes kv_cache_shape[0] == num_blocks
# (i.e. the first physical dimension is the block
# index), which holds for all current backends
# (MLA, FlashAttention, TritonAttention, etc.).
if packing is not None:
layer_offset, blk_stride = packing
dtype_size = get_dtype_size(dtype)
page_stride = blk_stride // dtype_size
strides = list(torch.empty(kv_cache_shape).stride())
strides[inv_order[0]] = page_stride
kv_cache = torch.as_strided(
kv_tensor,
size=kv_cache_shape,
stride=tuple(strides),
storage_offset=layer_offset // dtype_size,
)
elif kv_cache_spec.page_size_padded is not None:
dtype_size = get_dtype_size(dtype)
page_stride = kv_cache_spec.page_size_bytes // dtype_size
strides = list(torch.empty(kv_cache_shape).stride())
@@ -247,7 +273,6 @@ def _reshape_kv_cache(
stride=tuple(strides),
)
else:
# No padding — safe to use a contiguous view.
kv_cache = kv_tensor.view(kv_cache_shape)
kv_caches[layer_name] = kv_cache.permute(*inv_order)
@@ -349,9 +374,9 @@ def init_kv_cache(
cache_dtype: str,
kernel_block_sizes: list[int],
vllm_config: VllmConfig,
) -> dict[str, Any]:
) -> tuple[dict[str, Any], torch.Tensor | None, int | None]:
shared_kv_cache_layers = get_shared_kv_cache_layers(vllm_config)
kv_cache_raw_tensors = _allocate_kv_cache(
kv_cache_raw_tensors, packed_backing = _allocate_kv_cache(
kv_cache_config, shared_kv_cache_layers, device
)
flattened_attn_groups = list(group for groups in attn_groups for group in groups)
@@ -361,9 +386,21 @@ def init_kv_cache(
kernel_block_sizes=kernel_block_sizes,
cache_dtype=cache_dtype,
shared_kv_cache_layers=shared_kv_cache_layers,
kv_cache_config=kv_cache_config,
)
bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
return kv_caches
packed_block_stride = None
if packed_backing is not None:
packed_block_stride = next(
(
t.block_stride
for t in kv_cache_config.kv_cache_tensors
if t.block_stride > 0
),
None,
)
return kv_caches, packed_backing, packed_block_stride
def build_slot_mappings_by_layer(
+21 -8
View File
@@ -46,14 +46,20 @@ class KVConnector:
class ActiveKVConnector(KVConnector):
def __init__(
self, vllm_config: VllmConfig, kv_caches_dict: dict[str, torch.Tensor]
self,
vllm_config: VllmConfig,
kv_caches_dict: dict[str, torch.Tensor],
packed_backing: torch.Tensor | None = None,
packed_block_stride: int | None = None,
):
self.vllm_config = vllm_config
self.kv_connector = get_kv_transfer_group()
# Register kv caches with KV Connector if applicable.
# TODO: support cross_layers_kv_cache
# (see https://github.com/vllm-project/vllm/pull/27743)
self.kv_connector.register_kv_caches(kv_caches_dict)
if packed_backing is not None and packed_block_stride is not None:
self.kv_connector.register_cross_layers_kv_cache(
packed_backing, None, block_stride=packed_block_stride
)
else:
self.kv_connector.register_kv_caches(kv_caches_dict)
self.kv_connector.set_host_xfer_buffer_ops(copy_kv_blocks)
self._disabled = False
@@ -114,10 +120,17 @@ NO_OP_KV_CONNECTOR = KVConnector()
def get_kv_connector(
vllm_config: VllmConfig, kv_caches_dict: dict[str, torch.Tensor]
vllm_config: VllmConfig,
kv_caches_dict: dict[str, torch.Tensor],
packed_backing: torch.Tensor | None = None,
packed_block_stride: int | None = None,
) -> KVConnector:
if not has_kv_transfer_group():
# No-op connector.
return NO_OP_KV_CONNECTOR
return ActiveKVConnector(vllm_config, kv_caches_dict)
return ActiveKVConnector(
vllm_config,
kv_caches_dict,
packed_backing=packed_backing,
packed_block_stride=packed_block_stride,
)
+7 -2
View File
@@ -465,7 +465,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
self.kv_caches: list[torch.Tensor] = []
kv_caches_dict = init_kv_cache(
kv_caches_dict, packed_backing, packed_block_stride = init_kv_cache(
self.kv_caches,
self.compilation_config.static_forward_context,
self.kv_cache_config,
@@ -475,7 +475,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.kernel_block_sizes,
self.vllm_config,
)
self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)
self.kv_connector = get_kv_connector(
self.vllm_config,
kv_caches_dict,
packed_backing=packed_backing,
packed_block_stride=packed_block_stride,
)
def _init_kv_zero_meta(self) -> None:
"""Build KV-block zeroing metadata; invoked from gpu_worker."""
+56 -21
View File
@@ -6995,7 +6995,7 @@ class GPUModelRunner(
def _allocate_kv_cache_tensors(
self, kv_cache_config: KVCacheConfig
) -> dict[str, torch.Tensor]:
) -> tuple[dict[str, torch.Tensor], torch.Tensor | None]:
"""
Initializes the KV cache buffer with the correct size. The buffer needs
to be reshaped to the desired shape before being used by the models.
@@ -7007,12 +7007,20 @@ class GPUModelRunner(
corresponding memory buffer for KV cache.
"""
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
packed_backing: torch.Tensor | None = None
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
tensor = torch.zeros(
kv_cache_tensor.size, dtype=torch.int8, device=self.device
)
if kv_cache_tensor.block_stride > 0:
if packed_backing is None:
packed_backing = torch.zeros(
kv_cache_tensor.size, dtype=torch.int8, device=self.device
)
backing = packed_backing
else:
backing = torch.zeros(
kv_cache_tensor.size, dtype=torch.int8, device=self.device
)
for layer_name in kv_cache_tensor.shared_by:
kv_cache_raw_tensors[layer_name] = tensor
kv_cache_raw_tensors[layer_name] = backing
layer_names = set()
for group in kv_cache_config.kv_cache_groups:
@@ -7023,7 +7031,7 @@ class GPUModelRunner(
assert layer_names == set(kv_cache_raw_tensors.keys()), (
"Some layers are not correctly initialized"
)
return kv_cache_raw_tensors
return kv_cache_raw_tensors, packed_backing
def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
return itertools.chain.from_iterable(self.attn_groups)
@@ -7052,6 +7060,12 @@ class GPUModelRunner(
"""
kv_caches: dict[str, torch.Tensor] = {}
has_attn, has_mamba = False, False
layer_packing: dict[str, tuple[int, int]] = {}
for kv_tensor in self.kv_cache_config.kv_cache_tensors:
if kv_tensor.block_stride > 0:
for ln in kv_tensor.shared_by:
layer_packing[ln] = (kv_tensor.offset, kv_tensor.block_stride)
for group in self._kv_cache_spec_attn_group_iterator():
kv_cache_spec = group.kv_cache_spec
attn_backend = group.backend
@@ -7063,8 +7077,13 @@ class GPUModelRunner(
if layer_name in self.runner_only_attn_layers:
continue
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
packing = layer_packing.get(layer_name)
if packing is not None:
_, blk_stride = packing
num_blocks = raw_tensor.numel() // blk_stride
else:
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True
num_blocks_per_kv_block = (
@@ -7106,15 +7125,19 @@ class GPUModelRunner(
]
raw_tensor = kv_cache_raw_tensors[layer_name].view(dtype)
if kv_cache_spec.page_size_padded is not None:
# Use strided view to handle page_size_bytes that
# include padding. This follows
# the same pattern as MambaSpec handling below.
# NOTE: This assumes kv_cache_shape[0] == num_blocks
# (i.e. the first physical dimension is the block
# index), which holds for MLA backends but NOT for
# standard attention backends whose shape starts with
# a K/V dimension of size 2.
if packing is not None:
layer_offset, blk_stride = packing
dtype_size = get_dtype_size(dtype)
page_stride = blk_stride // dtype_size
strides = list(torch.empty(kv_cache_shape).stride())
strides[inv_order[0]] = page_stride
kv_cache = torch.as_strided(
raw_tensor,
size=kv_cache_shape,
stride=tuple(strides),
storage_offset=layer_offset // dtype_size,
)
elif kv_cache_spec.page_size_padded is not None:
dtype_size = get_dtype_size(dtype)
page_stride = kv_cache_spec.page_size_bytes // dtype_size
strides = list(torch.empty(kv_cache_shape).stride())
@@ -7125,7 +7148,6 @@ class GPUModelRunner(
stride=tuple(strides),
)
else:
# No padding — safe to use a contiguous view.
kv_cache = raw_tensor.view(kv_cache_shape)
kv_caches[layer_name] = kv_cache.permute(*inv_order)
@@ -7227,13 +7249,24 @@ class GPUModelRunner(
else:
# Fallback to the general case
# Initialize the memory buffer for KV cache
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
kv_cache_raw_tensors, packed_backing = self._allocate_kv_cache_tensors(
kv_cache_config
)
# Change the memory buffer to the desired shape
kv_caches = self._reshape_kv_cache_tensors(
kv_cache_raw_tensors, kernel_block_sizes
)
if packed_backing is not None:
block_stride = next(
t.block_stride
for t in kv_cache_config.kv_cache_tensors
if t.block_stride > 0
)
self.cross_layers_kv_cache = packed_backing
self._packed_block_stride = block_stride
# Set up cross-layer KV cache sharing
for layer_name, target_layer_name in self.shared_kv_cache_layers.items():
logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name)
@@ -7329,9 +7362,11 @@ class GPUModelRunner(
if has_kv_transfer_group() and not is_profiling:
kv_transfer_group = get_kv_transfer_group()
if self.cross_layers_kv_cache is not None:
assert self.cross_layers_attn_backend is not None
block_stride = getattr(self, "_packed_block_stride", None)
kv_transfer_group.register_cross_layers_kv_cache(
self.cross_layers_kv_cache, self.cross_layers_attn_backend
self.cross_layers_kv_cache,
self.cross_layers_attn_backend,
block_stride=block_stride,
)
else:
kv_transfer_group.register_kv_caches(kv_caches)