From 5919036d8cd4f7e7df1596ca2fbf61a6195a335f Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 3 Jun 2026 14:02:34 -0400 Subject: [PATCH] Pack KV caches into contiguous per-block allocations for DeepSeek V4 For DeepSeek V4, pack all layer data contiguously per block so that NIXL registers one region per block instead of one per layer. Full-attention MLA + SWA/compressor caches share one contiguous allocation per block. Each layer gets an as_strided view with storage_offset into the packed backing tensor. The packed backing tensor and block_stride are passed through the cross-layer KV cache registration API so NIXL registers one region with block_len=block_stride instead of many separate regions. Previously: 92 NIXL regions, 92 tiny P2P transfers per block (~16KB). Now: 1 region, 1 large RDMA transfer per block (~1.48MB). Co-authored-by: Claude Signed-off-by: Tyler Michael Smith --- tests/v1/core/test_contiguous_kv_packing.py | 214 ++++++++++++++++++ .../kv_transfer/kv_connector/v1/base.py | 5 +- .../v1/mooncake/store/connector.py | 5 +- .../kv_connector/v1/multi_connector.py | 9 +- .../kv_connector/v1/nixl/connector.py | 9 +- .../kv_connector/v1/nixl/worker.py | 43 ++-- .../kv_connector/v1/offloading/worker.py | 6 +- .../kv_connector/v1/offloading_connector.py | 5 +- vllm/v1/core/kv_cache_utils.py | 38 +++- vllm/v1/kv_cache_interface.py | 2 + vllm/v1/worker/gpu/attn_utils.py | 69 ++++-- vllm/v1/worker/gpu/kv_connector.py | 29 ++- vllm/v1/worker/gpu/model_runner.py | 9 +- vllm/v1/worker/gpu_model_runner.py | 77 +++++-- 14 files changed, 441 insertions(+), 79 deletions(-) create mode 100644 tests/v1/core/test_contiguous_kv_packing.py diff --git a/tests/v1/core/test_contiguous_kv_packing.py b/tests/v1/core/test_contiguous_kv_packing.py new file mode 100644 index 00000000000..56fd7083ff5 --- /dev/null +++ b/tests/v1/core/test_contiguous_kv_packing.py @@ -0,0 +1,214 @@ +"""Tests for contiguous KV block packing in _get_kv_cache_config_deepseek_v4.""" + +from unittest.mock import MagicMock + +import pytest +import torch + +from vllm.v1.core.kv_cache_utils import _get_kv_cache_config_deepseek_v4 +from vllm.v1.kv_cache_interface import ( + KVCacheGroupSpec, + MLAAttentionSpec, + UniformTypeKVCacheSpecs, +) + + +def _make_mla_spec(page_size: int, block_size: int = 256) -> MLAAttentionSpec: + return MLAAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=512, + dtype=torch.uint8, + page_size_padded=page_size, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + alignment=576, + ) + + +def _make_groups(n_c4, n_c128, n_swa): + PS_C4_MLA = 37440 + PS_C4_IDX = 8640 + PS_C128 = 1728 + PS_SWA = 37440 + + mla_specs = {} + for i in range(n_c4): + mla_specs[f"c4_mla.{i}"] = _make_mla_spec(PS_C4_MLA) + mla_specs[f"c4_idx.{i}"] = _make_mla_spec(PS_C4_IDX) + for i in range(n_c128): + mla_specs[f"c128_mla.{i}"] = _make_mla_spec(PS_C128) + + mla_group = KVCacheGroupSpec( + layer_names=list(mla_specs.keys()), + kv_cache_spec=UniformTypeKVCacheSpecs(block_size=256, kv_cache_specs=mla_specs), + ) + + swa_specs = {} + for i in range(n_swa): + swa_specs[f"swa.{i}"] = _make_mla_spec(PS_SWA) + + swa_group = KVCacheGroupSpec( + layer_names=list(swa_specs.keys()), + kv_cache_spec=UniformTypeKVCacheSpecs(block_size=256, kv_cache_specs=swa_specs), + ) + + return [mla_group, swa_group] + + +def _mock_vllm_config(): + config = MagicMock() + config.scheduler_config.num_gpu_blocks_override = None + return config + + +def _run(n_c4=3, n_c128=2, n_swa=5, mem=100 * 1024 * 1024): + groups = _make_groups(n_c4, n_c128, n_swa) + return _get_kv_cache_config_deepseek_v4(_mock_vllm_config(), groups, mem) + + +def _split_tensors(tensors): + mla = [t for t in tensors if any("mla" in n or "idx" in n for n in t.shared_by)] + swa = [t for t in tensors if any("swa" in n for n in t.shared_by)] + return mla, swa + + +class TestMlaContiguousPacking: + def test_all_mla_share_one_size(self): + _, tensors = _run() + mla, _ = _split_tensors(tensors) + assert len(set(t.size for t in mla)) == 1 + + def test_all_mla_share_one_block_stride(self): + _, tensors = _run() + mla, _ = _split_tensors(tensors) + strides = set(t.block_stride for t in mla) + assert len(strides) == 1 + assert strides.pop() > 0 + + def test_mla_offsets_are_unique(self): + _, tensors = _run(n_c4=5, n_c128=4) + mla, _ = _split_tensors(tensors) + offsets = [t.offset for t in mla] + assert len(offsets) == len(set(offsets)) + + def test_mla_offsets_fit_within_block_stride(self): + _, tensors = _run(n_c4=5, n_c128=4) + mla, _ = _split_tensors(tensors) + stride = mla[0].block_stride + for t in mla: + assert 0 <= t.offset < stride, f"offset {t.offset} >= block_stride {stride}" + + def test_mla_all_layers_accounted_for(self): + n_c4, n_c128 = 5, 4 + _, tensors = _run(n_c4=n_c4, n_c128=n_c128) + mla, _ = _split_tensors(tensors) + all_names = set() + for t in mla: + all_names.update(t.shared_by) + expected_count = n_c4 * 2 + n_c128 # c4_mla + c4_idx + c128_mla + assert len(all_names) == expected_count + + def test_size_equals_stride_times_num_blocks(self): + num_blocks, tensors = _run() + mla, _ = _split_tensors(tensors) + for t in mla: + assert t.size == t.block_stride * num_blocks + + +class TestSwaContiguousPacking: + def test_swa_packed_separately_from_mla(self): + _, tensors = _run() + mla, swa = _split_tensors(tensors) + assert len(swa) > 0 + assert len(mla) > 0 + assert set(t.size for t in swa) != set(t.size for t in mla) + + def test_swa_share_one_size(self): + _, tensors = _run() + _, swa = _split_tensors(tensors) + assert len(set(t.size for t in swa)) == 1 + + def test_swa_share_one_block_stride(self): + _, tensors = _run() + _, swa = _split_tensors(tensors) + strides = set(t.block_stride for t in swa) + assert len(strides) == 1 + assert strides.pop() > 0 + + def test_swa_offsets_are_unique(self): + _, tensors = _run(n_swa=10) + _, swa = _split_tensors(tensors) + offsets = [t.offset for t in swa] + assert len(offsets) == len(set(offsets)) + + def test_swa_all_layers_accounted_for(self): + n_swa = 7 + _, tensors = _run(n_swa=n_swa) + _, swa = _split_tensors(tensors) + all_names = set() + for t in swa: + all_names.update(t.shared_by) + assert len(all_names) == n_swa + + def test_swa_size_equals_stride_times_num_blocks(self): + num_blocks, tensors = _run() + _, swa = _split_tensors(tensors) + for t in swa: + assert t.size == t.block_stride * num_blocks + + +class TestStridedViewCorrectness: + def test_views_are_independent(self): + page_sizes = [1728, 8640, 37440] + layer_tuple_bytes = sum(page_sizes) + num_tuples = 3 + block_stride = layer_tuple_bytes * num_tuples + num_blocks = 4 + + backing = torch.zeros(block_stride * num_blocks, dtype=torch.uint8) + + views = [] + for t in range(num_tuples): + for ps_idx, ps in enumerate(page_sizes): + offset = t * layer_tuple_bytes + sum(page_sizes[:ps_idx]) + view = torch.as_strided( + backing, + size=(num_blocks, ps), + stride=(block_stride, 1), + storage_offset=offset, + ) + views.append((f"tuple{t}_ps{ps}", view)) + + for i, (name, view) in enumerate(views): + view.fill_(i + 1) + for j, (_, other) in enumerate(views): + if j <= i: + continue + assert other.sum() == 0, f"Writing to {name} corrupted view {j}" + + for i, (name, view) in enumerate(views): + assert view.sum() == (i + 1) * view.numel() + + def test_block_isolation(self): + ps = 37440 + n_layers = 5 + block_stride = ps * n_layers + num_blocks = 3 + + backing = torch.zeros(block_stride * num_blocks, dtype=torch.uint8) + views = [ + torch.as_strided(backing, (num_blocks, ps), (block_stride, 1), layer * ps) + for layer in range(n_layers) + ] + + views[2][1].fill_(42) + assert views[2][0].sum() == 0 + assert views[2][2].sum() == 0 + for i in range(n_layers): + if i != 2: + assert views[i].sum() == 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 71d89f43a79..9afffe7c704 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -259,7 +259,10 @@ class KVConnectorBase_V1(ABC): return def register_cross_layers_kv_cache( - self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"] + self, + kv_cache: torch.Tensor, + attn_backend: type["AttentionBackend"] | None, + block_stride: int | None = None, ): """ Initialize with a single KV cache tensor used by all layers. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/connector.py index 14d4b381a3c..341db0cc684 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/connector.py @@ -248,7 +248,10 @@ class MooncakeStoreConnector(KVConnectorBase_V1, SupportsHMA): self.connector_worker.register_kv_caches(kv_caches) def register_cross_layers_kv_cache( - self, kv_cache: torch.Tensor, attn_backend: type + self, + kv_cache: torch.Tensor, + attn_backend: type | None, + block_stride: int | None = None, ): assert self.connector_worker is not None assert ( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 46354337e65..ac4cda3ee0d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -236,11 +236,16 @@ class MultiConnector(KVConnectorBase_V1, SupportsHMA): return ret def register_cross_layers_kv_cache( - self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend] + self, + kv_cache: torch.Tensor, + attn_backend: type[AttentionBackend] | None, + block_stride: int | None = None, ): # Register on all connectors for c in self._connectors: - c.register_cross_layers_kv_cache(kv_cache, attn_backend) + c.register_cross_layers_kv_cache( + kv_cache, attn_backend, block_stride=block_stride + ) def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): for c in self._connectors: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py index dad81e84c45..b51446c0c90 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py @@ -210,10 +210,15 @@ class NixlConnector(KVConnectorBase_V1, SupportsHMA): self.connector_worker.register_kv_caches(kv_caches) def register_cross_layers_kv_cache( - self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend] + self, + kv_cache: torch.Tensor, + attn_backend: type[AttentionBackend] | None, + block_stride: int | None = None, ): assert self.connector_worker is not None - self.connector_worker.register_cross_layers_kv_caches(kv_cache) + self.connector_worker.register_cross_layers_kv_caches( + kv_cache, block_stride=block_stride + ) def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): assert self.connector_worker is not None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index a297058c845..a83c8ff01f8 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -774,15 +774,18 @@ class NixlConnectorWorker: fut.add_done_callback(request_ready) - def register_cross_layers_kv_caches(self, kv_cache: torch.Tensor) -> None: + def register_cross_layers_kv_caches( + self, + kv_cache: torch.Tensor, + block_stride: int | None = None, + ) -> None: """Register a cross-layers KV cache tensor with NIXL. - `use_uniform_kv_cache()` guarantees a single KV cache group whose - layers all share the same `AttentionSpec`, so any layer name from - `_layer_specs` yields the correct per-layer spec for `page_size_bytes`. + When block_stride is provided (packed heterogeneous layout), + it overrides the per-layer page_size calculation. """ + self._packed_block_stride = block_stride first_layer = next(iter(self._layer_specs)) - # Forwarding a real layer name rather than a synthetic key self.register_kv_caches({first_layer: kv_cache}) def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): @@ -859,19 +862,23 @@ class NixlConnectorWorker: ) # `layer_spec.page_size_bytes` only accounts for logical page_size, that is # the page_size assuming constant `self._logical_num_blocks`. - physical_page_size = ( - layer_spec.page_size_bytes - if isinstance(layer_spec, MambaSpec) - else layer_spec.page_size_bytes - // self._physical_blocks_per_logical_kv_block - ) - # For when registering multiple tensors eg K/V in separate regions. - physical_page_size = physical_page_size // len(cache_list) - if self.transfer_topo._cross_layers_blocks: - # When cross-layers blocks are used, multiply by number of layers - physical_page_size = physical_page_size * len( - self.kv_cache_config.kv_cache_tensors + packed_stride = getattr(self, "_packed_block_stride", None) + if packed_stride is not None: + physical_page_size = packed_stride + else: + physical_page_size = ( + layer_spec.page_size_bytes + if isinstance(layer_spec, MambaSpec) + else layer_spec.page_size_bytes + // self._physical_blocks_per_logical_kv_block ) + # For when registering multiple tensors eg K/V in separate regions. + physical_page_size = physical_page_size // len(cache_list) + if self.transfer_topo._cross_layers_blocks: + # When cross-layers blocks are used, multiply by number of layers + physical_page_size = physical_page_size * len( + self.kv_cache_config.kv_cache_tensors + ) num_blocks = ( self._logical_num_blocks if isinstance(layer_spec, MambaSpec) @@ -906,7 +913,7 @@ class NixlConnectorWorker: else: self.block_len_per_layer.append(physical_page_size) - if cache.shape[0] != num_blocks: + if packed_stride is None and cache.shape[0] != num_blocks: raise AssertionError( "All kv cache tensors must have the same number of " f"blocks; layer={layer_name}, " diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py index 8957ce3445a..6342c8c5bec 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py @@ -183,8 +183,12 @@ class OffloadingConnectorWorker: self._register_handlers(canonical_kv_caches) def register_cross_layers_kv_cache( - self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend] + self, + kv_cache: torch.Tensor, + attn_backend: type[AttentionBackend] | None, + block_stride: int | None = None, ): + assert attn_backend is not None # verify that num_blocks is at physical position 0 in the cross-layers # tensor layout. test_shape = attn_backend.get_kv_cache_shape( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py index 20888c71f84..0b549e67394 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -76,7 +76,10 @@ class OffloadingConnector(KVConnectorBase_V1, SupportsHMA): self.connector_worker.register_kv_caches(kv_caches) def register_cross_layers_kv_cache( - self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend] + self, + kv_cache: torch.Tensor, + attn_backend: type[AttentionBackend] | None, + block_stride: int | None = None, ): assert self.connector_worker is not None self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index ae3db581c0d..c3fc676ea9d 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1240,17 +1240,43 @@ def _get_kv_cache_config_deepseek_v4( num_blocks = available_memory // (layer_tuple_page_bytes * num_layer_tuples) num_blocks = may_override_num_blocks(vllm_config, num_blocks) + full_block_bytes = layer_tuple_page_bytes * num_layer_tuples + total_size = full_block_bytes * num_blocks + kv_cache_tensors: list[KVCacheTensor] = [] + for tuple_idx in range(num_layer_tuples): + ps_offset = 0 for ps in page_sizes: - shared_by: list[str] = [] - for b in bucketed: + mla_shared: list[str] = [] + swa_shared: list[str] = [] + for g_idx, b in enumerate(bucketed): bucket = b.get(ps) if bucket is not None and tuple_idx < len(bucket): - shared_by.append(bucket[tuple_idx]) - kv_cache_tensors.append( - KVCacheTensor(size=ps * num_blocks, shared_by=shared_by) - ) + if g_idx == 0: + mla_shared.append(bucket[tuple_idx]) + else: + swa_shared.append(bucket[tuple_idx]) + offset = tuple_idx * layer_tuple_page_bytes + ps_offset + if mla_shared: + kv_cache_tensors.append( + KVCacheTensor( + size=total_size, + shared_by=mla_shared, + offset=offset, + block_stride=full_block_bytes, + ) + ) + if swa_shared: + kv_cache_tensors.append( + KVCacheTensor( + size=total_size, + shared_by=swa_shared, + offset=offset, + block_stride=full_block_bytes, + ) + ) + ps_offset += ps return num_blocks, kv_cache_tensors diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 2f8048c7966..a1525b17d33 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -834,6 +834,8 @@ class KVCacheTensor: size: int # size of the KV cache tensor in bytes shared_by: list[str] # layer names that share the same KV cache tensor + offset: int = 0 # byte offset of this layer within a contiguous block + block_stride: int = 0 # total bytes per block in a packed layout (0 = not packed) @dataclass diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 6fc55ee3203..c4b6b213022 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -151,8 +151,16 @@ def _allocate_kv_cache( kv_cache_config: KVCacheConfig, shared_layers: dict[str, str], device: torch.device ): kv_cache_raw_tensors: dict[str, torch.Tensor] = {} + packed_backing: torch.Tensor | None = None for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device) + if kv_cache_tensor.block_stride > 0: + if packed_backing is None: + packed_backing = torch.zeros( + kv_cache_tensor.size, dtype=torch.int8, device=device + ) + tensor = packed_backing + else: + tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device) for layer_name in kv_cache_tensor.shared_by: kv_cache_raw_tensors[layer_name] = tensor @@ -163,7 +171,7 @@ def _allocate_kv_cache( assert layer_names == (kv_cache_raw_tensors.keys() | shared_layers.keys()), ( "Some layers are not correctly initialized" ) - return kv_cache_raw_tensors + return kv_cache_raw_tensors, packed_backing def _reshape_kv_cache( @@ -172,10 +180,18 @@ def _reshape_kv_cache( cache_dtype: str, kernel_block_sizes: list[int], shared_kv_cache_layers: dict[str, str], + kv_cache_config: "KVCacheConfig | None" = None, ) -> dict[str, Any]: kv_caches: dict[str, Any] = {} has_attn, has_mamba = False, False + layer_packing: dict[str, tuple[int, int]] = {} + if kv_cache_config is not None: + for kv_tensor in kv_cache_config.kv_cache_tensors: + if kv_tensor.block_stride > 0: + for ln in kv_tensor.shared_by: + layer_packing[ln] = (kv_tensor.offset, kv_tensor.block_stride) + for group in attn_groups: if group.kv_cache_group_id >= len(kernel_block_sizes): continue @@ -194,8 +210,13 @@ def _reshape_kv_cache( continue kv_raw_tensor = kv_cache_raw_tensors[layer_name] - assert kv_raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 - num_blocks = kv_raw_tensor.numel() // kv_cache_spec.page_size_bytes + packing = layer_packing.get(layer_name) + if packing is not None: + _, blk_stride = packing + num_blocks = kv_raw_tensor.numel() // blk_stride + else: + assert kv_raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 + num_blocks = kv_raw_tensor.numel() // kv_cache_spec.page_size_bytes if isinstance(kv_cache_spec, AttentionSpec): has_attn = True @@ -229,14 +250,19 @@ def _reshape_kv_cache( dtype = kv_cache_spec.dtype kv_tensor = kv_raw_tensor.view(dtype) - if kv_cache_spec.page_size_padded is not None: - # Use strided view to handle page_size_bytes that - # include padding. This follows the same pattern as - # MambaSpec handling in gpu_model_runner.py. - # NOTE: This assumes kv_cache_shape[0] == num_blocks - # (i.e. the first physical dimension is the block - # index), which holds for all current backends - # (MLA, FlashAttention, TritonAttention, etc.). + if packing is not None: + layer_offset, blk_stride = packing + dtype_size = get_dtype_size(dtype) + page_stride = blk_stride // dtype_size + strides = list(torch.empty(kv_cache_shape).stride()) + strides[inv_order[0]] = page_stride + kv_cache = torch.as_strided( + kv_tensor, + size=kv_cache_shape, + stride=tuple(strides), + storage_offset=layer_offset // dtype_size, + ) + elif kv_cache_spec.page_size_padded is not None: dtype_size = get_dtype_size(dtype) page_stride = kv_cache_spec.page_size_bytes // dtype_size strides = list(torch.empty(kv_cache_shape).stride()) @@ -247,7 +273,6 @@ def _reshape_kv_cache( stride=tuple(strides), ) else: - # No padding — safe to use a contiguous view. kv_cache = kv_tensor.view(kv_cache_shape) kv_caches[layer_name] = kv_cache.permute(*inv_order) @@ -349,9 +374,9 @@ def init_kv_cache( cache_dtype: str, kernel_block_sizes: list[int], vllm_config: VllmConfig, -) -> dict[str, Any]: +) -> tuple[dict[str, Any], torch.Tensor | None, int | None]: shared_kv_cache_layers = get_shared_kv_cache_layers(vllm_config) - kv_cache_raw_tensors = _allocate_kv_cache( + kv_cache_raw_tensors, packed_backing = _allocate_kv_cache( kv_cache_config, shared_kv_cache_layers, device ) flattened_attn_groups = list(group for groups in attn_groups for group in groups) @@ -361,9 +386,21 @@ def init_kv_cache( kernel_block_sizes=kernel_block_sizes, cache_dtype=cache_dtype, shared_kv_cache_layers=shared_kv_cache_layers, + kv_cache_config=kv_cache_config, ) bind_kv_cache(kv_caches, forward_context, runner_kv_caches) - return kv_caches + + packed_block_stride = None + if packed_backing is not None: + packed_block_stride = next( + ( + t.block_stride + for t in kv_cache_config.kv_cache_tensors + if t.block_stride > 0 + ), + None, + ) + return kv_caches, packed_backing, packed_block_stride def build_slot_mappings_by_layer( diff --git a/vllm/v1/worker/gpu/kv_connector.py b/vllm/v1/worker/gpu/kv_connector.py index cdacb36e583..aa40119133e 100644 --- a/vllm/v1/worker/gpu/kv_connector.py +++ b/vllm/v1/worker/gpu/kv_connector.py @@ -46,14 +46,20 @@ class KVConnector: class ActiveKVConnector(KVConnector): def __init__( - self, vllm_config: VllmConfig, kv_caches_dict: dict[str, torch.Tensor] + self, + vllm_config: VllmConfig, + kv_caches_dict: dict[str, torch.Tensor], + packed_backing: torch.Tensor | None = None, + packed_block_stride: int | None = None, ): self.vllm_config = vllm_config self.kv_connector = get_kv_transfer_group() - # Register kv caches with KV Connector if applicable. - # TODO: support cross_layers_kv_cache - # (see https://github.com/vllm-project/vllm/pull/27743) - self.kv_connector.register_kv_caches(kv_caches_dict) + if packed_backing is not None and packed_block_stride is not None: + self.kv_connector.register_cross_layers_kv_cache( + packed_backing, None, block_stride=packed_block_stride + ) + else: + self.kv_connector.register_kv_caches(kv_caches_dict) self.kv_connector.set_host_xfer_buffer_ops(copy_kv_blocks) self._disabled = False @@ -114,10 +120,17 @@ NO_OP_KV_CONNECTOR = KVConnector() def get_kv_connector( - vllm_config: VllmConfig, kv_caches_dict: dict[str, torch.Tensor] + vllm_config: VllmConfig, + kv_caches_dict: dict[str, torch.Tensor], + packed_backing: torch.Tensor | None = None, + packed_block_stride: int | None = None, ) -> KVConnector: if not has_kv_transfer_group(): - # No-op connector. return NO_OP_KV_CONNECTOR - return ActiveKVConnector(vllm_config, kv_caches_dict) + return ActiveKVConnector( + vllm_config, + kv_caches_dict, + packed_backing=packed_backing, + packed_block_stride=packed_block_stride, + ) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 367147b0b4d..07b53ccd3ec 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -465,7 +465,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) self.kv_caches: list[torch.Tensor] = [] - kv_caches_dict = init_kv_cache( + kv_caches_dict, packed_backing, packed_block_stride = init_kv_cache( self.kv_caches, self.compilation_config.static_forward_context, self.kv_cache_config, @@ -475,7 +475,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.kernel_block_sizes, self.vllm_config, ) - self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict) + self.kv_connector = get_kv_connector( + self.vllm_config, + kv_caches_dict, + packed_backing=packed_backing, + packed_block_stride=packed_block_stride, + ) def _init_kv_zero_meta(self) -> None: """Build KV-block zeroing metadata; invoked from gpu_worker.""" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 462375644ba..ac8b8d14ff2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -6995,7 +6995,7 @@ class GPUModelRunner( def _allocate_kv_cache_tensors( self, kv_cache_config: KVCacheConfig - ) -> dict[str, torch.Tensor]: + ) -> tuple[dict[str, torch.Tensor], torch.Tensor | None]: """ Initializes the KV cache buffer with the correct size. The buffer needs to be reshaped to the desired shape before being used by the models. @@ -7007,12 +7007,20 @@ class GPUModelRunner( corresponding memory buffer for KV cache. """ kv_cache_raw_tensors: dict[str, torch.Tensor] = {} + packed_backing: torch.Tensor | None = None for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - tensor = torch.zeros( - kv_cache_tensor.size, dtype=torch.int8, device=self.device - ) + if kv_cache_tensor.block_stride > 0: + if packed_backing is None: + packed_backing = torch.zeros( + kv_cache_tensor.size, dtype=torch.int8, device=self.device + ) + backing = packed_backing + else: + backing = torch.zeros( + kv_cache_tensor.size, dtype=torch.int8, device=self.device + ) for layer_name in kv_cache_tensor.shared_by: - kv_cache_raw_tensors[layer_name] = tensor + kv_cache_raw_tensors[layer_name] = backing layer_names = set() for group in kv_cache_config.kv_cache_groups: @@ -7023,7 +7031,7 @@ class GPUModelRunner( assert layer_names == set(kv_cache_raw_tensors.keys()), ( "Some layers are not correctly initialized" ) - return kv_cache_raw_tensors + return kv_cache_raw_tensors, packed_backing def _attn_group_iterator(self) -> Iterator[AttentionGroup]: return itertools.chain.from_iterable(self.attn_groups) @@ -7052,6 +7060,12 @@ class GPUModelRunner( """ kv_caches: dict[str, torch.Tensor] = {} has_attn, has_mamba = False, False + + layer_packing: dict[str, tuple[int, int]] = {} + for kv_tensor in self.kv_cache_config.kv_cache_tensors: + if kv_tensor.block_stride > 0: + for ln in kv_tensor.shared_by: + layer_packing[ln] = (kv_tensor.offset, kv_tensor.block_stride) for group in self._kv_cache_spec_attn_group_iterator(): kv_cache_spec = group.kv_cache_spec attn_backend = group.backend @@ -7063,8 +7077,13 @@ class GPUModelRunner( if layer_name in self.runner_only_attn_layers: continue raw_tensor = kv_cache_raw_tensors[layer_name] - assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 - num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes + packing = layer_packing.get(layer_name) + if packing is not None: + _, blk_stride = packing + num_blocks = raw_tensor.numel() // blk_stride + else: + assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 + num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes if isinstance(kv_cache_spec, AttentionSpec): has_attn = True num_blocks_per_kv_block = ( @@ -7106,15 +7125,19 @@ class GPUModelRunner( ] raw_tensor = kv_cache_raw_tensors[layer_name].view(dtype) - if kv_cache_spec.page_size_padded is not None: - # Use strided view to handle page_size_bytes that - # include padding. This follows - # the same pattern as MambaSpec handling below. - # NOTE: This assumes kv_cache_shape[0] == num_blocks - # (i.e. the first physical dimension is the block - # index), which holds for MLA backends but NOT for - # standard attention backends whose shape starts with - # a K/V dimension of size 2. + if packing is not None: + layer_offset, blk_stride = packing + dtype_size = get_dtype_size(dtype) + page_stride = blk_stride // dtype_size + strides = list(torch.empty(kv_cache_shape).stride()) + strides[inv_order[0]] = page_stride + kv_cache = torch.as_strided( + raw_tensor, + size=kv_cache_shape, + stride=tuple(strides), + storage_offset=layer_offset // dtype_size, + ) + elif kv_cache_spec.page_size_padded is not None: dtype_size = get_dtype_size(dtype) page_stride = kv_cache_spec.page_size_bytes // dtype_size strides = list(torch.empty(kv_cache_shape).stride()) @@ -7125,7 +7148,6 @@ class GPUModelRunner( stride=tuple(strides), ) else: - # No padding — safe to use a contiguous view. kv_cache = raw_tensor.view(kv_cache_shape) kv_caches[layer_name] = kv_cache.permute(*inv_order) @@ -7227,13 +7249,24 @@ class GPUModelRunner( else: # Fallback to the general case # Initialize the memory buffer for KV cache - kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) + kv_cache_raw_tensors, packed_backing = self._allocate_kv_cache_tensors( + kv_cache_config + ) # Change the memory buffer to the desired shape kv_caches = self._reshape_kv_cache_tensors( kv_cache_raw_tensors, kernel_block_sizes ) + if packed_backing is not None: + block_stride = next( + t.block_stride + for t in kv_cache_config.kv_cache_tensors + if t.block_stride > 0 + ) + self.cross_layers_kv_cache = packed_backing + self._packed_block_stride = block_stride + # Set up cross-layer KV cache sharing for layer_name, target_layer_name in self.shared_kv_cache_layers.items(): logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name) @@ -7329,9 +7362,11 @@ class GPUModelRunner( if has_kv_transfer_group() and not is_profiling: kv_transfer_group = get_kv_transfer_group() if self.cross_layers_kv_cache is not None: - assert self.cross_layers_attn_backend is not None + block_stride = getattr(self, "_packed_block_stride", None) kv_transfer_group.register_cross_layers_kv_cache( - self.cross_layers_kv_cache, self.cross_layers_attn_backend + self.cross_layers_kv_cache, + self.cross_layers_attn_backend, + block_stride=block_stride, ) else: kv_transfer_group.register_kv_caches(kv_caches)