mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Attention][AMD] Standardize kv layout to blocks first for AMD (#43660)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -362,6 +362,7 @@ def flash_attn_triton_available() -> bool:
|
||||
def _get_backend_priorities(
|
||||
use_mla: bool,
|
||||
use_sparse: bool,
|
||||
use_kv_connector: bool = False,
|
||||
) -> list[AttentionBackendEnum]:
|
||||
from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops
|
||||
|
||||
@@ -380,9 +381,11 @@ def _get_backend_priorities(
|
||||
AttentionBackendEnum.TRITON_MLA,
|
||||
]
|
||||
|
||||
backends = [
|
||||
AttentionBackendEnum.ROCM_ATTN,
|
||||
]
|
||||
backends = []
|
||||
# ROCM_ATTN uses (2, num_blocks, ...) KV cache layout which is
|
||||
# incompatible with KV connectors that require blocks-first layout.
|
||||
if not use_kv_connector:
|
||||
backends.append(AttentionBackendEnum.ROCM_ATTN)
|
||||
if rocm_aiter_ops.is_mha_enabled():
|
||||
backends.append(AttentionBackendEnum.ROCM_AITER_FA)
|
||||
if is_aiter_found_and_supported():
|
||||
@@ -461,6 +464,7 @@ class RocmPlatform(Platform):
|
||||
backend_priorities = _get_backend_priorities(
|
||||
attn_selector_config.use_mla,
|
||||
attn_selector_config.use_sparse,
|
||||
attn_selector_config.use_kv_connector,
|
||||
)
|
||||
for priority, backend in enumerate(backend_priorities):
|
||||
try:
|
||||
|
||||
@@ -240,6 +240,10 @@ class AttentionBackend(ABC):
|
||||
def supports_batch_invariance(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def supports_kv_connector(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""Check if backend supports a given attention type.
|
||||
@@ -283,6 +287,7 @@ class AttentionBackend(ABC):
|
||||
attn_type: str,
|
||||
use_non_causal: bool = False,
|
||||
use_batch_invariant: bool = False,
|
||||
use_kv_connector: bool = False,
|
||||
) -> list[str]:
|
||||
invalid_reasons = []
|
||||
if not cls.supports_head_size(head_size):
|
||||
@@ -319,6 +324,8 @@ class AttentionBackend(ABC):
|
||||
invalid_reasons.append("non-causal attention not supported")
|
||||
if use_batch_invariant and not cls.supports_batch_invariance():
|
||||
invalid_reasons.append("batch invariance not supported")
|
||||
if use_kv_connector and not cls.supports_kv_connector():
|
||||
invalid_reasons.append("KV connector not supported")
|
||||
combination_reason = cls.supports_combination(
|
||||
head_size,
|
||||
dtype,
|
||||
|
||||
@@ -225,6 +225,8 @@ if current_platform.is_rocm():
|
||||
x,
|
||||
k_stride0,
|
||||
v_stride0,
|
||||
k_cache_block_stride,
|
||||
v_cache_block_stride,
|
||||
block_size,
|
||||
head_size,
|
||||
num_kv_heads,
|
||||
@@ -241,15 +243,17 @@ if current_platform.is_rocm():
|
||||
return
|
||||
block_id = slot_id // block_size
|
||||
block_offset = slot_id % block_size
|
||||
dst_offset = (
|
||||
block_id * num_kv_heads * head_size * block_size
|
||||
+ head_id * head_size * block_size
|
||||
k_dst_offset = (
|
||||
block_id * k_cache_block_stride + head_id * head_size * block_size
|
||||
)
|
||||
dst_k_shuffle_offset = (
|
||||
dst_offset + offset // x * block_size * x + block_offset * x + offset % x
|
||||
k_dst_offset + offset // x * block_size * x + block_offset * x + offset % x
|
||||
)
|
||||
v_dst_offset = (
|
||||
block_id * v_cache_block_stride + head_id * head_size * block_size
|
||||
)
|
||||
dst_v_shuffle_offset = (
|
||||
dst_offset
|
||||
v_dst_offset
|
||||
+ block_offset // x * head_size * x
|
||||
+ offset * x
|
||||
+ block_offset % x
|
||||
@@ -280,18 +284,6 @@ if current_platform.is_rocm():
|
||||
_, num_kv_heads, head_size = key.shape
|
||||
num_blocks, block_size, _, _ = key_cache.shape
|
||||
x = 16 // key_cache.element_size()
|
||||
k_cache_template = torch.empty(
|
||||
[num_blocks, num_kv_heads, head_size // x, block_size, x],
|
||||
dtype=key_cache.dtype,
|
||||
device="meta",
|
||||
)
|
||||
v_cache_template = torch.empty(
|
||||
[num_blocks, num_kv_heads, block_size // x, head_size, x],
|
||||
dtype=value_cache.dtype,
|
||||
device="meta",
|
||||
)
|
||||
new_key_cache = key_cache.view_as(k_cache_template)
|
||||
new_value_cache = value_cache.view_as(v_cache_template)
|
||||
QUANT = False
|
||||
if is_quantized_kv_cache(kv_cache_dtype):
|
||||
QUANT = True
|
||||
@@ -302,14 +294,16 @@ if current_platform.is_rocm():
|
||||
reshape_and_cache_shuffle_kernel[grid](
|
||||
key,
|
||||
value,
|
||||
new_key_cache,
|
||||
new_value_cache,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
k_scales,
|
||||
v_scales,
|
||||
x,
|
||||
key.stride(0),
|
||||
value.stride(0),
|
||||
key_cache.stride(0),
|
||||
value_cache.stride(0),
|
||||
block_size,
|
||||
head_size,
|
||||
num_kv_heads,
|
||||
@@ -485,7 +479,7 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
kv_cache_shape = self.vllm_config.compilation_config.static_forward_context[
|
||||
first_layer_name
|
||||
].kv_cache.shape
|
||||
num_blocks = kv_cache_shape[1]
|
||||
num_blocks = kv_cache_shape[0]
|
||||
self.scale = torch.ones(
|
||||
[num_blocks, self.num_heads_kv, self.block_size],
|
||||
dtype=torch.float32,
|
||||
@@ -765,7 +759,7 @@ class AiterFlashAttentionBackend(AttentionBackend):
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
return (num_blocks, 2, block_size, num_kv_heads, head_size)
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
@@ -1025,7 +1019,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
[num_blocks, 2, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
@@ -1053,7 +1047,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
key_cache, value_cache = kv_cache.unbind(1)
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||
@@ -1278,18 +1272,12 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
num_blocks, block_size, num_kv_heads, _ = key_cache.shape
|
||||
x = 16 // key_cache.element_size()
|
||||
k_cache_template = torch.empty(
|
||||
[num_blocks, num_kv_heads, head_size // x, block_size, x],
|
||||
dtype=key_cache.dtype,
|
||||
device="meta",
|
||||
new_key_cache = key_cache.reshape(
|
||||
num_blocks, num_kv_heads, head_size // x, block_size, x
|
||||
)
|
||||
v_cache_template = torch.empty(
|
||||
[num_blocks, num_kv_heads, block_size // x, head_size, x],
|
||||
dtype=value_cache.dtype,
|
||||
device="meta",
|
||||
new_value_cache = value_cache.reshape(
|
||||
num_blocks, num_kv_heads, block_size // x, head_size, x
|
||||
)
|
||||
new_key_cache = key_cache.view_as(k_cache_template)
|
||||
new_value_cache = value_cache.view_as(v_cache_template)
|
||||
k_qscale = (
|
||||
layer._k_scale
|
||||
if attn_metadata.k_scale is None
|
||||
@@ -1378,7 +1366,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
):
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
key_cache, value_cache = kv_cache.unbind(1)
|
||||
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
@@ -1446,7 +1434,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
layer_slot_mapping: torch.Tensor,
|
||||
):
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
key_cache, value_cache = kv_cache.unbind(1)
|
||||
flash_layout = True
|
||||
|
||||
is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
|
||||
|
||||
@@ -74,7 +74,7 @@ class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
return (num_blocks, 2, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
@@ -153,7 +153,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
[num_blocks, 2, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
@@ -194,7 +194,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
layer,
|
||||
)
|
||||
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
key_cache, value_cache = kv_cache.unbind(1)
|
||||
|
||||
softmax_scale = self.scale
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
@@ -243,7 +243,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
# For encoder attention,
|
||||
# we use direct Q, K, V tensors without caching
|
||||
return
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
key_cache, value_cache = kv_cache.unbind(1)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
ops.reshape_and_cache_flash(
|
||||
@@ -276,7 +276,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
# For encoder attention,
|
||||
# we use direct Q, K, V tensors without caching
|
||||
return
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
key_cache, value_cache = kv_cache.unbind(1)
|
||||
flash_layout = True
|
||||
|
||||
is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
|
||||
|
||||
@@ -208,6 +208,12 @@ class RocmAttentionBackend(AttentionBackend):
|
||||
def supports_non_causal(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_kv_connector(cls) -> bool:
|
||||
# ROCM_ATTN uses (2, num_blocks, ...) KV cache layout which is
|
||||
# incompatible with KV connectors that require blocks-first layout.
|
||||
return False
|
||||
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -31,6 +31,7 @@ class AttentionSelectorConfig(NamedTuple):
|
||||
attn_type: str = AttentionType.DECODER
|
||||
use_non_causal: bool = False
|
||||
use_batch_invariant: bool = False
|
||||
use_kv_connector: bool = False
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
@@ -45,7 +46,8 @@ class AttentionSelectorConfig(NamedTuple):
|
||||
f"use_per_head_quant_scales={self.use_per_head_quant_scales}, "
|
||||
f"attn_type={self.attn_type}, "
|
||||
f"use_non_causal={self.use_non_causal}, "
|
||||
f"use_batch_invariant={self.use_batch_invariant})"
|
||||
f"use_batch_invariant={self.use_batch_invariant}, "
|
||||
f"use_kv_connector={self.use_kv_connector})"
|
||||
)
|
||||
|
||||
|
||||
@@ -80,6 +82,11 @@ def get_attn_backend(
|
||||
else:
|
||||
block_size = None
|
||||
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
use_kv_connector = (
|
||||
kv_transfer_config is not None and kv_transfer_config.is_kv_transfer_instance
|
||||
)
|
||||
|
||||
attn_selector_config = AttentionSelectorConfig(
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
@@ -93,6 +100,7 @@ def get_attn_backend(
|
||||
attn_type=attn_type or AttentionType.DECODER,
|
||||
use_non_causal=vllm_config.attention_config.use_non_causal,
|
||||
use_batch_invariant=envs.VLLM_BATCH_INVARIANT,
|
||||
use_kv_connector=use_kv_connector,
|
||||
)
|
||||
|
||||
return _cached_get_attn_backend(
|
||||
|
||||
Reference in New Issue
Block a user