From b4a6f26c904c4dc1653c1bb623c4cb8186347175 Mon Sep 17 00:00:00 2001 From: Tuukka Sarvi Date: Fri, 5 Jun 2026 09:46:29 +0300 Subject: [PATCH] [ROCm][perf] Use workspace manager for sparse indexer allocations (#41002) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Stig-Arne Grönroos Signed-off-by: Tuukka Sarvi Co-authored-by: Stig-Arne Grönroos Co-authored-by: TJian --- .../v1/attention/ops/rocm_aiter_mla_sparse.py | 79 +++++++++++++------ 1 file changed, 55 insertions(+), 24 deletions(-) diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 332350d8380..12fd3a17421 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -8,6 +8,7 @@ from importlib.util import find_spec import torch import torch.nn.functional as F +import vllm.envs as envs from vllm.compilation.breakable_cudagraph import eager_break_during_capture from vllm.forward_context import get_forward_context from vllm.platforms import current_platform @@ -15,6 +16,7 @@ from vllm.triton_utils import tl, triton from vllm.utils.torch_utils import LayerNameType from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton +from vllm.v1.worker.workspace import current_workspace_manager if current_platform.is_rocm(): from vllm.platforms.rocm import _ON_GFX942, _ON_GFX950 @@ -408,8 +410,8 @@ def rocm_fp8_paged_mqa_logits( aiter_paged_mqa_logits_module = None # if rocm_aiter_ops.is_enabled(): - batch_size, next_n, heads, head_dim = q_fp8.shape - num_blocks, block_size, _, _ = kv_cache_fp8.shape + batch_size, next_n = q_fp8.shape[:2] + block_size = kv_cache_fp8.shape[1] if rocm_aiter_ops.is_enabled(): aiter_paged_mqa_logits_module = paged_mqa_logits_module() @@ -420,12 +422,10 @@ def rocm_fp8_paged_mqa_logits( aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits ) batch_size, next_n, heads, _ = q_fp8.shape - out_logits = torch.full( - [batch_size * next_n, max_model_len], - float("-inf"), - device="cuda", - dtype=torch.float32, + (out_logits,) = current_workspace_manager().get_simultaneous( + ((batch_size * next_n, max_model_len), torch.float32), ) + out_logits.fill_(float("-inf")) deepgemm_fp8_paged_mqa_logits( q_fp8, kv_cache_fp8, @@ -444,12 +444,10 @@ def rocm_fp8_paged_mqa_logits( aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits_stage1 ) batch_size, next_n, heads, _ = q_fp8.shape - out_qk = torch.full( - (heads, batch_size * next_n, max_model_len), - float("-inf"), - device="cuda", - dtype=torch.float32, + (out_qk,) = current_workspace_manager().get_simultaneous( + ((heads, batch_size * next_n, max_model_len), torch.float32), ) + out_qk.fill_(float("-inf")) deepgemm_fp8_paged_mqa_logits_stage1( q_fp8, kv_cache_fp8, @@ -647,6 +645,43 @@ def rocm_aiter_sparse_attn_indexer( k_cache_prefix = _resolve_layer_name(k_cache_prefix) # assert isinstance(attn_metadata, dict) if not isinstance(attn_metadata, dict): + # Profiling early-exit: reserve memory to account for runtime + # allocations. Must be in the real impl, not the fake impl — + # torch.compile calls the fake impl under FakeTensor mode where + # workspace manager operations on the locked real workspace + # would corrupt PyTorch's dispatch state. + workspace_manager = current_workspace_manager() + + # Prefill k_fp8 and k_scale buffers, used by + # rocm_aiter_sparse_attn_indexer's prefill path + workspace_manager.get_simultaneous( + ((total_seq_lens, head_dim), fp8_dtype), + ((total_seq_lens, 4), torch.uint8), + ) + + # Decode logits buffer, used by rocm_fp8_paged_mqa_logits. + # batch_size * next_n <= hidden_states.shape[0] == max_num_batched_tokens + if _ON_GFX942 or _ON_GFX950: + workspace_manager.get_simultaneous( + ((hidden_states.shape[0], max_model_len), torch.float32), + ) + else: + workspace_manager.get_simultaneous( + ( + (q_fp8.shape[1], hidden_states.shape[0], max_model_len), + torch.float32, + ), + ) + # Transient logits tensor peak memory, produced by + # rocm_fp8_mqa_logits (prefill) and rocm_fp8_paged_mqa_logits + # (decode). Prefill logits are bounded by + # VLLM_SPARSE_INDEXER_MAX_LOGITS_MB via chunking in + # split_indexer_prefill_chunks; decode logits are smaller. + max_logits_elems = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024 + _ = torch.empty( + max_logits_elems, dtype=torch.uint8, device=hidden_states.device + ) + return rocm_aiter_sparse_attn_indexer_fake( hidden_states, k_cache_prefix, @@ -671,7 +706,6 @@ def rocm_aiter_sparse_attn_indexer( has_decode = layer_attn_metadata.num_decodes > 0 has_prefill = layer_attn_metadata.num_prefills > 0 num_decode_tokens = layer_attn_metadata.num_decode_tokens - device = hidden_states.device if k is None else k.device # during speculative decoding, k may be padded to the CUDA graph batch # size while slot_mapping only covers actual tokens. @@ -703,17 +737,15 @@ def rocm_aiter_sparse_attn_indexer( if has_prefill: prefill_metadata = layer_attn_metadata.prefill assert prefill_metadata is not None + + workspace_manager = current_workspace_manager() + k_fp8_full, k_scale_full = workspace_manager.get_simultaneous( + ((total_seq_lens, head_dim), fp8_dtype), + ((total_seq_lens, 4), torch.uint8), + ) for chunk in prefill_metadata.chunks: - k_fp8 = torch.empty( - [chunk.total_seq_lens, head_dim], - device=device, - dtype=fp8_dtype, - ) - k_scale = torch.empty( - [chunk.total_seq_lens, 4], - device=device, - dtype=torch.uint8, - ) + k_fp8 = k_fp8_full[: chunk.total_seq_lens] + k_scale = k_scale_full[: chunk.total_seq_lens] if _ON_GFX942: ops.cp_gather_indexer_k_quant_cache( kv_cache, @@ -731,7 +763,6 @@ def rocm_aiter_sparse_attn_indexer( chunk.cu_seq_lens, token_to_seq=chunk.token_to_seq, ) - logits = rocm_fp8_mqa_logits( q_fp8[chunk.token_start : chunk.token_end], (k_fp8, k_scale.view(torch.float32)),