mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[ROCm][perf] Use workspace manager for sparse indexer allocations (#41002)
Signed-off-by: Stig-Arne Grönroos <stig-arne.gronroos@amd.com> Signed-off-by: Tuukka Sarvi <tuukka.sarvi@amd.com> Co-authored-by: Stig-Arne Grönroos <stig-arne.gronroos@amd.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -8,6 +8,7 @@ from importlib.util import find_spec
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.breakable_cudagraph import eager_break_during_capture
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.platforms import current_platform
|
||||
@@ -15,6 +16,7 @@ from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.torch_utils import LayerNameType
|
||||
from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata
|
||||
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
|
||||
from vllm.v1.worker.workspace import current_workspace_manager
|
||||
|
||||
if current_platform.is_rocm():
|
||||
from vllm.platforms.rocm import _ON_GFX942, _ON_GFX950
|
||||
@@ -408,8 +410,8 @@ def rocm_fp8_paged_mqa_logits(
|
||||
|
||||
aiter_paged_mqa_logits_module = None
|
||||
# if rocm_aiter_ops.is_enabled():
|
||||
batch_size, next_n, heads, head_dim = q_fp8.shape
|
||||
num_blocks, block_size, _, _ = kv_cache_fp8.shape
|
||||
batch_size, next_n = q_fp8.shape[:2]
|
||||
block_size = kv_cache_fp8.shape[1]
|
||||
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
aiter_paged_mqa_logits_module = paged_mqa_logits_module()
|
||||
@@ -420,12 +422,10 @@ def rocm_fp8_paged_mqa_logits(
|
||||
aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits
|
||||
)
|
||||
batch_size, next_n, heads, _ = q_fp8.shape
|
||||
out_logits = torch.full(
|
||||
[batch_size * next_n, max_model_len],
|
||||
float("-inf"),
|
||||
device="cuda",
|
||||
dtype=torch.float32,
|
||||
(out_logits,) = current_workspace_manager().get_simultaneous(
|
||||
((batch_size * next_n, max_model_len), torch.float32),
|
||||
)
|
||||
out_logits.fill_(float("-inf"))
|
||||
deepgemm_fp8_paged_mqa_logits(
|
||||
q_fp8,
|
||||
kv_cache_fp8,
|
||||
@@ -444,12 +444,10 @@ def rocm_fp8_paged_mqa_logits(
|
||||
aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits_stage1
|
||||
)
|
||||
batch_size, next_n, heads, _ = q_fp8.shape
|
||||
out_qk = torch.full(
|
||||
(heads, batch_size * next_n, max_model_len),
|
||||
float("-inf"),
|
||||
device="cuda",
|
||||
dtype=torch.float32,
|
||||
(out_qk,) = current_workspace_manager().get_simultaneous(
|
||||
((heads, batch_size * next_n, max_model_len), torch.float32),
|
||||
)
|
||||
out_qk.fill_(float("-inf"))
|
||||
deepgemm_fp8_paged_mqa_logits_stage1(
|
||||
q_fp8,
|
||||
kv_cache_fp8,
|
||||
@@ -647,6 +645,43 @@ def rocm_aiter_sparse_attn_indexer(
|
||||
k_cache_prefix = _resolve_layer_name(k_cache_prefix)
|
||||
# assert isinstance(attn_metadata, dict)
|
||||
if not isinstance(attn_metadata, dict):
|
||||
# Profiling early-exit: reserve memory to account for runtime
|
||||
# allocations. Must be in the real impl, not the fake impl —
|
||||
# torch.compile calls the fake impl under FakeTensor mode where
|
||||
# workspace manager operations on the locked real workspace
|
||||
# would corrupt PyTorch's dispatch state.
|
||||
workspace_manager = current_workspace_manager()
|
||||
|
||||
# Prefill k_fp8 and k_scale buffers, used by
|
||||
# rocm_aiter_sparse_attn_indexer's prefill path
|
||||
workspace_manager.get_simultaneous(
|
||||
((total_seq_lens, head_dim), fp8_dtype),
|
||||
((total_seq_lens, 4), torch.uint8),
|
||||
)
|
||||
|
||||
# Decode logits buffer, used by rocm_fp8_paged_mqa_logits.
|
||||
# batch_size * next_n <= hidden_states.shape[0] == max_num_batched_tokens
|
||||
if _ON_GFX942 or _ON_GFX950:
|
||||
workspace_manager.get_simultaneous(
|
||||
((hidden_states.shape[0], max_model_len), torch.float32),
|
||||
)
|
||||
else:
|
||||
workspace_manager.get_simultaneous(
|
||||
(
|
||||
(q_fp8.shape[1], hidden_states.shape[0], max_model_len),
|
||||
torch.float32,
|
||||
),
|
||||
)
|
||||
# Transient logits tensor peak memory, produced by
|
||||
# rocm_fp8_mqa_logits (prefill) and rocm_fp8_paged_mqa_logits
|
||||
# (decode). Prefill logits are bounded by
|
||||
# VLLM_SPARSE_INDEXER_MAX_LOGITS_MB via chunking in
|
||||
# split_indexer_prefill_chunks; decode logits are smaller.
|
||||
max_logits_elems = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024
|
||||
_ = torch.empty(
|
||||
max_logits_elems, dtype=torch.uint8, device=hidden_states.device
|
||||
)
|
||||
|
||||
return rocm_aiter_sparse_attn_indexer_fake(
|
||||
hidden_states,
|
||||
k_cache_prefix,
|
||||
@@ -671,7 +706,6 @@ def rocm_aiter_sparse_attn_indexer(
|
||||
has_decode = layer_attn_metadata.num_decodes > 0
|
||||
has_prefill = layer_attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = layer_attn_metadata.num_decode_tokens
|
||||
device = hidden_states.device if k is None else k.device
|
||||
|
||||
# during speculative decoding, k may be padded to the CUDA graph batch
|
||||
# size while slot_mapping only covers actual tokens.
|
||||
@@ -703,17 +737,15 @@ def rocm_aiter_sparse_attn_indexer(
|
||||
if has_prefill:
|
||||
prefill_metadata = layer_attn_metadata.prefill
|
||||
assert prefill_metadata is not None
|
||||
|
||||
workspace_manager = current_workspace_manager()
|
||||
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
|
||||
((total_seq_lens, head_dim), fp8_dtype),
|
||||
((total_seq_lens, 4), torch.uint8),
|
||||
)
|
||||
for chunk in prefill_metadata.chunks:
|
||||
k_fp8 = torch.empty(
|
||||
[chunk.total_seq_lens, head_dim],
|
||||
device=device,
|
||||
dtype=fp8_dtype,
|
||||
)
|
||||
k_scale = torch.empty(
|
||||
[chunk.total_seq_lens, 4],
|
||||
device=device,
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
|
||||
k_scale = k_scale_full[: chunk.total_seq_lens]
|
||||
if _ON_GFX942:
|
||||
ops.cp_gather_indexer_k_quant_cache(
|
||||
kv_cache,
|
||||
@@ -731,7 +763,6 @@ def rocm_aiter_sparse_attn_indexer(
|
||||
chunk.cu_seq_lens,
|
||||
token_to_seq=chunk.token_to_seq,
|
||||
)
|
||||
|
||||
logits = rocm_fp8_mqa_logits(
|
||||
q_fp8[chunk.token_start : chunk.token_end],
|
||||
(k_fp8, k_scale.view(torch.float32)),
|
||||
|
||||
Reference in New Issue
Block a user