[Attention][AMD] Standardize kv layout to blocks first for AMD (#43660)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2026-05-28 19:28:50 +02:00
committed by GitHub
parent 53a2088675
commit 5b115bb8a3
6 changed files with 57 additions and 44 deletions
+7 -3
View File
@@ -362,6 +362,7 @@ def flash_attn_triton_available() -> bool:
def _get_backend_priorities(
use_mla: bool,
use_sparse: bool,
use_kv_connector: bool = False,
) -> list[AttentionBackendEnum]:
from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops
@@ -380,9 +381,11 @@ def _get_backend_priorities(
AttentionBackendEnum.TRITON_MLA,
]
backends = [
AttentionBackendEnum.ROCM_ATTN,
]
backends = []
# ROCM_ATTN uses (2, num_blocks, ...) KV cache layout which is
# incompatible with KV connectors that require blocks-first layout.
if not use_kv_connector:
backends.append(AttentionBackendEnum.ROCM_ATTN)
if rocm_aiter_ops.is_mha_enabled():
backends.append(AttentionBackendEnum.ROCM_AITER_FA)
if is_aiter_found_and_supported():
@@ -461,6 +464,7 @@ class RocmPlatform(Platform):
backend_priorities = _get_backend_priorities(
attn_selector_config.use_mla,
attn_selector_config.use_sparse,
attn_selector_config.use_kv_connector,
)
for priority, backend in enumerate(backend_priorities):
try:
+7
View File
@@ -240,6 +240,10 @@ class AttentionBackend(ABC):
def supports_batch_invariance(cls) -> bool:
return False
@classmethod
def supports_kv_connector(cls) -> bool:
return True
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""Check if backend supports a given attention type.
@@ -283,6 +287,7 @@ class AttentionBackend(ABC):
attn_type: str,
use_non_causal: bool = False,
use_batch_invariant: bool = False,
use_kv_connector: bool = False,
) -> list[str]:
invalid_reasons = []
if not cls.supports_head_size(head_size):
@@ -319,6 +324,8 @@ class AttentionBackend(ABC):
invalid_reasons.append("non-causal attention not supported")
if use_batch_invariant and not cls.supports_batch_invariance():
invalid_reasons.append("batch invariance not supported")
if use_kv_connector and not cls.supports_kv_connector():
invalid_reasons.append("KV connector not supported")
combination_reason = cls.supports_combination(
head_size,
dtype,
+23 -35
View File
@@ -225,6 +225,8 @@ if current_platform.is_rocm():
x,
k_stride0,
v_stride0,
k_cache_block_stride,
v_cache_block_stride,
block_size,
head_size,
num_kv_heads,
@@ -241,15 +243,17 @@ if current_platform.is_rocm():
return
block_id = slot_id // block_size
block_offset = slot_id % block_size
dst_offset = (
block_id * num_kv_heads * head_size * block_size
+ head_id * head_size * block_size
k_dst_offset = (
block_id * k_cache_block_stride + head_id * head_size * block_size
)
dst_k_shuffle_offset = (
dst_offset + offset // x * block_size * x + block_offset * x + offset % x
k_dst_offset + offset // x * block_size * x + block_offset * x + offset % x
)
v_dst_offset = (
block_id * v_cache_block_stride + head_id * head_size * block_size
)
dst_v_shuffle_offset = (
dst_offset
v_dst_offset
+ block_offset // x * head_size * x
+ offset * x
+ block_offset % x
@@ -280,18 +284,6 @@ if current_platform.is_rocm():
_, num_kv_heads, head_size = key.shape
num_blocks, block_size, _, _ = key_cache.shape
x = 16 // key_cache.element_size()
k_cache_template = torch.empty(
[num_blocks, num_kv_heads, head_size // x, block_size, x],
dtype=key_cache.dtype,
device="meta",
)
v_cache_template = torch.empty(
[num_blocks, num_kv_heads, block_size // x, head_size, x],
dtype=value_cache.dtype,
device="meta",
)
new_key_cache = key_cache.view_as(k_cache_template)
new_value_cache = value_cache.view_as(v_cache_template)
QUANT = False
if is_quantized_kv_cache(kv_cache_dtype):
QUANT = True
@@ -302,14 +294,16 @@ if current_platform.is_rocm():
reshape_and_cache_shuffle_kernel[grid](
key,
value,
new_key_cache,
new_value_cache,
key_cache,
value_cache,
slot_mapping,
k_scales,
v_scales,
x,
key.stride(0),
value.stride(0),
key_cache.stride(0),
value_cache.stride(0),
block_size,
head_size,
num_kv_heads,
@@ -485,7 +479,7 @@ class AiterFlashAttentionMetadataBuilder(
kv_cache_shape = self.vllm_config.compilation_config.static_forward_context[
first_layer_name
].kv_cache.shape
num_blocks = kv_cache_shape[1]
num_blocks = kv_cache_shape[0]
self.scale = torch.ones(
[num_blocks, self.num_heads_kv, self.block_size],
dtype=torch.float32,
@@ -765,7 +759,7 @@ class AiterFlashAttentionBackend(AttentionBackend):
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
return (num_blocks, 2, block_size, num_kv_heads, head_size)
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
@@ -1025,7 +1019,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
[num_blocks, 2, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
@@ -1053,7 +1047,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(0)
key_cache, value_cache = kv_cache.unbind(1)
if is_quantized_kv_cache(self.kv_cache_dtype):
key_cache = key_cache.view(current_platform.fp8_dtype())
@@ -1278,18 +1272,12 @@ class AiterFlashAttentionImpl(AttentionImpl):
max_logits = torch.empty_like(exp_sums)
num_blocks, block_size, num_kv_heads, _ = key_cache.shape
x = 16 // key_cache.element_size()
k_cache_template = torch.empty(
[num_blocks, num_kv_heads, head_size // x, block_size, x],
dtype=key_cache.dtype,
device="meta",
new_key_cache = key_cache.reshape(
num_blocks, num_kv_heads, head_size // x, block_size, x
)
v_cache_template = torch.empty(
[num_blocks, num_kv_heads, block_size // x, head_size, x],
dtype=value_cache.dtype,
device="meta",
new_value_cache = value_cache.reshape(
num_blocks, num_kv_heads, block_size // x, head_size, x
)
new_key_cache = key_cache.view_as(k_cache_template)
new_value_cache = value_cache.view_as(v_cache_template)
k_qscale = (
layer._k_scale
if attn_metadata.k_scale is None
@@ -1378,7 +1366,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
key_cache, value_cache = kv_cache.unbind(0)
key_cache, value_cache = kv_cache.unbind(1)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
@@ -1446,7 +1434,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
layer_slot_mapping: torch.Tensor,
):
key_cache, value_cache = kv_cache.unbind(0)
key_cache, value_cache = kv_cache.unbind(1)
flash_layout = True
is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
@@ -74,7 +74,7 @@ class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
return (num_blocks, 2, block_size, num_kv_heads, head_size)
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
@@ -153,7 +153,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
[num_blocks, 2, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
@@ -194,7 +194,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
layer,
)
key_cache, value_cache = kv_cache.unbind(0)
key_cache, value_cache = kv_cache.unbind(1)
softmax_scale = self.scale
if is_quantized_kv_cache(self.kv_cache_dtype):
@@ -243,7 +243,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return
key_cache, value_cache = kv_cache.unbind(0)
key_cache, value_cache = kv_cache.unbind(1)
# Reshape the input keys and values and store them in the cache.
ops.reshape_and_cache_flash(
@@ -276,7 +276,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return
key_cache, value_cache = kv_cache.unbind(0)
key_cache, value_cache = kv_cache.unbind(1)
flash_layout = True
is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
+6
View File
@@ -208,6 +208,12 @@ class RocmAttentionBackend(AttentionBackend):
def supports_non_causal(cls) -> bool:
return True
@classmethod
def supports_kv_connector(cls) -> bool:
# ROCM_ATTN uses (2, num_blocks, ...) KV cache layout which is
# incompatible with KV connectors that require blocks-first layout.
return False
forward_includes_kv_cache_update: bool = False
@staticmethod
+9 -1
View File
@@ -31,6 +31,7 @@ class AttentionSelectorConfig(NamedTuple):
attn_type: str = AttentionType.DECODER
use_non_causal: bool = False
use_batch_invariant: bool = False
use_kv_connector: bool = False
def __repr__(self):
return (
@@ -45,7 +46,8 @@ class AttentionSelectorConfig(NamedTuple):
f"use_per_head_quant_scales={self.use_per_head_quant_scales}, "
f"attn_type={self.attn_type}, "
f"use_non_causal={self.use_non_causal}, "
f"use_batch_invariant={self.use_batch_invariant})"
f"use_batch_invariant={self.use_batch_invariant}, "
f"use_kv_connector={self.use_kv_connector})"
)
@@ -80,6 +82,11 @@ def get_attn_backend(
else:
block_size = None
kv_transfer_config = vllm_config.kv_transfer_config
use_kv_connector = (
kv_transfer_config is not None and kv_transfer_config.is_kv_transfer_instance
)
attn_selector_config = AttentionSelectorConfig(
head_size=head_size,
dtype=dtype,
@@ -93,6 +100,7 @@ def get_attn_backend(
attn_type=attn_type or AttentionType.DECODER,
use_non_causal=vllm_config.attention_config.use_non_causal,
use_batch_invariant=envs.VLLM_BATCH_INVARIANT,
use_kv_connector=use_kv_connector,
)
return _cached_get_attn_backend(