[ROCm][perf] Use workspace manager for sparse indexer allocations (#41002)

Signed-off-by: Stig-Arne Grönroos <stig-arne.gronroos@amd.com>
Signed-off-by: Tuukka Sarvi <tuukka.sarvi@amd.com>
Co-authored-by: Stig-Arne Grönroos <stig-arne.gronroos@amd.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
Tuukka Sarvi
2026-06-05 09:46:29 +03:00
committed by GitHub
parent 165b7864d0
commit b4a6f26c90
+55 -24
View File
@@ -8,6 +8,7 @@ from importlib.util import find_spec
import torch
import torch.nn.functional as F
import vllm.envs as envs
from vllm.compilation.breakable_cudagraph import eager_break_during_capture
from vllm.forward_context import get_forward_context
from vllm.platforms import current_platform
@@ -15,6 +16,7 @@ from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import LayerNameType
from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
from vllm.v1.worker.workspace import current_workspace_manager
if current_platform.is_rocm():
from vllm.platforms.rocm import _ON_GFX942, _ON_GFX950
@@ -408,8 +410,8 @@ def rocm_fp8_paged_mqa_logits(
aiter_paged_mqa_logits_module = None
# if rocm_aiter_ops.is_enabled():
batch_size, next_n, heads, head_dim = q_fp8.shape
num_blocks, block_size, _, _ = kv_cache_fp8.shape
batch_size, next_n = q_fp8.shape[:2]
block_size = kv_cache_fp8.shape[1]
if rocm_aiter_ops.is_enabled():
aiter_paged_mqa_logits_module = paged_mqa_logits_module()
@@ -420,12 +422,10 @@ def rocm_fp8_paged_mqa_logits(
aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits
)
batch_size, next_n, heads, _ = q_fp8.shape
out_logits = torch.full(
[batch_size * next_n, max_model_len],
float("-inf"),
device="cuda",
dtype=torch.float32,
(out_logits,) = current_workspace_manager().get_simultaneous(
((batch_size * next_n, max_model_len), torch.float32),
)
out_logits.fill_(float("-inf"))
deepgemm_fp8_paged_mqa_logits(
q_fp8,
kv_cache_fp8,
@@ -444,12 +444,10 @@ def rocm_fp8_paged_mqa_logits(
aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits_stage1
)
batch_size, next_n, heads, _ = q_fp8.shape
out_qk = torch.full(
(heads, batch_size * next_n, max_model_len),
float("-inf"),
device="cuda",
dtype=torch.float32,
(out_qk,) = current_workspace_manager().get_simultaneous(
((heads, batch_size * next_n, max_model_len), torch.float32),
)
out_qk.fill_(float("-inf"))
deepgemm_fp8_paged_mqa_logits_stage1(
q_fp8,
kv_cache_fp8,
@@ -647,6 +645,43 @@ def rocm_aiter_sparse_attn_indexer(
k_cache_prefix = _resolve_layer_name(k_cache_prefix)
# assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict):
# Profiling early-exit: reserve memory to account for runtime
# allocations. Must be in the real impl, not the fake impl —
# torch.compile calls the fake impl under FakeTensor mode where
# workspace manager operations on the locked real workspace
# would corrupt PyTorch's dispatch state.
workspace_manager = current_workspace_manager()
# Prefill k_fp8 and k_scale buffers, used by
# rocm_aiter_sparse_attn_indexer's prefill path
workspace_manager.get_simultaneous(
((total_seq_lens, head_dim), fp8_dtype),
((total_seq_lens, 4), torch.uint8),
)
# Decode logits buffer, used by rocm_fp8_paged_mqa_logits.
# batch_size * next_n <= hidden_states.shape[0] == max_num_batched_tokens
if _ON_GFX942 or _ON_GFX950:
workspace_manager.get_simultaneous(
((hidden_states.shape[0], max_model_len), torch.float32),
)
else:
workspace_manager.get_simultaneous(
(
(q_fp8.shape[1], hidden_states.shape[0], max_model_len),
torch.float32,
),
)
# Transient logits tensor peak memory, produced by
# rocm_fp8_mqa_logits (prefill) and rocm_fp8_paged_mqa_logits
# (decode). Prefill logits are bounded by
# VLLM_SPARSE_INDEXER_MAX_LOGITS_MB via chunking in
# split_indexer_prefill_chunks; decode logits are smaller.
max_logits_elems = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024
_ = torch.empty(
max_logits_elems, dtype=torch.uint8, device=hidden_states.device
)
return rocm_aiter_sparse_attn_indexer_fake(
hidden_states,
k_cache_prefix,
@@ -671,7 +706,6 @@ def rocm_aiter_sparse_attn_indexer(
has_decode = layer_attn_metadata.num_decodes > 0
has_prefill = layer_attn_metadata.num_prefills > 0
num_decode_tokens = layer_attn_metadata.num_decode_tokens
device = hidden_states.device if k is None else k.device
# during speculative decoding, k may be padded to the CUDA graph batch
# size while slot_mapping only covers actual tokens.
@@ -703,17 +737,15 @@ def rocm_aiter_sparse_attn_indexer(
if has_prefill:
prefill_metadata = layer_attn_metadata.prefill
assert prefill_metadata is not None
workspace_manager = current_workspace_manager()
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
((total_seq_lens, head_dim), fp8_dtype),
((total_seq_lens, 4), torch.uint8),
)
for chunk in prefill_metadata.chunks:
k_fp8 = torch.empty(
[chunk.total_seq_lens, head_dim],
device=device,
dtype=fp8_dtype,
)
k_scale = torch.empty(
[chunk.total_seq_lens, 4],
device=device,
dtype=torch.uint8,
)
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
if _ON_GFX942:
ops.cp_gather_indexer_k_quant_cache(
kv_cache,
@@ -731,7 +763,6 @@ def rocm_aiter_sparse_attn_indexer(
chunk.cu_seq_lens,
token_to_seq=chunk.token_to_seq,
)
logits = rocm_fp8_mqa_logits(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32)),