[DSV4] Refactor DeepseekV4Attention (#44569)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-06-04 20:23:07 -07:00
committed by GitHub
parent 56aff0dd15
commit 4efd6ffde0
8 changed files with 520 additions and 917 deletions
+3 -161
View File
@@ -18,7 +18,6 @@ from vllm.model_executor.layers.activation import SiluAndMul, SiluAndMulWithClam
from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
@@ -45,11 +44,7 @@ from vllm.model_executor.models.utils import (
make_layers,
maybe_prefix,
)
from vllm.models.deepseek_v4.attention import (
DeepseekV4Indexer,
DeepseekV4MLA,
)
from vllm.models.deepseek_v4.common.rope import build_deepseek_v4_rope
from vllm.models.deepseek_v4.amd.rocm import DeepseekV4ROCMAiterMLAAttention
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import has_tilelang
@@ -225,158 +220,6 @@ class DeepseekV4MoE(nn.Module):
return final_hidden_states.view(org_shape)
class DeepseekV4Attention(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str,
topk_indices_buffer: torch.Tensor | None = None,
aux_stream_list: list[torch.cuda.Stream] | None = None,
):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
layer_id = extract_layer_index(prefix)
self.layer_id = layer_id
self.hidden_size = config.hidden_size
self.n_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
assert self.n_heads % tp_size == 0
self.n_local_heads = self.n_heads // tp_size
self.q_lora_rank = config.q_lora_rank
self.o_lora_rank = config.o_lora_rank
self.head_dim = config.head_dim
self.rope_head_dim = config.qk_rope_head_dim
self.nope_head_dim = self.head_dim - self.rope_head_dim
self.n_groups = config.o_groups
self.n_local_groups = self.n_groups // tp_size
self.window_size = config.sliding_window
# NOTE(zyongye) Compress ratio can't be 0
# we do this for because MTP layer is not included
# in the compress ratio list
if layer_id < config.num_hidden_layers:
self.compress_ratio = max(1, config.compress_ratios[layer_id])
else:
self.compress_ratio = 1
self.eps = config.rms_norm_eps
self.max_position_embeddings = config.max_position_embeddings
# Padded to min 64 heads for FlashMLA, initialized to -inf
# (no sink effect). Weight loading fills the first n_local_heads slots.
padded_heads = max(self.n_local_heads, 64)
self.attn_sink = nn.Parameter(
torch.full((padded_heads,), -float("inf"), dtype=torch.float32),
requires_grad=False,
)
self.fused_wqa_wkv = MergedColumnParallelLinear(
self.hidden_size,
[self.q_lora_rank, self.head_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fused_wqa_wkv",
disable_tp=True, # fused ReplicatedLinear
)
self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
self.wq_b = ColumnParallelLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wq_b",
)
self.kv_norm = RMSNorm(self.head_dim, self.eps)
self.wo_a = ColumnParallelLinear(
self.n_heads * self.head_dim // self.n_groups,
self.n_groups * self.o_lora_rank,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wo_a",
)
self.wo_a.is_bmm = True
self.wo_a.bmm_batch_size = self.n_local_groups
self.wo_b = RowParallelLinear(
self.n_groups * self.o_lora_rank,
self.hidden_size,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wo_b",
)
self.softmax_scale = self.head_dim**-0.5
self.scale_fmt = config.quantization_config["scale_fmt"]
self.rope_parameters = config.rope_scaling
# Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it)
self.rotary_emb = build_deepseek_v4_rope(
config,
head_dim=self.head_dim,
rope_head_dim=self.rope_head_dim,
max_position_embeddings=self.max_position_embeddings,
compress_ratio=self.compress_ratio,
)
self.indexer = None
if self.compress_ratio == 4:
# Only C4A uses sparse attention and hence has indexer.
self.indexer = DeepseekV4Indexer(
vllm_config,
config=config,
hidden_size=self.hidden_size,
q_lora_rank=self.q_lora_rank,
quant_config=quant_config,
cache_config=vllm_config.cache_config,
topk_indices_buffer=topk_indices_buffer,
compress_ratio=self.compress_ratio,
prefix=f"{prefix}.indexer",
)
self.mla_attn = DeepseekV4MLA(
hidden_size=self.hidden_size,
num_heads=self.n_local_heads,
head_dim=self.head_dim,
scale=self.softmax_scale,
qk_nope_head_dim=self.nope_head_dim,
qk_rope_head_dim=self.rope_head_dim,
v_head_dim=self.head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.head_dim,
o_lora_rank=self.o_lora_rank,
vllm_config=vllm_config,
fused_wqa_wkv=self.fused_wqa_wkv,
q_norm=self.q_norm,
wq_b=self.wq_b,
kv_norm=self.kv_norm,
wo_a=self.wo_a,
wo_b=self.wo_b,
attn_sink=self.attn_sink,
rotary_emb=self.rotary_emb,
indexer=self.indexer,
indexer_rotary_emb=self.rotary_emb,
topk_indices_buffer=topk_indices_buffer,
aux_stream_list=aux_stream_list,
window_size=self.window_size,
compress_ratio=self.compress_ratio,
cache_config=vllm_config.cache_config,
quant_config=quant_config,
prefix=prefix,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None,
):
return self.mla_attn(positions, hidden_states, llama_4_scaling)
class DeepseekV4DecoderLayer(nn.Module):
def __init__(
self,
@@ -395,7 +238,7 @@ class DeepseekV4DecoderLayer(nn.Module):
self.hidden_size = config.hidden_size
self.rms_norm_eps = config.rms_norm_eps
self.attn = DeepseekV4Attention(
self.attn = DeepseekV4ROCMAiterMLAAttention(
vllm_config,
prefix=f"{prefix}.attn",
topk_indices_buffer=topk_indices_buffer,
@@ -601,7 +444,7 @@ class DeepseekV4Model(nn.Module):
self.rms_norm_eps = config.rms_norm_eps
# Three aux streams: one per non-default input GEMM in
# DeepseekV4MLA.attn_gemm_parallel_execute
# DeepseekV4Attention.attn_gemm_parallel_execute
# (compressor kv_score, indexer.weights_proj, indexer.compressor
# kv_score). fused_wqa_wkv stays on the default stream.
# Disable them on ROCm because of hang issues.
@@ -897,7 +740,6 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
".ffn.gate.bias": ".ffn.gate.e_score_correction_bias",
},
orig_to_new_substr={
".attn.compressor.": ".attn.mla_attn.compressor.",
".shared_experts.w2": ".shared_experts.down_proj",
},
)
+65 -68
View File
@@ -2,15 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TYPE_CHECKING, cast
from typing import cast
import torch
from vllm.forward_context import get_forward_context
from vllm.models.deepseek_v4.attention import DeepseekV4Attention
from vllm.models.deepseek_v4.common.ops import dequantize_and_gather_k_cache
from vllm.models.deepseek_v4.nvidia.flashmla import (
DeepseekV4FlashMLASparseBackend,
DeepseekV4SparseMLAAttentionImpl,
)
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import (
@@ -26,16 +26,12 @@ from vllm.v1.attention.backends.mla.sparse_swa import (
)
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
build_ragged_indices_from_dense,
rocm_inv_rope_einsum,
rocm_sparse_attn_decode,
rocm_sparse_attn_prefill,
)
from vllm.v1.worker.workspace import current_workspace_manager
if TYPE_CHECKING:
from vllm.models.deepseek_v4.attention import (
DeepseekV4MLAAttention,
)
def _build_indptr_from_lengths(lengths: torch.Tensor) -> torch.Tensor:
lengths = lengths.to(dtype=torch.int32).contiguous()
@@ -582,13 +578,9 @@ class DeepseekV4ROCMAiterMLASparseBackend(DeepseekV4FlashMLASparseBackend):
def get_builder_cls() -> type["DeepseekV4ROCMAiterMLASparseMetadataBuilder"]:
return DeepseekV4ROCMAiterMLASparseMetadataBuilder
@staticmethod
def get_impl_cls() -> type["DeepseekV4SparseMLAAttentionImpl"]:
return DeepseekV4ROCMAiterMLASparseImpl
class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
"""ROCm sparse MLA implementation used by DeepSeek V4's custom MLA layer."""
class DeepseekV4ROCMAiterMLAAttention(DeepseekV4Attention):
"""ROCm sparse MLA attention layer for DeepSeek V4."""
backend_cls = DeepseekV4ROCMAiterMLASparseBackend
@@ -596,10 +588,21 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
def get_padded_num_q_heads(cls, num_heads: int) -> int:
return num_heads
@classmethod
def forward_mqa( # type: ignore[override]
cls,
layer: "DeepseekV4MLAAttention",
def _o_proj(self, o: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
# ROCm BF16 reference wo_a path (inverse RoPE + einsum) + wo_b.
z = rocm_inv_rope_einsum(
self.rotary_emb,
o,
positions,
self.rope_head_dim,
self.n_local_groups,
self.o_lora_rank,
self.wo_a,
)
return self.wo_b(z.flatten(1))
def forward_mqa(
self,
q: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
@@ -619,16 +622,16 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
# Warmup dummy run: no real metadata. Reserve the same bf16
# gather workspace _forward_prefill would; the dequantize / topk
# / sparse_fwd kernels are skipped this step.
swa_only = layer.compress_ratio <= 1
swa_only = self.compress_ratio <= 1
N = (
0
if swa_only
else (layer.max_model_len + layer.compress_ratio - 1)
// layer.compress_ratio
else (self.max_model_len + self.compress_ratio - 1)
// self.compress_ratio
)
M = N + layer.window_size + layer.max_num_batched_tokens
M = N + self.window_size + self.max_num_batched_tokens
current_workspace_manager().get_simultaneous(
((cls.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
((self.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
)
output.zero_()
return
@@ -636,25 +639,24 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
assert isinstance(attn_metadata, dict)
rocm_metadata = cast(
DeepseekV4ROCMAiterMLASparseMetadata | None,
attn_metadata.get(layer.prefix),
attn_metadata.get(self.prefix),
)
swa_metadata = cast(
DeepseekV4ROCMAiterSparseSWAMetadata | None,
attn_metadata.get(layer.swa_cache_layer.prefix),
attn_metadata.get(self.swa_cache_layer.prefix),
)
assert swa_metadata is not None
swa_only = layer.compress_ratio <= 1
self_kv_cache = layer.kv_cache if not swa_only else None
swa_kv_cache = layer.swa_cache_layer.kv_cache
swa_only = self.compress_ratio <= 1
self_kv_cache = self.kv_cache if not swa_only else None
swa_kv_cache = self.swa_cache_layer.kv_cache
num_decodes = swa_metadata.num_decodes
num_prefills = swa_metadata.num_prefills
num_decode_tokens = swa_metadata.num_decode_tokens
if num_prefills > 0:
cls._forward_prefill(
layer=layer,
self._forward_prefill(
q=q[num_decode_tokens:],
positions=positions[num_decode_tokens:],
compressed_k_cache=self_kv_cache,
@@ -664,8 +666,7 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
swa_metadata=swa_metadata,
)
if num_decodes > 0:
cls._forward_decode(
layer=layer,
self._forward_decode(
q=q[:num_decode_tokens],
kv_cache=self_kv_cache,
swa_metadata=swa_metadata,
@@ -674,10 +675,8 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
output=output[:num_decode_tokens],
)
@classmethod
def _forward_decode(
cls,
layer: "DeepseekV4MLAAttention",
self,
q: torch.Tensor,
kv_cache: torch.Tensor | None,
swa_metadata: DeepseekV4ROCMAiterSparseSWAMetadata,
@@ -695,16 +694,16 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
if not swa_only:
assert attn_metadata is not None
assert swa_metadata.is_valid_token is not None
block_size = attn_metadata.block_size // layer.compress_ratio
block_size = attn_metadata.block_size // self.compress_ratio
is_valid = swa_metadata.is_valid_token[:num_decode_tokens]
if layer.compress_ratio == 4:
assert layer.topk_indices_buffer is not None
if self.compress_ratio == 4:
assert self.topk_indices_buffer is not None
(
topk_ragged_indices,
topk_ragged_indptr,
topk_lens,
) = compute_global_topk_ragged_indices_and_indptr(
layer.topk_indices_buffer[:num_decode_tokens],
self.topk_indices_buffer[:num_decode_tokens],
swa_metadata.token_to_req_indices,
attn_metadata.block_table[:num_decodes],
block_size,
@@ -719,7 +718,7 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
rocm_sparse_attn_decode(
q=q,
kv_cache=kv_cache,
swa_k_cache=layer.swa_cache_layer.kv_cache,
swa_k_cache=self.swa_cache_layer.kv_cache,
swa_only=swa_only,
topk_indices=topk_indices,
topk_lens=topk_lens,
@@ -729,18 +728,16 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
swa_ragged_indptr=swa_metadata.decode_swa_ragged_indptr,
topk_ragged_indices=topk_ragged_indices,
topk_ragged_indptr=topk_ragged_indptr,
attn_sink=layer.attn_sink,
scale=layer.scale,
head_dim=layer.head_dim,
nope_head_dim=layer.nope_head_dim,
rope_head_dim=layer.rope_head_dim,
attn_sink=self.attn_sink,
scale=self.scale,
head_dim=self.head_dim,
nope_head_dim=self.nope_head_dim,
rope_head_dim=self.rope_head_dim,
output=output,
)
@classmethod
def _forward_prefill(
cls,
layer: "DeepseekV4MLAAttention",
self,
q: torch.Tensor,
positions: torch.Tensor,
compressed_k_cache: torch.Tensor | None,
@@ -768,34 +765,34 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
prefill_token_base = query_start_loc_cpu[num_decodes]
if not swa_only:
if layer.compress_ratio == 4:
assert layer.topk_indices_buffer is not None
topk_indices = layer.topk_indices_buffer[num_decode_tokens:]
if self.compress_ratio == 4:
assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[num_decode_tokens:]
topk_indices = topk_indices[:num_prefill_tokens]
else:
assert attn_metadata is not None
topk_indices = attn_metadata.c128a_prefill_topk_indices
assert topk_indices is not None
top_k = topk_indices.shape[-1]
N = (layer.max_model_len + layer.compress_ratio - 1) // layer.compress_ratio
N = (self.max_model_len + self.compress_ratio - 1) // self.compress_ratio
else:
assert layer.topk_indices_buffer is not None
topk_indices = layer.topk_indices_buffer[num_decode_tokens:]
assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[num_decode_tokens:]
top_k = 0
N = 0
M = N + layer.window_size + layer.max_num_batched_tokens
num_chunks = (num_prefills + cls.PREFILL_CHUNK_SIZE - 1) // (
cls.PREFILL_CHUNK_SIZE
M = N + self.window_size + self.max_num_batched_tokens
num_chunks = (num_prefills + self.PREFILL_CHUNK_SIZE - 1) // (
self.PREFILL_CHUNK_SIZE
)
workspace_manager = current_workspace_manager()
kv = workspace_manager.get_simultaneous(
((cls.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
((self.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
)[0]
for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * cls.PREFILL_CHUNK_SIZE
chunk_end = min(chunk_start + cls.PREFILL_CHUNK_SIZE, num_prefills)
chunk_start = chunk_idx * self.PREFILL_CHUNK_SIZE
chunk_end = min(chunk_start + self.PREFILL_CHUNK_SIZE, num_prefills)
chunk_size = chunk_end - chunk_start
if not swa_only:
assert attn_metadata is not None
@@ -804,10 +801,10 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
dequantize_and_gather_k_cache(
kv[:chunk_size],
compressed_k_cache,
seq_lens=seq_lens[chunk_start:chunk_end] // layer.compress_ratio,
seq_lens=seq_lens[chunk_start:chunk_end] // self.compress_ratio,
gather_lens=None,
block_table=block_table[chunk_start:chunk_end],
block_size=attn_metadata.block_size // layer.compress_ratio,
block_size=attn_metadata.block_size // self.compress_ratio,
offset=0,
)
@@ -836,8 +833,8 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
],
seq_lens[chunk_start:chunk_end],
gather_lens[chunk_start:chunk_end],
layer.window_size,
layer.compress_ratio,
self.window_size,
self.compress_ratio,
top_k,
M,
N,
@@ -847,10 +844,10 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
kv=kv.view(-1, 1, q.shape[-1]),
indices=combined_indices,
topk_length=combined_lens,
scale=layer.scale,
head_dim=layer.head_dim,
nope_head_dim=layer.nope_head_dim,
rope_head_dim=layer.rope_head_dim,
attn_sink=layer.attn_sink,
scale=self.scale,
head_dim=self.head_dim,
nope_head_dim=self.nope_head_dim,
rope_head_dim=self.rope_head_dim,
attn_sink=self.attn_sink,
output=output[query_start:query_end],
)
+222 -343
View File
@@ -4,8 +4,9 @@
DeepseekV4 MLA Attention Layer
"""
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, ClassVar, cast
import torch
import torch.nn as nn
@@ -15,16 +16,16 @@ from transformers import DeepseekV2Config, DeepseekV3Config
import vllm.envs as envs
from vllm.compilation.breakable_cudagraph import eager_break_during_capture
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer
from vllm.models.deepseek_v4.common.ops import (
fused_indexer_q_rope_quant,
fused_inv_rope_fp8_quant,
fused_q_kv_rmsnorm,
)
from vllm.utils.deep_gemm import fp8_einsum
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import rocm_inv_rope_einsum
if TYPE_CHECKING:
from vllm.v1.attention.backends.mla.sparse_swa import (
@@ -42,14 +43,9 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.input_quant_fp8 import (
QuantFP8,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
from vllm.model_executor.models.utils import extract_layer_index
from vllm.models.deepseek_v4.common.rope import build_deepseek_v4_rope
from vllm.models.deepseek_v4.compressor import DeepseekCompressor
from vllm.platforms import current_platform
from vllm.utils.multi_stream_utils import (
execute_in_parallel,
maybe_execute_in_parallel,
@@ -62,187 +58,209 @@ from vllm.v1.attention.backends.mla.indexer import (
from vllm.v1.attention.backends.mla.sparse_swa import DeepseekV4SWACache
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
if TYPE_CHECKING:
from vllm.models.deepseek_v4.nvidia.flashmla import (
DeepseekV4SparseMLAAttentionImpl,
)
logger = init_logger(__name__)
def _resolve_dsv4_backend(vllm_config: VllmConfig | None):
"""Return the explicitly-requested DSv4 sparse backend enum, or None."""
if vllm_config is None:
return None
attn_config = getattr(vllm_config, "attention_config", None)
return getattr(attn_config, "backend", None) if attn_config is not None else None
def _select_v4_sparse_impl(
vllm_config: VllmConfig | None = None,
) -> "type[DeepseekV4SparseMLAAttentionImpl]":
"""Pick the V4 sparse MLA impl class.
An explicit ``--attention-backend FLASHINFER_MLA_SPARSE_DSV4`` selects the
FlashInfer TRTLLM-gen path; otherwise the platform default (FlashMLA on
NVIDIA, ROCm Aiter on AMD) is used.
"""
from vllm.v1.attention.backends.registry import AttentionBackendEnum
backend = _resolve_dsv4_backend(vllm_config)
if backend == AttentionBackendEnum.FLASHINFER_MLA_SPARSE_DSV4:
from vllm.models.deepseek_v4.nvidia.flashinfer_sparse import (
DeepseekV4FlashInferMLASparseImpl,
)
logger.info_once("Using FLASHINFER_MLA_SPARSE_DSV4 backend.")
return DeepseekV4FlashInferMLASparseImpl
if current_platform.is_rocm():
from vllm.models.deepseek_v4.amd.rocm import (
DeepseekV4ROCMAiterMLASparseImpl,
)
logger.info_once("Using ROCM_FLASHMLA_SPARSE_DSV4 backend.")
return DeepseekV4ROCMAiterMLASparseImpl
from vllm.models.deepseek_v4.nvidia.flashmla import (
DeepseekV4FlashMLASparseImpl,
)
logger.info_once("Using FLASHMLA_SPARSE_DSV4 backend.")
return DeepseekV4FlashMLASparseImpl
def _resolve_dsv4_kv_cache_dtype(
backend,
use_flashmla_fp8_layout: bool,
kv_cache_dtype: str,
cache_config: CacheConfig | None,
) -> tuple[str, torch.dtype]:
"""Map ``(backend, --kv-cache-dtype)`` to ``(cache_dtype_str, torch_dtype)``.
"""Map ``(layout, --kv-cache-dtype)`` to ``(cache_dtype_str, torch_dtype)``.
FlashInfer V4 reads a contiguous 512-wide KV row (bf16 or per-tensor FP8
E4M3); FlashMLA V4 reads the legacy UE8M0 paged layout (uint8 /
``fp8_ds_mla``). For FlashMLA the canonical ``fp8_ds_mla`` string is
written back onto ``cache_config`` so the page-size specs pick the 576B
layout.
Both layouts are paged; they differ in the per-token block format. The
FlashMLA fp8 layout (FlashMLA / ROCm Aiter) is the ``fp8_ds_mla`` format:
UE8M0 block-scaled fp8 packed as ``uint8`` (the canonical ``fp8_ds_mla``
string is written back onto ``cache_config`` so the page-size specs pick
the 576B per-token slot). Otherwise (FlashInfer) each token's KV row is
stored in its plain element dtype — bf16 or per-tensor FP8 E4M3.
"""
from vllm.v1.attention.backends.registry import AttentionBackendEnum
if use_flashmla_fp8_layout:
# fp8_ds_mla block format: UE8M0 block-scaled fp8 packed as uint8.
assert kv_cache_dtype.startswith("fp8"), (
f"DeepseekV4 FlashMLA fp8 layout only supports fp8 kv-cache, "
f"got {kv_cache_dtype}"
)
if kv_cache_dtype != "fp8_ds_mla":
if cache_config is not None:
cache_config.cache_dtype = "fp8_ds_mla"
kv_cache_dtype = "fp8_ds_mla"
logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.")
return kv_cache_dtype, torch.uint8
if backend == AttentionBackendEnum.FLASHINFER_MLA_SPARSE_DSV4:
if kv_cache_dtype.startswith("fp8"):
return kv_cache_dtype, torch.float8_e4m3fn
# auto / bfloat16 -> contiguous BF16 cache.
return kv_cache_dtype, torch.bfloat16
# FlashMLA (and ROCm Aiter): legacy UE8M0 paged uint8 cache.
assert kv_cache_dtype.startswith("fp8"), (
f"DeepseekV4 FlashMLA sparse backend only supports fp8 kv-cache, "
f"got {kv_cache_dtype}"
)
if kv_cache_dtype != "fp8_ds_mla":
if cache_config is not None:
cache_config.cache_dtype = "fp8_ds_mla"
kv_cache_dtype = "fp8_ds_mla"
logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.")
return kv_cache_dtype, torch.uint8
# Plain bf16 / per-tensor fp8 KV row (FlashInfer).
if kv_cache_dtype.startswith("fp8"):
return kv_cache_dtype, torch.float8_e4m3fn
# auto / bfloat16 -> plain bf16 KV row.
return kv_cache_dtype, torch.bfloat16
class DeepseekV4MLA(nn.Module):
class DeepseekV4Attention(nn.Module, AttentionLayerBase, ABC):
"""DeepseekV4 MLA attention layer.
The platform-specific sparse-MLA forward (``forward_mqa`` /
``get_padded_num_q_heads`` / ``_o_proj`` / ``backend_cls``) is provided by a
subclass — ``DeepseekV4FlashMLAAttention`` / ``DeepseekV4FlashInferMLAAttention``
(CUDA) or ``DeepseekV4ROCMAiterMLAAttention`` (ROCm) — selected by the
platform-specific deepseek_v4 model module. The base is never instantiated
directly.
"""
# Provided by the platform subclass.
backend_cls: ClassVar[type[AttentionBackend]]
# KV-cache per-token block format (both layouts are paged). True (default)
# = FlashMLA / ROCm fp8_ds_mla (UE8M0 block-scaled fp8 packed as uint8);
# False = FlashInfer plain bf16 / per-tensor fp8 KV row.
use_flashmla_fp8_layout: ClassVar[bool] = True
# Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather
# workspace allocated in _forward_prefill and is also read by the dummy-run
# path to pre-reserve that workspace.
PREFILL_CHUNK_SIZE: ClassVar[int] = 4
@classmethod
@abstractmethod
def get_padded_num_q_heads(cls, num_heads: int) -> int:
"""Q head count the q/output buffers are allocated at.
The layer allocates the q/output buffers at
``[N, get_padded_num_q_heads(n_local_heads), head_dim]``. Must satisfy
``result >= num_heads``. Backends with no padding constraint return
``num_heads``.
"""
raise NotImplementedError
@abstractmethod
def forward_mqa(
self,
q: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
output: torch.Tensor,
) -> None:
"""Platform-specific sparse MLA forward; writes attention into ``output``."""
raise NotImplementedError
@abstractmethod
def _o_proj(self, o: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
"""Inverse-RoPE + wo_a + wo_b output projection (platform-specific)."""
raise NotImplementedError
def __init__(
self,
hidden_size: int,
num_heads: int,
head_dim: int,
scale: float,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int | None,
kv_lora_rank: int,
o_lora_rank: int | None,
vllm_config: VllmConfig,
fused_wqa_wkv: torch.nn.Module,
q_norm: torch.nn.Module,
wq_b: torch.nn.Module,
kv_norm: torch.nn.Module,
wo_a: torch.nn.Module,
wo_b: torch.nn.Module,
attn_sink: torch.nn.Module,
rotary_emb: torch.nn.Module,
indexer: torch.nn.Module | None,
indexer_rotary_emb: torch.nn.Module,
topk_indices_buffer: torch.Tensor | None,
aux_stream_list: list[torch.cuda.Stream] | None,
window_size: int,
compress_ratio: int | None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
prefix: str,
topk_indices_buffer: torch.Tensor | None = None,
aux_stream_list: list[torch.cuda.Stream] | None = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.n_local_heads = num_heads
self.head_dim = head_dim
self.scale = scale
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.window_size = window_size
self.compress_ratio = compress_ratio if compress_ratio is not None else 1
self.prefix = prefix
# Extract config from vllm_config
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
cache_config = vllm_config.cache_config
tp_size = get_tensor_model_parallel_world_size()
layer_id = extract_layer_index(prefix)
# DeepseekV4-specific attributes (num_heads is already TP-adjusted)
self.eps = config.rms_norm_eps
self.rope_head_dim = config.qk_rope_head_dim
self.nope_head_dim = head_dim - self.rope_head_dim
self.n_local_groups = config.o_groups // tp_size
self.prefix = prefix # Alias for compatibility with compressor
self.hidden_size = config.hidden_size
self.n_heads = config.num_attention_heads
assert self.n_heads % tp_size == 0
self.n_local_heads = self.n_heads // tp_size
self.q_lora_rank = config.q_lora_rank
self.o_lora_rank = config.o_lora_rank
self.head_dim = config.head_dim
self.rope_head_dim = config.qk_rope_head_dim
self.nope_head_dim = self.head_dim - self.rope_head_dim
self.n_groups = config.o_groups
self.n_local_groups = self.n_groups // tp_size
self.window_size = config.sliding_window
# NOTE(zyongye) Compress ratio can't be 0
# we do this for because MTP layer is not included
# in the compress ratio list
if layer_id < config.num_hidden_layers:
self.compress_ratio = max(1, config.compress_ratios[layer_id])
else:
self.compress_ratio = 1
self.eps = config.rms_norm_eps
self.scale = self.head_dim**-0.5
# Store projection modules
self.fused_wqa_wkv = fused_wqa_wkv
self.q_norm = q_norm
self.wq_b = wq_b
self.kv_norm = kv_norm
self.wo_a = wo_a
self._wo_a_act_quant = QuantFP8(
static=False,
group_shape=GroupShape(1, 128),
use_ue8m0=True,
# Padded Q head count is dictated by the platform subclass.
self.padded_heads = self.get_padded_num_q_heads(self.n_local_heads)
# Sink padded to the same head count, initialized to -inf (no sink
# effect). Weight loading fills the first n_local_heads slots.
self.attn_sink = nn.Parameter(
torch.full((self.padded_heads,), -float("inf"), dtype=torch.float32),
requires_grad=False,
)
# Bypass packed-for-deepgemm path — we need FP32 scales (not packed
# INT32) so fp8_einsum can handle layout transform internally.
self._wo_a_act_quant.use_deep_gemm_supported = False
self.wo_b = wo_b
# Pick fp8_einsum recipe based on GPU arch:
# SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128
# SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1
cap = current_platform.get_device_capability()
assert cap is not None, "DeepseekV4 attention requires a CUDA device"
self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128)
self._tma_aligned_scales = cap.major >= 10
self.fused_wqa_wkv = MergedColumnParallelLinear(
self.hidden_size,
[self.q_lora_rank, self.head_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fused_wqa_wkv",
disable_tp=True, # fused ReplicatedLinear
)
self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
self.wq_b = ColumnParallelLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wq_b",
)
self.rotary_emb = rotary_emb
self.indexer_rotary_emb = indexer_rotary_emb
self.kv_norm = RMSNorm(self.head_dim, self.eps)
self.wo_a = ColumnParallelLinear(
self.n_heads * self.head_dim // self.n_groups,
self.n_groups * self.o_lora_rank,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wo_a",
)
self.wo_a.is_bmm = True
self.wo_a.bmm_batch_size = self.n_local_groups
self.wo_b = RowParallelLinear(
self.n_groups * self.o_lora_rank,
self.hidden_size,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wo_b",
)
# Initialize rotary embedding before the indexer/compressor consume it.
self.rotary_emb = build_deepseek_v4_rope(
config,
head_dim=self.head_dim,
rope_head_dim=self.rope_head_dim,
max_position_embeddings=config.max_position_embeddings,
compress_ratio=self.compress_ratio,
)
self.indexer_rotary_emb = self.rotary_emb
self.topk_indices_buffer = topk_indices_buffer
self.indexer = indexer
# Per-head RMS normalization for Q (no learnable weights)
self.q_head_norm = RMSNorm(head_dim, eps=self.eps, has_weight=False)
# TODO(yifan): currently hardcoded for FP8 sparse, make it more generic
head_bytes = (
self.nope_head_dim # 448 fp8 NoPE
+ self.rope_head_dim * 2 # 64 bf16 RoPE
+ self.nope_head_dim // 64 # 7B scale factors
+ 1 # 1B pad
)
self.indexer = None
if self.compress_ratio == 4:
# Only C4A uses sparse attention and hence has indexer.
# aux_stream_list[2] is free here (outer GEMMs joined) for the inner
# overlap of wq_b+fused_indexer_q_rope_quant vs compressor. None on
# ROCm, where aux_stream_list is None.
indexer_aux_stream = (
aux_stream_list[2] if aux_stream_list is not None else None
)
self.indexer = DeepseekV4Indexer(
vllm_config,
config=config,
hidden_size=self.hidden_size,
q_lora_rank=self.q_lora_rank,
quant_config=quant_config,
cache_config=cache_config,
topk_indices_buffer=topk_indices_buffer,
compress_ratio=self.compress_ratio,
prefix=f"{prefix}.indexer",
aux_stream=indexer_aux_stream,
)
# Will be None on ROCm for now.
self.aux_stream_list = aux_stream_list
@@ -252,45 +270,39 @@ class DeepseekV4MLA(nn.Module):
self.ln_events = [torch.cuda.Event() for _ in range(4)]
assert cache_config is not None, "DeepseekV4 attention requires cache_config"
# Resolve the SWA cache tensor dtype from the selected backend: FlashMLA
# uses the legacy UE8M0 paged uint8 layout; FlashInfer uses a contiguous
# bf16 / per-tensor fp8 row.
backend = _resolve_dsv4_backend(vllm_config)
_, swa_cache_torch_dtype = _resolve_dsv4_kv_cache_dtype(
backend, cache_config.cache_dtype, cache_config
# ---- Attention / KV-cache setup ----
self.max_num_batched_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens
)
self.max_model_len = vllm_config.model_config.max_model_len
# Resolve the kv-cache dtype from this backend's block format (a
# ClassVar set by the subclass): fp8_ds_mla (UE8M0 block-scaled fp8 as
# uint8) for FlashMLA / ROCm, vs a plain bf16 / per-tensor fp8 row for
# FlashInfer. The same resolution drives the SWA cache tensor dtype
# below.
self.kv_cache_dtype, self.kv_cache_torch_dtype = _resolve_dsv4_kv_cache_dtype(
self.use_flashmla_fp8_layout, cache_config.cache_dtype, cache_config
)
self.swa_cache_layer = DeepseekV4SWACache(
head_dim=self.head_dim,
window_size=self.window_size,
dtype=swa_cache_torch_dtype,
dtype=self.kv_cache_torch_dtype,
prefix=f"{prefix}.swa_cache",
cache_config=cache_config,
)
self.mla_attn = DeepseekV4MLAAttention(
num_heads=self.n_local_heads,
head_dim=self.head_dim,
scale=self.scale,
qk_nope_head_dim=self.nope_head_dim,
qk_rope_head_dim=self.rope_head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
compress_ratio=self.compress_ratio,
window_size=self.window_size,
head_bytes=head_bytes,
swa_cache_layer=self.swa_cache_layer,
attn_sink=attn_sink, # already padded with -inf
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
indexer=self.indexer,
topk_indices_buffer=self.topk_indices_buffer,
)
# Mirror the inner layer's padded head count (single source of truth).
self.padded_heads = self.mla_attn.padded_heads
# Register with compilation context for metadata lookup.
compilation_config = vllm_config.compilation_config
if prefix and prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
if prefix:
compilation_config.static_forward_context[prefix] = self
self.kv_cache = torch.tensor([])
# Create the compressor for layers with compress_ratio > 1; after
# creating the DeepseekV4MLAAttention layer to get its cache.
# Create the compressor for layers with compress_ratio > 1; after the
# attention setup above so its KV-cache prefix (self.prefix) is set.
self.compressor = None
if self.compress_ratio > 1:
self.compressor = DeepseekCompressor(
@@ -300,7 +312,7 @@ class DeepseekV4MLA(nn.Module):
head_dim=self.head_dim,
rotate=True,
prefix=f"{prefix}.compressor",
k_cache_prefix=self.mla_attn.prefix,
k_cache_prefix=self.prefix,
)
def forward(
@@ -324,48 +336,8 @@ class DeepseekV4MLA(nn.Module):
self.attention_impl(hidden_states, positions, o_padded)
o = o_padded[:, : self.n_local_heads, :]
# Keep ROCm on the BF16 reference wo_a path util kernel ready.
if current_platform.is_rocm():
z = rocm_inv_rope_einsum(
self.rotary_emb,
o,
positions,
self.rope_head_dim,
self.n_local_groups,
self.o_lora_rank,
self.wo_a,
)
return self.wo_b(z.flatten(1))
# O projection: inverse RoPE + FP8 quant + einsum + wo_b
o_fp8, o_scale = fused_inv_rope_fp8_quant(
o,
positions,
self.rotary_emb.cos_sin_cache,
n_groups=self.n_local_groups,
heads_per_group=self.n_local_heads // self.n_local_groups,
nope_dim=self.nope_head_dim,
rope_dim=self.rope_head_dim,
tma_aligned_scales=self._tma_aligned_scales,
)
wo_a_fp8 = self.wo_a.weight
wo_a_scale = self.wo_a.weight_scale_inv
z = torch.empty(
(num_tokens, self.n_local_groups, self.o_lora_rank),
device=o.device,
dtype=torch.bfloat16,
)
fp8_einsum(
"bhr,hdr->bhd",
(o_fp8, o_scale),
(wo_a_fp8, wo_a_scale),
z,
recipe=self._einsum_recipe,
)
return self.wo_b(z.flatten(1))
# Inverse-RoPE + wo_a + wo_b output projection (platform-specific).
return self._o_proj(o, positions)
def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]:
aux_streams = self.aux_stream_list
@@ -451,7 +423,7 @@ class DeepseekV4MLA(nn.Module):
)
# wq_b + kv_insert (+ MLA compressor when an indexer is present) ride
# on the default stream so q stays on its consumer stream (mla_attn
# on the default stream so q stays on its consumer stream (forward_mqa
# downstream reads q on default). Indexer/compressor go on aux for
# overlap with default's GEMM + cache write.
if self.indexer is not None:
@@ -514,7 +486,7 @@ class DeepseekV4MLA(nn.Module):
# MLA attention writes into the pre-allocated `out` buffer
# ([num_tokens, padded_heads, head_dim]).
self.mla_attn(q, kv, positions, output=out)
self.forward_mqa(q, kv, positions, out)
def _fused_qnorm_rope_kv_insert(
self,
@@ -549,7 +521,7 @@ class DeepseekV4MLA(nn.Module):
cos_sin_cache = self.rotary_emb.cos_sin_cache
cache_dtype = swa_kv_cache.dtype
# kv is unchanged; mla_attn reads kv solely via swa_kv_cache.
# kv is unchanged; attention reads kv solely via swa_kv_cache.
if cache_dtype == torch.uint8:
# Legacy FlashMLA UE8M0 paged path. Horizontally fused:
# Q side: per-head RMSNorm (no weight) + GPT-J RoPE, zero-filling
@@ -569,9 +541,10 @@ class DeepseekV4MLA(nn.Module):
swa_metadata.block_size,
)
# FlashInfer full-cache path: contiguous [num_blocks, block_size, 512]
# cache (no Q padding). bf16 rewrites q in place; per-tensor fp8 writes a
# separately-allocated fp8 q and quantizes the KV row.
# FlashInfer full-cache path: the [num_blocks, block_size, 512] cache
# stores the KV row in its plain dtype (no Q padding). bf16 rewrites q
# in place; per-tensor fp8 writes a separately-allocated fp8 q and
# quantizes the KV row.
block_size = swa_metadata.block_size
swa_kv_cache_3d = swa_kv_cache.view(-1, block_size, self.head_dim)
if cache_dtype == torch.bfloat16:
@@ -597,99 +570,13 @@ class DeepseekV4MLA(nn.Module):
swa_metadata.slot_mapping,
positions,
cos_sin_cache,
self.mla_attn._flashinfer_fp8_kv_scale,
self.mla_attn._flashinfer_fp8_q_scale_inv,
self._flashinfer_fp8_kv_scale,
self._flashinfer_fp8_q_scale_inv,
self.eps,
block_size,
)
return q_fp8
class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
def __init__(
self,
num_heads: int,
head_dim: int,
scale: float,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
q_lora_rank: int | None,
kv_lora_rank: int,
compress_ratio: int,
window_size: int,
head_bytes: int,
swa_cache_layer: DeepseekV4SWACache,
attn_sink: torch.Tensor,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
# Sparse MLA Args
indexer: object | None = None,
topk_indices_buffer: torch.Tensor | None = None,
aux_stream: torch.cuda.Stream | None = None,
**extra_impl_args,
) -> None:
super().__init__()
vllm_config = get_current_vllm_config()
self.impl_cls = _select_v4_sparse_impl(vllm_config)
self.backend_cls = self.impl_cls.backend_cls
self.num_heads = num_heads
self.num_kv_heads = 1
self.head_dim = head_dim
self.scale = scale
self.window_size = window_size
self.head_bytes = head_bytes
self.compress_ratio = compress_ratio
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.nope_head_dim = qk_nope_head_dim
self.rope_head_dim = qk_rope_head_dim
self.indexer = indexer
self.topk_indices_buffer = topk_indices_buffer
self.prefix = prefix # Alias for compatibility with compressor
self.aux_stream = aux_stream
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
# Padded Q head count is dictated by the selected impl.
self.padded_heads = self.impl_cls.get_padded_num_q_heads(num_heads)
# Store attention sink
assert attn_sink is not None
self.attn_sink: torch.Tensor = attn_sink
# Store SWA cache
assert swa_cache_layer is not None
self.swa_cache_layer: DeepseekV4SWACache = swa_cache_layer
# Get vllm config for cache setup
self.max_num_batched_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens
)
self.max_model_len = vllm_config.model_config.max_model_len
# Resolve the kv-cache dtype from the selected backend. FlashMLA uses
# the legacy UE8M0 paged uint8 (fp8_ds_mla) layout; FlashInfer uses a
# contiguous bf16 / per-tensor fp8 row.
backend = _resolve_dsv4_backend(vllm_config)
kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "fp8"
self.kv_cache_dtype, self.kv_cache_torch_dtype = _resolve_dsv4_kv_cache_dtype(
backend, kv_cache_dtype, cache_config
)
# Per-impl layer buffers (e.g. FlashInfer FP8 scale buffers). No-op for
# the FlashMLA / ROCm impls.
self.impl_cls.init_layer_buffers(self)
# Register with compilation context for metadata lookup
compilation_config = vllm_config.compilation_config
if prefix and prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
if prefix:
compilation_config.static_forward_context[prefix] = self
self.kv_cache = torch.tensor([])
def get_attn_backend(self) -> type[AttentionBackend]:
return self.backend_cls
@@ -698,8 +585,9 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
self.compress_ratio <= 1
): # SWA part. Allocated separately as DeepseekV4SWACache.
return None
# FlashMLA uses the UE8M0 paged uint8 layout (576B aligned); FlashInfer
# uses a contiguous bf16 / per-tensor fp8 cache with no extra alignment.
# FlashMLA uses the fp8_ds_mla block format (UE8M0 block-scaled fp8 as
# uint8, 576B aligned); FlashInfer stores a plain bf16 / per-tensor fp8
# row with no extra alignment.
is_flashmla = self.kv_cache_dtype == "fp8_ds_mla"
return MLAAttentionSpec(
block_size=vllm_config.cache_config.block_size,
@@ -712,15 +600,6 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
model_version="deepseek_v4",
)
def forward(
self,
q: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
output: torch.Tensor,
) -> None:
self.impl_cls.forward_mqa(self, q, kv, positions, output)
class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase):
def __init__(
@@ -3,28 +3,30 @@
"""DeepSeek V4 FlashInfer TRTLLM-gen sparse MLA backend.
Uses FlashInfer's public ``trtllm_batch_decode_sparse_mla_dsv4`` launcher with a
contiguous bf16 / per-tensor FP8 KV cache. Shares the V4 sparse-index pipeline
(SWA cache + compressor + indexer, 256-token blocks, head_size 512) with the
FlashMLA V4 backend; only the attention forward differs.
plain bf16 / per-tensor FP8 KV row (vs FlashMLA's packed ``fp8_ds_mla`` block
format). Shares the V4 sparse-index pipeline (SWA cache + compressor + indexer,
256-token blocks, head_size 512) with the FlashMLA V4 backend; only the
attention forward differs.
"""
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, ClassVar, cast
import torch
from vllm.forward_context import get_forward_context
from vllm.models.deepseek_v4.attention import DeepseekV4Attention
from vllm.models.deepseek_v4.common.ops import (
build_flashinfer_mixed_sparse_indices,
)
from vllm.models.deepseek_v4.nvidia.flashmla import (
DeepseekV4FlashMLASparseBackend,
DeepseekV4SparseMLAAttentionImpl,
from vllm.models.deepseek_v4.nvidia.flashmla import DeepseekV4FlashMLASparseBackend
from vllm.models.deepseek_v4.nvidia.ops.o_proj import (
compute_fp8_einsum_recipe,
deep_gemm_fp8_o_proj,
)
from vllm.utils.flashinfer import flashinfer_trtllm_batch_decode_sparse_mla_dsv4
from vllm.v1.attention.backends.mla.flashmla_sparse import FlashMLASparseMetadata
if TYPE_CHECKING:
from vllm.models.deepseek_v4.attention import DeepseekV4MLAAttention
from vllm.v1.attention.backends.mla.sparse_swa import DeepseekSparseSWAMetadata
# 128 MB TRTLLM-gen workspace, allocated once per device and zero-initialized
@@ -51,23 +53,21 @@ class DeepseekV4FlashInferMLASparseBackend(DeepseekV4FlashMLASparseBackend):
Inheriting from the FlashMLA V4 backend reuses its ``FlashMLASparseMetadata``
builder (which the V4 sparse-index pipeline needs — the V3.2 FlashInfer
builder lacks the ``c128a_*`` fields), 256-token blocks, head_size 512, and
the contiguous (num_blocks, block_size, 512) cache shape for non-``fp8_ds_mla``
dtypes.
the (num_blocks, block_size, 512) cache shape for non-``fp8_ds_mla`` dtypes.
"""
@staticmethod
def get_name() -> str:
return "FLASHINFER_MLA_SPARSE_DSV4"
@staticmethod
def get_impl_cls() -> type["DeepseekV4FlashInferMLASparseImpl"]:
return DeepseekV4FlashInferMLASparseImpl
class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
"""FlashInfer TRTLLM-gen sparse MLA implementation for DeepSeek V4."""
class DeepseekV4FlashInferMLAAttention(DeepseekV4Attention):
"""FlashInfer TRTLLM-gen sparse MLA attention layer for DeepSeek V4."""
backend_cls = DeepseekV4FlashInferMLASparseBackend
# FlashInfer stores a plain bf16 / per-tensor fp8 KV row, not the FlashMLA
# packed fp8_ds_mla block format (UE8M0 block-scaled fp8 as uint8).
use_flashmla_fp8_layout: ClassVar[bool] = False
@classmethod
def get_padded_num_q_heads(cls, num_heads: int) -> int:
@@ -79,27 +79,44 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
)
return 64 if num_heads <= 64 else 128
@classmethod
def init_layer_buffers(cls, layer: "DeepseekV4MLAAttention") -> None:
def _o_proj(self, o: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
return deep_gemm_fp8_o_proj(
o,
positions,
self.rotary_emb.cos_sin_cache,
self.wo_a,
self.wo_b,
n_groups=self.n_local_groups,
heads_per_group=self.n_local_heads // self.n_local_groups,
nope_dim=self.nope_head_dim,
rope_dim=self.rope_head_dim,
o_lora_rank=self.o_lora_rank,
einsum_recipe=self._einsum_recipe,
tma_aligned_scales=self._tma_aligned_scales,
)
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._einsum_recipe, self._tma_aligned_scales = compute_fp8_einsum_recipe()
# Per-tensor FP8 scale buffers + precomputed scalar BMM scales. Only the
# per-tensor FP8 cache path consumes these; bf16 reads ``layer.scale``.
if layer.kv_cache_torch_dtype != torch.float8_e4m3fn:
# per-tensor FP8 cache path consumes these; bf16 reads ``self.scale``.
if self.kv_cache_torch_dtype != torch.float8_e4m3fn:
return
# TODO: load real per-tensor Q/KV scales from the checkpoint; unit
# scales until the scale tensor names are wired.
fp8_q_scale = 1.0
fp8_kv_scale = 1.0
layer.register_buffer(
self.register_buffer(
"_flashinfer_fp8_q_scale",
torch.tensor([fp8_q_scale], dtype=torch.float32),
persistent=False,
)
layer.register_buffer(
self.register_buffer(
"_flashinfer_fp8_q_scale_inv",
torch.tensor([1.0 / fp8_q_scale], dtype=torch.float32),
persistent=False,
)
layer.register_buffer(
self.register_buffer(
"_flashinfer_fp8_kv_scale",
torch.tensor([fp8_kv_scale], dtype=torch.float32),
persistent=False,
@@ -107,13 +124,11 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
# TRTLLM-gen takes scalar scale args on a distinct (correct) C++ path
# vs 1-elem tensors, so these are Python floats. bmm1 folds the softmax
# scale and the Q/KV per-tensor scales; bmm2 is the KV scale.
layer._flashinfer_fp8_bmm1_scale = layer.scale * fp8_q_scale * fp8_kv_scale
layer._flashinfer_fp8_bmm2_scale = fp8_kv_scale
self._flashinfer_fp8_bmm1_scale = self.scale * fp8_q_scale * fp8_kv_scale
self._flashinfer_fp8_bmm2_scale = fp8_kv_scale
@classmethod
def forward_mqa( # type: ignore[override]
cls,
layer: "DeepseekV4MLAAttention",
def forward_mqa(
self,
q: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
@@ -147,21 +162,20 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
assert isinstance(attn_metadata, dict)
flashmla_metadata = cast(
FlashMLASparseMetadata | None, attn_metadata.get(layer.prefix)
FlashMLASparseMetadata | None, attn_metadata.get(self.prefix)
)
swa_metadata = cast(
"DeepseekSparseSWAMetadata | None",
attn_metadata.get(layer.swa_cache_layer.prefix),
attn_metadata.get(self.swa_cache_layer.prefix),
)
assert swa_metadata is not None
swa_only = layer.compress_ratio <= 1
swa_only = self.compress_ratio <= 1
# SWA-only layers don't allocate their own compressed KV cache.
self_kv_cache = layer.kv_cache if not swa_only else None
swa_kv_cache = layer.swa_cache_layer.kv_cache
self_kv_cache = self.kv_cache if not swa_only else None
swa_kv_cache = self.swa_cache_layer.kv_cache
cls._forward(
layer=layer,
self._forward(
q=q,
kv_cache=self_kv_cache,
swa_k_cache=swa_kv_cache,
@@ -171,10 +185,8 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
output=output,
)
@classmethod
def _build_sparse_index_metadata(
cls,
layer: "DeepseekV4MLAAttention",
self,
kv_cache: torch.Tensor | None,
swa_k_cache: torch.Tensor,
swa_metadata: "DeepseekSparseSWAMetadata",
@@ -200,17 +212,17 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
assert swa_metadata.block_table is not None
decode_swa_indices = swa_metadata.decode_swa_indices.reshape(
num_decode_tokens, layer.window_size
num_decode_tokens, self.window_size
)
decode_compressed_topk_lens = None
decode_compressed_indices_are_local = False
decode_is_valid_token = None
if swa_only:
assert layer.topk_indices_buffer is not None
assert self.topk_indices_buffer is not None
compressed_kv_cache = swa_k_cache
decode_compressed_indices = None
prefill_topk_indices = layer.topk_indices_buffer[
prefill_topk_indices = self.topk_indices_buffer[
num_decode_tokens:num_tokens, :0
]
compressed_block_table = None
@@ -221,24 +233,24 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
assert attn_metadata is not None
compressed_kv_cache = kv_cache
compressed_block_table = attn_metadata.block_table[:num_reqs]
compressed_block_size = attn_metadata.block_size // layer.compress_ratio
compressed_block_size = attn_metadata.block_size // self.compress_ratio
if layer.compress_ratio == 4:
assert layer.topk_indices_buffer is not None
if self.compress_ratio == 4:
assert self.topk_indices_buffer is not None
if num_prefill_tokens > 0:
prefill_topk_indices = layer.topk_indices_buffer[
prefill_topk_indices = self.topk_indices_buffer[
num_decode_tokens:num_tokens
]
top_k = prefill_topk_indices.shape[-1]
else:
prefill_topk_indices = layer.topk_indices_buffer[:0, :0]
prefill_topk_indices = self.topk_indices_buffer[:0, :0]
top_k = 0
decode_compressed_indices_are_local = True
assert swa_metadata.is_valid_token is not None
decode_is_valid_token = swa_metadata.is_valid_token[:num_decode_tokens]
if num_decode_tokens > 0:
decode_compressed_indices = layer.topk_indices_buffer[
decode_compressed_indices = self.topk_indices_buffer[
:num_decode_tokens
]
else:
@@ -284,18 +296,16 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
swa_metadata.block_size,
compressed_block_table,
compressed_block_size,
layer.window_size,
layer.compress_ratio,
self.window_size,
self.compress_ratio,
top_k,
decode_compressed_indices_are_local=decode_compressed_indices_are_local,
decode_is_valid_token=decode_is_valid_token,
)
return compressed_kv_cache, seq_lens, sparse_indices, sparse_topk_lens
@classmethod
def _forward(
cls,
layer: "DeepseekV4MLAAttention",
self,
q: torch.Tensor,
kv_cache: torch.Tensor | None,
swa_k_cache: torch.Tensor,
@@ -304,7 +314,7 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
swa_only: bool,
output: torch.Tensor,
) -> None:
assert layer.kv_cache_torch_dtype in (torch.bfloat16, torch.float8_e4m3fn)
assert self.kv_cache_torch_dtype in (torch.bfloat16, torch.float8_e4m3fn)
num_decodes = swa_metadata.num_decodes
num_prefills = swa_metadata.num_prefills
num_decode_tokens = swa_metadata.num_decode_tokens
@@ -319,8 +329,7 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
seq_lens,
sparse_indices,
sparse_topk_lens,
) = cls._build_sparse_index_metadata(
layer=layer,
) = self._build_sparse_index_metadata(
kv_cache=kv_cache,
swa_k_cache=swa_k_cache,
swa_metadata=swa_metadata,
@@ -332,12 +341,12 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
# restrict to the real tokens (the launcher validates sparse indices).
query = q[:num_tokens]
output = output[:num_tokens]
bmm1_scale: float | torch.Tensor = layer.scale
bmm1_scale: float | torch.Tensor = self.scale
bmm2_scale: float | torch.Tensor = 1.0
if layer.kv_cache_torch_dtype == torch.float8_e4m3fn:
if self.kv_cache_torch_dtype == torch.float8_e4m3fn:
assert query.dtype == torch.float8_e4m3fn
bmm1_scale = layer._flashinfer_fp8_bmm1_scale
bmm2_scale = layer._flashinfer_fp8_bmm2_scale
bmm1_scale = self._flashinfer_fp8_bmm1_scale
bmm2_scale = self._flashinfer_fp8_bmm2_scale
else:
assert query.dtype == torch.bfloat16
query = query.contiguous()
@@ -376,7 +385,7 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
out=output[:num_decode_tokens],
bmm1_scale=bmm1_scale,
bmm2_scale=bmm2_scale,
sinks=layer.attn_sink,
sinks=self.attn_sink,
cum_seq_lens_q=decode_cu,
max_q_len=int(decode_lens_cpu.max().item()),
)
@@ -401,7 +410,7 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
out=output[num_decode_tokens:num_tokens],
bmm1_scale=bmm1_scale,
bmm2_scale=bmm2_scale,
sinks=layer.attn_sink,
sinks=self.attn_sink,
cum_seq_lens_q=prefill_cu,
max_q_len=int(prefill_lens_cpu.max().item()),
)
+72 -118
View File
@@ -1,22 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import abstractmethod
from typing import TYPE_CHECKING, ClassVar, cast
from typing import TYPE_CHECKING, cast
import torch
from vllm.forward_context import get_forward_context
from vllm.models.deepseek_v4.attention import DeepseekV4Attention
from vllm.models.deepseek_v4.common.ops import (
combine_topk_swa_indices,
compute_global_topk_indices_and_lens,
dequantize_and_gather_k_cache,
)
from vllm.v1.attention.backend import (
AttentionBackend,
MultipleOf,
SparseMLAAttentionImpl,
from vllm.models.deepseek_v4.nvidia.ops.o_proj import (
compute_fp8_einsum_recipe,
deep_gemm_fp8_o_proj,
)
from vllm.v1.attention.backend import MultipleOf
from vllm.v1.attention.backends.mla.flashmla_sparse import (
FlashMLASparseBackend,
FlashMLASparseMetadata,
@@ -28,63 +28,9 @@ from vllm.v1.attention.ops.flashmla import (
from vllm.v1.worker.workspace import current_workspace_manager
if TYPE_CHECKING:
from vllm.models.deepseek_v4.attention import (
DeepseekV4MLAAttention,
)
from vllm.v1.attention.backends.mla.sparse_swa import DeepseekSparseSWAMetadata
class DeepseekV4SparseMLAAttentionImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
"""Abstract parent for DeepseekV4 sparse MLA impls.
V4 sparse MLA is driven by the layer (``DeepseekV4MLAAttention.forward``)
rather than the v1 framework, so ``forward_mqa`` is overridden with a
classmethod that takes the layer as its first argument. This Liskov-broken
override is intentional: the grandparent's instance-method ``forward_mqa``
is never called on V4 layers.
"""
backend_cls: ClassVar[type[AttentionBackend]]
# Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather
# workspace allocated in _forward_prefill and is also read by the V4 layer's
# dummy-run path to pre-reserve that workspace.
PREFILL_CHUNK_SIZE: ClassVar[int] = 4
@classmethod
@abstractmethod
def forward_mqa( # type: ignore[override]
cls,
layer: "DeepseekV4MLAAttention",
q: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
output: torch.Tensor,
) -> None:
raise NotImplementedError
@classmethod
@abstractmethod
def get_padded_num_q_heads(cls, num_heads: int) -> int:
"""Q head count the backend wants q allocated at.
The MLA wrapper allocates the q/output buffers at
``[N, get_padded_num_q_heads(n_local_heads), head_dim]``. Must
satisfy ``result >= num_heads``. Backends with no padding constraint
return ``num_heads``.
"""
raise NotImplementedError
@classmethod
def init_layer_buffers(cls, layer: "DeepseekV4MLAAttention") -> None:
"""Register impl-specific buffers on the layer at construction.
No-op by default; FlashInfer overrides this to register its per-tensor
FP8 scale buffers.
"""
return None
class DeepseekV4FlashMLASparseBackend(FlashMLASparseBackend):
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
@@ -94,10 +40,6 @@ class DeepseekV4FlashMLASparseBackend(FlashMLASparseBackend):
def get_name() -> str:
return "FLASHMLA_SPARSE_DSV4"
@staticmethod
def get_impl_cls() -> type["DeepseekV4SparseMLAAttentionImpl"]:
return DeepseekV4FlashMLASparseImpl
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
# DeepSeek V4 layout: 448 NoPE + 64 RoPE = 512 (overrides the
@@ -120,11 +62,31 @@ class DeepseekV4FlashMLASparseBackend(FlashMLASparseBackend):
return (num_blocks, block_size, head_size)
class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
"""FlashMLA sparse MLA implementation for DeepSeek V4's custom MLA layer."""
class DeepseekV4FlashMLAAttention(DeepseekV4Attention):
"""FlashMLA sparse MLA attention layer for DeepSeek V4 (CUDA)."""
backend_cls = DeepseekV4FlashMLASparseBackend
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._einsum_recipe, self._tma_aligned_scales = compute_fp8_einsum_recipe()
def _o_proj(self, o: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
return deep_gemm_fp8_o_proj(
o,
positions,
self.rotary_emb.cos_sin_cache,
self.wo_a,
self.wo_b,
n_groups=self.n_local_groups,
heads_per_group=self.n_local_heads // self.n_local_groups,
nope_dim=self.nope_head_dim,
rope_dim=self.rope_head_dim,
o_lora_rank=self.o_lora_rank,
einsum_recipe=self._einsum_recipe,
tma_aligned_scales=self._tma_aligned_scales,
)
@classmethod
def get_padded_num_q_heads(cls, num_heads: int) -> int:
# FP8 decode kernel only supports h_q = 64 or 128.
@@ -135,10 +97,8 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
)
return 64 if num_heads <= 64 else 128
@classmethod
def forward_mqa( # type: ignore[override]
cls,
layer: "DeepseekV4MLAAttention",
def forward_mqa(
self,
q: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
@@ -159,35 +119,35 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
# Warmup dummy run: no real metadata. Reserve the same bf16
# gather workspace _forward_prefill would; the dequantize / topk
# / sparse_fwd kernels are skipped this step.
swa_only = layer.compress_ratio <= 1
swa_only = self.compress_ratio <= 1
N = (
0
if swa_only
else (layer.max_model_len + layer.compress_ratio - 1)
// layer.compress_ratio
else (self.max_model_len + self.compress_ratio - 1)
// self.compress_ratio
)
M = N + layer.window_size + layer.max_num_batched_tokens
M = N + self.window_size + self.max_num_batched_tokens
current_workspace_manager().get_simultaneous(
((cls.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
((self.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
)
output.zero_()
return
assert isinstance(attn_metadata, dict)
flashmla_metadata = cast(
FlashMLASparseMetadata | None, attn_metadata.get(layer.prefix)
FlashMLASparseMetadata | None, attn_metadata.get(self.prefix)
)
swa_metadata = cast(
"DeepseekSparseSWAMetadata | None",
attn_metadata.get(layer.swa_cache_layer.prefix),
attn_metadata.get(self.swa_cache_layer.prefix),
)
assert swa_metadata is not None
swa_only = layer.compress_ratio <= 1
swa_only = self.compress_ratio <= 1
# SWA-only layers (compress_ratio <= 1) don't have their own KV cache
# allocation, so layer.kv_cache may be empty after profiling cleanup.
self_kv_cache = layer.kv_cache if not swa_only else None
swa_kv_cache = layer.swa_cache_layer.kv_cache
# allocation, so self.kv_cache may be empty after profiling cleanup.
self_kv_cache = self.kv_cache if not swa_only else None
swa_kv_cache = self.swa_cache_layer.kv_cache
# Split prefill and decode
num_decodes = swa_metadata.num_decodes
@@ -195,8 +155,7 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
num_decode_tokens = swa_metadata.num_decode_tokens
if num_prefills > 0:
cls._forward_prefill(
layer=layer,
self._forward_prefill(
q=q[num_decode_tokens:],
positions=positions[num_decode_tokens:],
compressed_k_cache=self_kv_cache,
@@ -206,8 +165,7 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
swa_metadata=swa_metadata,
)
if num_decodes > 0:
cls._forward_decode(
layer=layer,
self._forward_decode(
q=q[:num_decode_tokens],
kv_cache=self_kv_cache,
swa_metadata=swa_metadata,
@@ -216,10 +174,8 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
output=output[:num_decode_tokens],
)
@classmethod
def _forward_decode(
cls,
layer: "DeepseekV4MLAAttention",
self,
q: torch.Tensor,
kv_cache: torch.Tensor | None, # Only used when compress_ratio > 1
swa_metadata: "DeepseekSparseSWAMetadata",
@@ -235,13 +191,13 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
if not swa_only:
assert attn_metadata is not None
assert swa_metadata.is_valid_token is not None
block_size = attn_metadata.block_size // layer.compress_ratio
block_size = attn_metadata.block_size // self.compress_ratio
is_valid = swa_metadata.is_valid_token[:num_decode_tokens]
if layer.compress_ratio == 4:
if self.compress_ratio == 4:
# C4A: local indices differ per layer (filled by Indexer).
assert layer.topk_indices_buffer is not None
assert self.topk_indices_buffer is not None
global_indices, topk_lens = compute_global_topk_indices_and_lens(
layer.topk_indices_buffer[:num_decode_tokens],
self.topk_indices_buffer[:num_decode_tokens],
swa_metadata.token_to_req_indices,
attn_metadata.block_table[:num_decodes],
block_size,
@@ -258,12 +214,12 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
# We treat queries in the same seq as different queries
# and later we only attend by generated indices.
# q arrives pre-padded to layer.padded_heads by the outer wrapper.
# q arrives pre-padded to self.padded_heads by the outer wrapper.
q = q.unsqueeze(1)
# Prepare SWA cache (num_blocks, swa_block_size, 1, head_bytes)
# Use unsqueeze to preserve strides (handles padded blocks correctly)
swa_cache = layer.swa_cache_layer.kv_cache.unsqueeze(-2)
swa_cache = self.swa_cache_layer.kv_cache.unsqueeze(-2)
# Reshape KV cache to (num_blocks, block_size, 1, head_bytes)
if kv_cache is not None:
kv_cache = kv_cache.unsqueeze(-2)
@@ -274,20 +230,20 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
# and num_splits via PyTorch's graph-aware allocator so CUDA graph
# capture reuses the same addresses on replay); subsequent same-type
# layers see have_initialized=True and skip the planner.
if layer.compress_ratio <= 1:
if self.compress_ratio <= 1:
tile_metadata = swa_metadata.tile_sched_swaonly
elif layer.compress_ratio == 4:
elif self.compress_ratio == 4:
tile_metadata = swa_metadata.tile_sched_c4a
elif layer.compress_ratio == 128:
elif self.compress_ratio == 128:
tile_metadata = swa_metadata.tile_sched_c128a
else:
raise ValueError(
f"Unsupported compress_ratio={layer.compress_ratio}; "
f"Unsupported compress_ratio={self.compress_ratio}; "
"expected 1, 4, or 128."
)
assert tile_metadata is not None, (
"swa_metadata missing tile_sched entry for "
f"compress_ratio={layer.compress_ratio}; "
f"compress_ratio={self.compress_ratio}; "
"DeepseekSparseSWAMetadataBuilder.build_tile_scheduler did not "
"allocate one for this layer type."
)
@@ -302,18 +258,16 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
is_fp8_kvcache=True,
indices=swa_indices,
topk_length=swa_lens,
softmax_scale=layer.scale,
attn_sink=layer.attn_sink,
softmax_scale=self.scale,
attn_sink=self.attn_sink,
extra_k_cache=kv_cache if not swa_only else None,
extra_indices_in_kvcache=topk_indices,
extra_topk_length=topk_lens,
out=output.unsqueeze(1),
)
@classmethod
def _forward_prefill(
cls,
layer: "DeepseekV4MLAAttention",
self,
q: torch.Tensor,
positions: torch.Tensor,
compressed_k_cache: torch.Tensor | None, # Only used when compress_ratio > 1
@@ -343,9 +297,9 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
prefill_token_base = query_start_loc_cpu[num_decodes]
if not swa_only:
if layer.compress_ratio == 4:
assert layer.topk_indices_buffer is not None
topk_indices = layer.topk_indices_buffer[num_decode_tokens:]
if self.compress_ratio == 4:
assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[num_decode_tokens:]
topk_indices = topk_indices[:num_prefill_tokens]
else:
# C128A: pre-computed during metadata build.
@@ -355,16 +309,16 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
# Compressed region must fit the full compressed pool (seq_len //
# compress_ratio), not just top_k. top_k bounds how many indices
# the indexer selects, not the pool size it indexes into.
N = (layer.max_model_len + layer.compress_ratio - 1) // layer.compress_ratio
N = (self.max_model_len + self.compress_ratio - 1) // self.compress_ratio
else:
# NOTE(woosuk): topk_indices will not be used for SWA-only layers.
assert layer.topk_indices_buffer is not None
topk_indices = layer.topk_indices_buffer[num_decode_tokens:]
assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[num_decode_tokens:]
top_k = 0
N = 0
M = N + layer.window_size + layer.max_num_batched_tokens
chunk_size_const = cls.PREFILL_CHUNK_SIZE
M = N + self.window_size + self.max_num_batched_tokens
chunk_size_const = self.PREFILL_CHUNK_SIZE
num_chunks = (num_prefills + chunk_size_const - 1) // chunk_size_const
workspace_manager = current_workspace_manager()
@@ -382,10 +336,10 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
dequantize_and_gather_k_cache(
kv[:chunk_size],
compressed_k_cache,
seq_lens=seq_lens[chunk_start:chunk_end] // layer.compress_ratio,
seq_lens=seq_lens[chunk_start:chunk_end] // self.compress_ratio,
gather_lens=None,
block_table=block_table[chunk_start:chunk_end],
block_size=attn_metadata.block_size // layer.compress_ratio,
block_size=attn_metadata.block_size // self.compress_ratio,
offset=0,
)
@@ -416,8 +370,8 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
],
seq_lens[chunk_start:chunk_end],
gather_lens[chunk_start:chunk_end],
layer.window_size,
layer.compress_ratio,
self.window_size,
self.compress_ratio,
top_k,
M,
N,
@@ -426,8 +380,8 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
q=q[query_start:query_end],
kv=kv.view(-1, 1, q.shape[-1]),
indices=combined_indices.unsqueeze(1),
sm_scale=layer.scale,
attn_sink=layer.attn_sink,
sm_scale=self.scale,
attn_sink=self.attn_sink,
topk_length=combined_lens,
out=output[query_start:query_end],
)
+18 -164
View File
@@ -33,7 +33,6 @@ from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
@@ -55,13 +54,14 @@ from vllm.model_executor.models.utils import (
maybe_prefix,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.models.deepseek_v4.attention import (
DeepseekV4Indexer,
DeepseekV4MLA,
from vllm.models.deepseek_v4.attention import DeepseekV4Attention
from vllm.models.deepseek_v4.nvidia.flashinfer_sparse import (
DeepseekV4FlashInferMLAAttention,
)
from vllm.models.deepseek_v4.common.rope import build_deepseek_v4_rope
from vllm.models.deepseek_v4.nvidia.flashmla import DeepseekV4FlashMLAAttention
from vllm.models.deepseek_v4.nvidia.ops.prepare_megamoe import prepare_megamoe_inputs
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backends.registry import AttentionBackendEnum
class DeepseekV4MLP(nn.Module):
@@ -713,163 +713,18 @@ class DeepseekV4MoE(nn.Module):
self.experts.finalize_weights()
class DeepseekV4Attention(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str,
topk_indices_buffer: torch.Tensor | None = None,
aux_stream_list: list[torch.cuda.Stream] | None = None,
def _select_dsv4_attn_cls(vllm_config: VllmConfig) -> type[DeepseekV4Attention]:
"""Pick the CUDA sparse-MLA attention class for the configured backend.
An explicit ``--attention-backend FLASHINFER_MLA_SPARSE_DSV4`` selects the
FlashInfer TRTLLM-gen path; otherwise the FlashMLA path is used.
"""
if (
vllm_config.attention_config.backend
== AttentionBackendEnum.FLASHINFER_MLA_SPARSE_DSV4
):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
layer_id = extract_layer_index(prefix)
self.layer_id = layer_id
self.hidden_size = config.hidden_size
self.n_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
assert self.n_heads % tp_size == 0
self.n_local_heads = self.n_heads // tp_size
self.q_lora_rank = config.q_lora_rank
self.o_lora_rank = config.o_lora_rank
self.head_dim = config.head_dim
self.rope_head_dim = config.qk_rope_head_dim
self.nope_head_dim = self.head_dim - self.rope_head_dim
self.n_groups = config.o_groups
self.n_local_groups = self.n_groups // tp_size
self.window_size = config.sliding_window
# NOTE(zyongye) Compress ratio can't be 0
# we do this for because MTP layer is not included
# in the compress ratio list
if layer_id < config.num_hidden_layers:
self.compress_ratio = max(1, config.compress_ratios[layer_id])
else:
self.compress_ratio = 1
self.eps = config.rms_norm_eps
self.max_position_embeddings = config.max_position_embeddings
# Padded to min 64 heads for FlashMLA, initialized to -inf
# (no sink effect). Weight loading fills the first n_local_heads slots.
padded_heads = max(self.n_local_heads, 64)
self.attn_sink = nn.Parameter(
torch.full((padded_heads,), -float("inf"), dtype=torch.float32),
requires_grad=False,
)
self.fused_wqa_wkv = MergedColumnParallelLinear(
self.hidden_size,
[self.q_lora_rank, self.head_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fused_wqa_wkv",
disable_tp=True, # fused ReplicatedLinear
)
self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
self.wq_b = ColumnParallelLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wq_b",
)
self.kv_norm = RMSNorm(self.head_dim, self.eps)
self.wo_a = ColumnParallelLinear(
self.n_heads * self.head_dim // self.n_groups,
self.n_groups * self.o_lora_rank,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wo_a",
)
self.wo_a.is_bmm = True
self.wo_a.bmm_batch_size = self.n_local_groups
self.wo_b = RowParallelLinear(
self.n_groups * self.o_lora_rank,
self.hidden_size,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wo_b",
)
self.softmax_scale = self.head_dim**-0.5
self.scale_fmt = config.quantization_config["scale_fmt"]
self.rope_parameters = config.rope_scaling
# Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it)
self.rotary_emb = build_deepseek_v4_rope(
config,
head_dim=self.head_dim,
rope_head_dim=self.rope_head_dim,
max_position_embeddings=self.max_position_embeddings,
compress_ratio=self.compress_ratio,
)
self.indexer = None
if self.compress_ratio == 4:
# Only C4A uses sparse attention and hence has indexer.
# aux_stream_list[0] runs indexer.forward() in the wrapper; [2] is
# free here (outer GEMMs joined) for the inner overlap of
# wq_b+fused_indexer_q_rope_quant vs compressor.
indexer_aux_stream = (
aux_stream_list[2] if aux_stream_list is not None else None
)
self.indexer = DeepseekV4Indexer(
vllm_config,
config=config,
hidden_size=self.hidden_size,
q_lora_rank=self.q_lora_rank,
quant_config=quant_config,
cache_config=vllm_config.cache_config,
topk_indices_buffer=topk_indices_buffer,
compress_ratio=self.compress_ratio,
prefix=f"{prefix}.indexer",
aux_stream=indexer_aux_stream,
)
self.mla_attn = DeepseekV4MLA(
hidden_size=self.hidden_size,
num_heads=self.n_local_heads,
head_dim=self.head_dim,
scale=self.softmax_scale,
qk_nope_head_dim=self.nope_head_dim,
qk_rope_head_dim=self.rope_head_dim,
v_head_dim=self.head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.head_dim,
o_lora_rank=self.o_lora_rank,
vllm_config=vllm_config,
fused_wqa_wkv=self.fused_wqa_wkv,
q_norm=self.q_norm,
wq_b=self.wq_b,
kv_norm=self.kv_norm,
wo_a=self.wo_a,
wo_b=self.wo_b,
attn_sink=self.attn_sink,
rotary_emb=self.rotary_emb,
indexer=self.indexer,
indexer_rotary_emb=self.rotary_emb,
topk_indices_buffer=topk_indices_buffer,
aux_stream_list=aux_stream_list,
window_size=self.window_size,
compress_ratio=self.compress_ratio,
cache_config=vllm_config.cache_config,
quant_config=quant_config,
prefix=prefix,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None,
):
return self.mla_attn(positions, hidden_states, llama_4_scaling)
return DeepseekV4FlashInferMLAAttention
return DeepseekV4FlashMLAAttention
class DeepseekV4DecoderLayer(nn.Module):
@@ -886,7 +741,7 @@ class DeepseekV4DecoderLayer(nn.Module):
self.hidden_size = config.hidden_size
self.rms_norm_eps = config.rms_norm_eps
self.attn = DeepseekV4Attention(
self.attn = _select_dsv4_attn_cls(vllm_config)(
vllm_config,
prefix=f"{prefix}.attn",
topk_indices_buffer=topk_indices_buffer,
@@ -1043,7 +898,7 @@ class DeepseekV4Model(nn.Module):
self.rms_norm_eps = config.rms_norm_eps
# Three aux streams: one per non-default input GEMM in
# DeepseekV4MLA.attn_gemm_parallel_execute
# DeepseekV4Attention.attn_gemm_parallel_execute
# (compressor kv_score, indexer.weights_proj, indexer.compressor
# kv_score). fused_wqa_wkv stays on the default stream.
aux_stream_list = [torch.cuda.Stream() for _ in range(3)]
@@ -1341,7 +1196,6 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
".ffn.gate.bias": ".ffn.gate.e_score_correction_bias",
},
orig_to_new_substr={
".attn.compressor.": ".attn.mla_attn.compressor.",
".shared_experts.w2": ".shared_experts.down_proj",
},
)
@@ -0,0 +1,68 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from vllm.models.deepseek_v4.common.ops import fused_inv_rope_fp8_quant
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import fp8_einsum
def compute_fp8_einsum_recipe() -> tuple[tuple[int, int, int], bool]:
"""fp8_einsum recipe + scale layout for the current GPU arch.
SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128.
SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1.
Returns ``(einsum_recipe, tma_aligned_scales)`` for ``deep_gemm_fp8_o_proj``.
"""
cap = current_platform.get_device_capability()
assert cap is not None, "DeepseekV4 attention requires a CUDA device"
einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128)
tma_aligned_scales = cap.major >= 10
return einsum_recipe, tma_aligned_scales
def deep_gemm_fp8_o_proj(
o: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
wo_a: nn.Module,
wo_b: nn.Module,
*,
n_groups: int,
heads_per_group: int,
nope_dim: int,
rope_dim: int,
o_lora_rank: int,
einsum_recipe: tuple[int, int, int],
tma_aligned_scales: bool,
) -> torch.Tensor:
"""O projection: inverse RoPE + FP8 quant + einsum + wo_b.
Shared by the FlashMLA and FlashInfer CUDA backends. ``einsum_recipe`` /
``tma_aligned_scales`` come from ``compute_fp8_einsum_recipe``.
"""
o_fp8, o_scale = fused_inv_rope_fp8_quant(
o,
positions,
cos_sin_cache,
n_groups=n_groups,
heads_per_group=heads_per_group,
nope_dim=nope_dim,
rope_dim=rope_dim,
tma_aligned_scales=tma_aligned_scales,
)
z = torch.empty(
(o.shape[0], n_groups, o_lora_rank),
device=o.device,
dtype=torch.bfloat16,
)
fp8_einsum(
"bhr,hdr->bhd",
(o_fp8, o_scale),
(wo_a.weight, wo_a.weight_scale_inv),
z,
recipe=einsum_recipe,
)
return wo_b(z.flatten(1))
@@ -602,7 +602,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
fp8_use_mixed_batch = (
self.num_heads < MIN_HEADS_FOR_BF16_PREFILL and not self.is_deepseek_v4
)
# DeepseekV4 has its own attention impl (DeepseekV4MLAAttention) that does not
# DeepseekV4 has its own attention impl (DeepseekV4Attention) that does not
# consume fp8_extra_metadata. Skipping the build here avoids a
# forced D2H sync on seq_lens that would otherwise fire on every
# prefill-bearing step, lifting GPU utilization on long-prefill