From 823d271c0dc7dce43f6a9716be3c99aaf82ced37 Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Wed, 3 Jun 2026 19:03:09 +0800 Subject: [PATCH] [Attention][CPU] Standardize kv layout to blocks first (#44393) Signed-off-by: jiang1.li --- tests/kernels/attention/test_cpu_attn.py | 9 ++++++--- vllm/v1/attention/backends/cpu_attn.py | 12 ++++++++---- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/tests/kernels/attention/test_cpu_attn.py b/tests/kernels/attention/test_cpu_attn.py index 6af1bfe1e7a..c3939502551 100644 --- a/tests/kernels/attention/test_cpu_attn.py +++ b/tests/kernels/attention/test_cpu_attn.py @@ -258,10 +258,13 @@ def varlen_with_paged_kv( # KV cache for CPU attention cache_dtype = torch.uint8 if is_fp8 else dtype - packed_key_cache = torch.empty( - num_blocks, num_kv_heads, block_size, head_size, dtype=cache_dtype + packed_key_value_cache = torch.empty( + num_blocks, num_kv_heads, block_size, head_size * 2, dtype=cache_dtype ) - packed_value_cache = torch.empty_like(packed_key_cache) + packed_key_value_cache = packed_key_value_cache.view( + (num_blocks, num_kv_heads, block_size * 2, -1) + ) + packed_key_cache, packed_value_cache = packed_key_value_cache.chunk(2, dim=2) cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum( dim=0, dtype=torch.int32 diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 005975c4775..3519691a3c5 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -93,7 +93,7 @@ class CPUAttentionBackend(AttentionBackend): head_size: int, cache_dtype_str: str = "auto", ) -> tuple[int, ...]: - return 2, num_blocks, num_kv_heads, block_size, head_size + return num_blocks, num_kv_heads, block_size, 2 * head_size @classmethod def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None": @@ -308,7 +308,7 @@ class CPUAttentionBackendImpl(AttentionImpl): key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] kv_cache: shape = - [2, num_blocks, num_kv_heads, block_size, head_size] + [num_blocks, num_kv_heads, block_size, 2 * head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -338,8 +338,12 @@ class CPUAttentionBackendImpl(AttentionImpl): ) # For decoder and cross-attention, use KV cache, size are - # [num_blocks, num_kv_heads, block_size, head_size] - key_cache, value_cache = kv_cache.unbind(0) + # [num_blocks, num_kv_heads, block_size, 2 * head_size] + # Make a view [num_blocks, num_kv_heads, block_size * 2, head_size] + # Then slice KV at dim 2 + num_blocks, num_kv_heads, block_size, _ = kv_cache.size() + kv_cache = kv_cache.view((num_blocks, num_kv_heads, block_size * 2, -1)) + key_cache, value_cache = kv_cache.chunk(2, dim=2) # key and value may be None in the case of cross attention. They are # calculated once based on the output from the encoder and then cached