mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[DSV4] Refactor DeepseekV4Attention (#44569)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -18,7 +18,6 @@ from vllm.model_executor.layers.activation import SiluAndMul, SiluAndMulWithClam
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
@@ -45,11 +44,7 @@ from vllm.model_executor.models.utils import (
|
||||
make_layers,
|
||||
maybe_prefix,
|
||||
)
|
||||
from vllm.models.deepseek_v4.attention import (
|
||||
DeepseekV4Indexer,
|
||||
DeepseekV4MLA,
|
||||
)
|
||||
from vllm.models.deepseek_v4.common.rope import build_deepseek_v4_rope
|
||||
from vllm.models.deepseek_v4.amd.rocm import DeepseekV4ROCMAiterMLAAttention
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.import_utils import has_tilelang
|
||||
@@ -225,158 +220,6 @@ class DeepseekV4MoE(nn.Module):
|
||||
return final_hidden_states.view(org_shape)
|
||||
|
||||
|
||||
class DeepseekV4Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str,
|
||||
topk_indices_buffer: torch.Tensor | None = None,
|
||||
aux_stream_list: list[torch.cuda.Stream] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
layer_id = extract_layer_index(prefix)
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.hidden_size = config.hidden_size
|
||||
self.n_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert self.n_heads % tp_size == 0
|
||||
|
||||
self.n_local_heads = self.n_heads // tp_size
|
||||
self.q_lora_rank = config.q_lora_rank
|
||||
self.o_lora_rank = config.o_lora_rank
|
||||
self.head_dim = config.head_dim
|
||||
self.rope_head_dim = config.qk_rope_head_dim
|
||||
self.nope_head_dim = self.head_dim - self.rope_head_dim
|
||||
self.n_groups = config.o_groups
|
||||
self.n_local_groups = self.n_groups // tp_size
|
||||
self.window_size = config.sliding_window
|
||||
# NOTE(zyongye) Compress ratio can't be 0
|
||||
# we do this for because MTP layer is not included
|
||||
# in the compress ratio list
|
||||
if layer_id < config.num_hidden_layers:
|
||||
self.compress_ratio = max(1, config.compress_ratios[layer_id])
|
||||
else:
|
||||
self.compress_ratio = 1
|
||||
self.eps = config.rms_norm_eps
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
|
||||
# Padded to min 64 heads for FlashMLA, initialized to -inf
|
||||
# (no sink effect). Weight loading fills the first n_local_heads slots.
|
||||
padded_heads = max(self.n_local_heads, 64)
|
||||
self.attn_sink = nn.Parameter(
|
||||
torch.full((padded_heads,), -float("inf"), dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
self.fused_wqa_wkv = MergedColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
[self.q_lora_rank, self.head_dim],
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fused_wqa_wkv",
|
||||
disable_tp=True, # fused ReplicatedLinear
|
||||
)
|
||||
self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
|
||||
self.wq_b = ColumnParallelLinear(
|
||||
self.q_lora_rank,
|
||||
self.n_heads * self.head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=f"{prefix}.wq_b",
|
||||
)
|
||||
|
||||
self.kv_norm = RMSNorm(self.head_dim, self.eps)
|
||||
self.wo_a = ColumnParallelLinear(
|
||||
self.n_heads * self.head_dim // self.n_groups,
|
||||
self.n_groups * self.o_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=f"{prefix}.wo_a",
|
||||
)
|
||||
self.wo_a.is_bmm = True
|
||||
self.wo_a.bmm_batch_size = self.n_local_groups
|
||||
self.wo_b = RowParallelLinear(
|
||||
self.n_groups * self.o_lora_rank,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=f"{prefix}.wo_b",
|
||||
)
|
||||
self.softmax_scale = self.head_dim**-0.5
|
||||
self.scale_fmt = config.quantization_config["scale_fmt"]
|
||||
|
||||
self.rope_parameters = config.rope_scaling
|
||||
|
||||
# Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it)
|
||||
self.rotary_emb = build_deepseek_v4_rope(
|
||||
config,
|
||||
head_dim=self.head_dim,
|
||||
rope_head_dim=self.rope_head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
compress_ratio=self.compress_ratio,
|
||||
)
|
||||
|
||||
self.indexer = None
|
||||
if self.compress_ratio == 4:
|
||||
# Only C4A uses sparse attention and hence has indexer.
|
||||
self.indexer = DeepseekV4Indexer(
|
||||
vllm_config,
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
quant_config=quant_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
topk_indices_buffer=topk_indices_buffer,
|
||||
compress_ratio=self.compress_ratio,
|
||||
prefix=f"{prefix}.indexer",
|
||||
)
|
||||
|
||||
self.mla_attn = DeepseekV4MLA(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=self.n_local_heads,
|
||||
head_dim=self.head_dim,
|
||||
scale=self.softmax_scale,
|
||||
qk_nope_head_dim=self.nope_head_dim,
|
||||
qk_rope_head_dim=self.rope_head_dim,
|
||||
v_head_dim=self.head_dim,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.head_dim,
|
||||
o_lora_rank=self.o_lora_rank,
|
||||
vllm_config=vllm_config,
|
||||
fused_wqa_wkv=self.fused_wqa_wkv,
|
||||
q_norm=self.q_norm,
|
||||
wq_b=self.wq_b,
|
||||
kv_norm=self.kv_norm,
|
||||
wo_a=self.wo_a,
|
||||
wo_b=self.wo_b,
|
||||
attn_sink=self.attn_sink,
|
||||
rotary_emb=self.rotary_emb,
|
||||
indexer=self.indexer,
|
||||
indexer_rotary_emb=self.rotary_emb,
|
||||
topk_indices_buffer=topk_indices_buffer,
|
||||
aux_stream_list=aux_stream_list,
|
||||
window_size=self.window_size,
|
||||
compress_ratio=self.compress_ratio,
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
llama_4_scaling: torch.Tensor | None,
|
||||
):
|
||||
return self.mla_attn(positions, hidden_states, llama_4_scaling)
|
||||
|
||||
|
||||
class DeepseekV4DecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -395,7 +238,7 @@ class DeepseekV4DecoderLayer(nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.rms_norm_eps = config.rms_norm_eps
|
||||
self.attn = DeepseekV4Attention(
|
||||
self.attn = DeepseekV4ROCMAiterMLAAttention(
|
||||
vllm_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
topk_indices_buffer=topk_indices_buffer,
|
||||
@@ -601,7 +444,7 @@ class DeepseekV4Model(nn.Module):
|
||||
self.rms_norm_eps = config.rms_norm_eps
|
||||
|
||||
# Three aux streams: one per non-default input GEMM in
|
||||
# DeepseekV4MLA.attn_gemm_parallel_execute
|
||||
# DeepseekV4Attention.attn_gemm_parallel_execute
|
||||
# (compressor kv_score, indexer.weights_proj, indexer.compressor
|
||||
# kv_score). fused_wqa_wkv stays on the default stream.
|
||||
# Disable them on ROCm because of hang issues.
|
||||
@@ -897,7 +740,6 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
|
||||
".ffn.gate.bias": ".ffn.gate.e_score_correction_bias",
|
||||
},
|
||||
orig_to_new_substr={
|
||||
".attn.compressor.": ".attn.mla_attn.compressor.",
|
||||
".shared_experts.w2": ".shared_experts.down_proj",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -2,15 +2,15 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.models.deepseek_v4.attention import DeepseekV4Attention
|
||||
from vllm.models.deepseek_v4.common.ops import dequantize_and_gather_k_cache
|
||||
from vllm.models.deepseek_v4.nvidia.flashmla import (
|
||||
DeepseekV4FlashMLASparseBackend,
|
||||
DeepseekV4SparseMLAAttentionImpl,
|
||||
)
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backend import (
|
||||
@@ -26,16 +26,12 @@ from vllm.v1.attention.backends.mla.sparse_swa import (
|
||||
)
|
||||
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
|
||||
build_ragged_indices_from_dense,
|
||||
rocm_inv_rope_einsum,
|
||||
rocm_sparse_attn_decode,
|
||||
rocm_sparse_attn_prefill,
|
||||
)
|
||||
from vllm.v1.worker.workspace import current_workspace_manager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.models.deepseek_v4.attention import (
|
||||
DeepseekV4MLAAttention,
|
||||
)
|
||||
|
||||
|
||||
def _build_indptr_from_lengths(lengths: torch.Tensor) -> torch.Tensor:
|
||||
lengths = lengths.to(dtype=torch.int32).contiguous()
|
||||
@@ -582,13 +578,9 @@ class DeepseekV4ROCMAiterMLASparseBackend(DeepseekV4FlashMLASparseBackend):
|
||||
def get_builder_cls() -> type["DeepseekV4ROCMAiterMLASparseMetadataBuilder"]:
|
||||
return DeepseekV4ROCMAiterMLASparseMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["DeepseekV4SparseMLAAttentionImpl"]:
|
||||
return DeepseekV4ROCMAiterMLASparseImpl
|
||||
|
||||
|
||||
class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
"""ROCm sparse MLA implementation used by DeepSeek V4's custom MLA layer."""
|
||||
class DeepseekV4ROCMAiterMLAAttention(DeepseekV4Attention):
|
||||
"""ROCm sparse MLA attention layer for DeepSeek V4."""
|
||||
|
||||
backend_cls = DeepseekV4ROCMAiterMLASparseBackend
|
||||
|
||||
@@ -596,10 +588,21 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
def get_padded_num_q_heads(cls, num_heads: int) -> int:
|
||||
return num_heads
|
||||
|
||||
@classmethod
|
||||
def forward_mqa( # type: ignore[override]
|
||||
cls,
|
||||
layer: "DeepseekV4MLAAttention",
|
||||
def _o_proj(self, o: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
|
||||
# ROCm BF16 reference wo_a path (inverse RoPE + einsum) + wo_b.
|
||||
z = rocm_inv_rope_einsum(
|
||||
self.rotary_emb,
|
||||
o,
|
||||
positions,
|
||||
self.rope_head_dim,
|
||||
self.n_local_groups,
|
||||
self.o_lora_rank,
|
||||
self.wo_a,
|
||||
)
|
||||
return self.wo_b(z.flatten(1))
|
||||
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
@@ -619,16 +622,16 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
# Warmup dummy run: no real metadata. Reserve the same bf16
|
||||
# gather workspace _forward_prefill would; the dequantize / topk
|
||||
# / sparse_fwd kernels are skipped this step.
|
||||
swa_only = layer.compress_ratio <= 1
|
||||
swa_only = self.compress_ratio <= 1
|
||||
N = (
|
||||
0
|
||||
if swa_only
|
||||
else (layer.max_model_len + layer.compress_ratio - 1)
|
||||
// layer.compress_ratio
|
||||
else (self.max_model_len + self.compress_ratio - 1)
|
||||
// self.compress_ratio
|
||||
)
|
||||
M = N + layer.window_size + layer.max_num_batched_tokens
|
||||
M = N + self.window_size + self.max_num_batched_tokens
|
||||
current_workspace_manager().get_simultaneous(
|
||||
((cls.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
|
||||
((self.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
|
||||
)
|
||||
output.zero_()
|
||||
return
|
||||
@@ -636,25 +639,24 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
assert isinstance(attn_metadata, dict)
|
||||
rocm_metadata = cast(
|
||||
DeepseekV4ROCMAiterMLASparseMetadata | None,
|
||||
attn_metadata.get(layer.prefix),
|
||||
attn_metadata.get(self.prefix),
|
||||
)
|
||||
swa_metadata = cast(
|
||||
DeepseekV4ROCMAiterSparseSWAMetadata | None,
|
||||
attn_metadata.get(layer.swa_cache_layer.prefix),
|
||||
attn_metadata.get(self.swa_cache_layer.prefix),
|
||||
)
|
||||
assert swa_metadata is not None
|
||||
|
||||
swa_only = layer.compress_ratio <= 1
|
||||
self_kv_cache = layer.kv_cache if not swa_only else None
|
||||
swa_kv_cache = layer.swa_cache_layer.kv_cache
|
||||
swa_only = self.compress_ratio <= 1
|
||||
self_kv_cache = self.kv_cache if not swa_only else None
|
||||
swa_kv_cache = self.swa_cache_layer.kv_cache
|
||||
|
||||
num_decodes = swa_metadata.num_decodes
|
||||
num_prefills = swa_metadata.num_prefills
|
||||
num_decode_tokens = swa_metadata.num_decode_tokens
|
||||
|
||||
if num_prefills > 0:
|
||||
cls._forward_prefill(
|
||||
layer=layer,
|
||||
self._forward_prefill(
|
||||
q=q[num_decode_tokens:],
|
||||
positions=positions[num_decode_tokens:],
|
||||
compressed_k_cache=self_kv_cache,
|
||||
@@ -664,8 +666,7 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
swa_metadata=swa_metadata,
|
||||
)
|
||||
if num_decodes > 0:
|
||||
cls._forward_decode(
|
||||
layer=layer,
|
||||
self._forward_decode(
|
||||
q=q[:num_decode_tokens],
|
||||
kv_cache=self_kv_cache,
|
||||
swa_metadata=swa_metadata,
|
||||
@@ -674,10 +675,8 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
output=output[:num_decode_tokens],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _forward_decode(
|
||||
cls,
|
||||
layer: "DeepseekV4MLAAttention",
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_cache: torch.Tensor | None,
|
||||
swa_metadata: DeepseekV4ROCMAiterSparseSWAMetadata,
|
||||
@@ -695,16 +694,16 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
if not swa_only:
|
||||
assert attn_metadata is not None
|
||||
assert swa_metadata.is_valid_token is not None
|
||||
block_size = attn_metadata.block_size // layer.compress_ratio
|
||||
block_size = attn_metadata.block_size // self.compress_ratio
|
||||
is_valid = swa_metadata.is_valid_token[:num_decode_tokens]
|
||||
if layer.compress_ratio == 4:
|
||||
assert layer.topk_indices_buffer is not None
|
||||
if self.compress_ratio == 4:
|
||||
assert self.topk_indices_buffer is not None
|
||||
(
|
||||
topk_ragged_indices,
|
||||
topk_ragged_indptr,
|
||||
topk_lens,
|
||||
) = compute_global_topk_ragged_indices_and_indptr(
|
||||
layer.topk_indices_buffer[:num_decode_tokens],
|
||||
self.topk_indices_buffer[:num_decode_tokens],
|
||||
swa_metadata.token_to_req_indices,
|
||||
attn_metadata.block_table[:num_decodes],
|
||||
block_size,
|
||||
@@ -719,7 +718,7 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
rocm_sparse_attn_decode(
|
||||
q=q,
|
||||
kv_cache=kv_cache,
|
||||
swa_k_cache=layer.swa_cache_layer.kv_cache,
|
||||
swa_k_cache=self.swa_cache_layer.kv_cache,
|
||||
swa_only=swa_only,
|
||||
topk_indices=topk_indices,
|
||||
topk_lens=topk_lens,
|
||||
@@ -729,18 +728,16 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
swa_ragged_indptr=swa_metadata.decode_swa_ragged_indptr,
|
||||
topk_ragged_indices=topk_ragged_indices,
|
||||
topk_ragged_indptr=topk_ragged_indptr,
|
||||
attn_sink=layer.attn_sink,
|
||||
scale=layer.scale,
|
||||
head_dim=layer.head_dim,
|
||||
nope_head_dim=layer.nope_head_dim,
|
||||
rope_head_dim=layer.rope_head_dim,
|
||||
attn_sink=self.attn_sink,
|
||||
scale=self.scale,
|
||||
head_dim=self.head_dim,
|
||||
nope_head_dim=self.nope_head_dim,
|
||||
rope_head_dim=self.rope_head_dim,
|
||||
output=output,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _forward_prefill(
|
||||
cls,
|
||||
layer: "DeepseekV4MLAAttention",
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
compressed_k_cache: torch.Tensor | None,
|
||||
@@ -768,34 +765,34 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
prefill_token_base = query_start_loc_cpu[num_decodes]
|
||||
|
||||
if not swa_only:
|
||||
if layer.compress_ratio == 4:
|
||||
assert layer.topk_indices_buffer is not None
|
||||
topk_indices = layer.topk_indices_buffer[num_decode_tokens:]
|
||||
if self.compress_ratio == 4:
|
||||
assert self.topk_indices_buffer is not None
|
||||
topk_indices = self.topk_indices_buffer[num_decode_tokens:]
|
||||
topk_indices = topk_indices[:num_prefill_tokens]
|
||||
else:
|
||||
assert attn_metadata is not None
|
||||
topk_indices = attn_metadata.c128a_prefill_topk_indices
|
||||
assert topk_indices is not None
|
||||
top_k = topk_indices.shape[-1]
|
||||
N = (layer.max_model_len + layer.compress_ratio - 1) // layer.compress_ratio
|
||||
N = (self.max_model_len + self.compress_ratio - 1) // self.compress_ratio
|
||||
else:
|
||||
assert layer.topk_indices_buffer is not None
|
||||
topk_indices = layer.topk_indices_buffer[num_decode_tokens:]
|
||||
assert self.topk_indices_buffer is not None
|
||||
topk_indices = self.topk_indices_buffer[num_decode_tokens:]
|
||||
top_k = 0
|
||||
N = 0
|
||||
|
||||
M = N + layer.window_size + layer.max_num_batched_tokens
|
||||
num_chunks = (num_prefills + cls.PREFILL_CHUNK_SIZE - 1) // (
|
||||
cls.PREFILL_CHUNK_SIZE
|
||||
M = N + self.window_size + self.max_num_batched_tokens
|
||||
num_chunks = (num_prefills + self.PREFILL_CHUNK_SIZE - 1) // (
|
||||
self.PREFILL_CHUNK_SIZE
|
||||
)
|
||||
|
||||
workspace_manager = current_workspace_manager()
|
||||
kv = workspace_manager.get_simultaneous(
|
||||
((cls.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
|
||||
((self.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
|
||||
)[0]
|
||||
for chunk_idx in range(num_chunks):
|
||||
chunk_start = chunk_idx * cls.PREFILL_CHUNK_SIZE
|
||||
chunk_end = min(chunk_start + cls.PREFILL_CHUNK_SIZE, num_prefills)
|
||||
chunk_start = chunk_idx * self.PREFILL_CHUNK_SIZE
|
||||
chunk_end = min(chunk_start + self.PREFILL_CHUNK_SIZE, num_prefills)
|
||||
chunk_size = chunk_end - chunk_start
|
||||
if not swa_only:
|
||||
assert attn_metadata is not None
|
||||
@@ -804,10 +801,10 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
dequantize_and_gather_k_cache(
|
||||
kv[:chunk_size],
|
||||
compressed_k_cache,
|
||||
seq_lens=seq_lens[chunk_start:chunk_end] // layer.compress_ratio,
|
||||
seq_lens=seq_lens[chunk_start:chunk_end] // self.compress_ratio,
|
||||
gather_lens=None,
|
||||
block_table=block_table[chunk_start:chunk_end],
|
||||
block_size=attn_metadata.block_size // layer.compress_ratio,
|
||||
block_size=attn_metadata.block_size // self.compress_ratio,
|
||||
offset=0,
|
||||
)
|
||||
|
||||
@@ -836,8 +833,8 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
],
|
||||
seq_lens[chunk_start:chunk_end],
|
||||
gather_lens[chunk_start:chunk_end],
|
||||
layer.window_size,
|
||||
layer.compress_ratio,
|
||||
self.window_size,
|
||||
self.compress_ratio,
|
||||
top_k,
|
||||
M,
|
||||
N,
|
||||
@@ -847,10 +844,10 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
kv=kv.view(-1, 1, q.shape[-1]),
|
||||
indices=combined_indices,
|
||||
topk_length=combined_lens,
|
||||
scale=layer.scale,
|
||||
head_dim=layer.head_dim,
|
||||
nope_head_dim=layer.nope_head_dim,
|
||||
rope_head_dim=layer.rope_head_dim,
|
||||
attn_sink=layer.attn_sink,
|
||||
scale=self.scale,
|
||||
head_dim=self.head_dim,
|
||||
nope_head_dim=self.nope_head_dim,
|
||||
rope_head_dim=self.rope_head_dim,
|
||||
attn_sink=self.attn_sink,
|
||||
output=output[query_start:query_end],
|
||||
)
|
||||
|
||||
@@ -4,8 +4,9 @@
|
||||
DeepseekV4 MLA Attention Layer
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -15,16 +16,16 @@ from transformers import DeepseekV2Config, DeepseekV3Config
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.breakable_cudagraph import eager_break_during_capture
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer
|
||||
from vllm.models.deepseek_v4.common.ops import (
|
||||
fused_indexer_q_rope_quant,
|
||||
fused_inv_rope_fp8_quant,
|
||||
fused_q_kv_rmsnorm,
|
||||
)
|
||||
from vllm.utils.deep_gemm import fp8_einsum
|
||||
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import rocm_inv_rope_einsum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.attention.backends.mla.sparse_swa import (
|
||||
@@ -42,14 +43,9 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import (
|
||||
QuantFP8,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
)
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.models.deepseek_v4.common.rope import build_deepseek_v4_rope
|
||||
from vllm.models.deepseek_v4.compressor import DeepseekCompressor
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.multi_stream_utils import (
|
||||
execute_in_parallel,
|
||||
maybe_execute_in_parallel,
|
||||
@@ -62,187 +58,209 @@ from vllm.v1.attention.backends.mla.indexer import (
|
||||
from vllm.v1.attention.backends.mla.sparse_swa import DeepseekV4SWACache
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.models.deepseek_v4.nvidia.flashmla import (
|
||||
DeepseekV4SparseMLAAttentionImpl,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _resolve_dsv4_backend(vllm_config: VllmConfig | None):
|
||||
"""Return the explicitly-requested DSv4 sparse backend enum, or None."""
|
||||
if vllm_config is None:
|
||||
return None
|
||||
attn_config = getattr(vllm_config, "attention_config", None)
|
||||
return getattr(attn_config, "backend", None) if attn_config is not None else None
|
||||
|
||||
|
||||
def _select_v4_sparse_impl(
|
||||
vllm_config: VllmConfig | None = None,
|
||||
) -> "type[DeepseekV4SparseMLAAttentionImpl]":
|
||||
"""Pick the V4 sparse MLA impl class.
|
||||
|
||||
An explicit ``--attention-backend FLASHINFER_MLA_SPARSE_DSV4`` selects the
|
||||
FlashInfer TRTLLM-gen path; otherwise the platform default (FlashMLA on
|
||||
NVIDIA, ROCm Aiter on AMD) is used.
|
||||
"""
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
backend = _resolve_dsv4_backend(vllm_config)
|
||||
if backend == AttentionBackendEnum.FLASHINFER_MLA_SPARSE_DSV4:
|
||||
from vllm.models.deepseek_v4.nvidia.flashinfer_sparse import (
|
||||
DeepseekV4FlashInferMLASparseImpl,
|
||||
)
|
||||
|
||||
logger.info_once("Using FLASHINFER_MLA_SPARSE_DSV4 backend.")
|
||||
return DeepseekV4FlashInferMLASparseImpl
|
||||
if current_platform.is_rocm():
|
||||
from vllm.models.deepseek_v4.amd.rocm import (
|
||||
DeepseekV4ROCMAiterMLASparseImpl,
|
||||
)
|
||||
|
||||
logger.info_once("Using ROCM_FLASHMLA_SPARSE_DSV4 backend.")
|
||||
return DeepseekV4ROCMAiterMLASparseImpl
|
||||
from vllm.models.deepseek_v4.nvidia.flashmla import (
|
||||
DeepseekV4FlashMLASparseImpl,
|
||||
)
|
||||
|
||||
logger.info_once("Using FLASHMLA_SPARSE_DSV4 backend.")
|
||||
return DeepseekV4FlashMLASparseImpl
|
||||
|
||||
|
||||
def _resolve_dsv4_kv_cache_dtype(
|
||||
backend,
|
||||
use_flashmla_fp8_layout: bool,
|
||||
kv_cache_dtype: str,
|
||||
cache_config: CacheConfig | None,
|
||||
) -> tuple[str, torch.dtype]:
|
||||
"""Map ``(backend, --kv-cache-dtype)`` to ``(cache_dtype_str, torch_dtype)``.
|
||||
"""Map ``(layout, --kv-cache-dtype)`` to ``(cache_dtype_str, torch_dtype)``.
|
||||
|
||||
FlashInfer V4 reads a contiguous 512-wide KV row (bf16 or per-tensor FP8
|
||||
E4M3); FlashMLA V4 reads the legacy UE8M0 paged layout (uint8 /
|
||||
``fp8_ds_mla``). For FlashMLA the canonical ``fp8_ds_mla`` string is
|
||||
written back onto ``cache_config`` so the page-size specs pick the 576B
|
||||
layout.
|
||||
Both layouts are paged; they differ in the per-token block format. The
|
||||
FlashMLA fp8 layout (FlashMLA / ROCm Aiter) is the ``fp8_ds_mla`` format:
|
||||
UE8M0 block-scaled fp8 packed as ``uint8`` (the canonical ``fp8_ds_mla``
|
||||
string is written back onto ``cache_config`` so the page-size specs pick
|
||||
the 576B per-token slot). Otherwise (FlashInfer) each token's KV row is
|
||||
stored in its plain element dtype — bf16 or per-tensor FP8 E4M3.
|
||||
"""
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
if use_flashmla_fp8_layout:
|
||||
# fp8_ds_mla block format: UE8M0 block-scaled fp8 packed as uint8.
|
||||
assert kv_cache_dtype.startswith("fp8"), (
|
||||
f"DeepseekV4 FlashMLA fp8 layout only supports fp8 kv-cache, "
|
||||
f"got {kv_cache_dtype}"
|
||||
)
|
||||
if kv_cache_dtype != "fp8_ds_mla":
|
||||
if cache_config is not None:
|
||||
cache_config.cache_dtype = "fp8_ds_mla"
|
||||
kv_cache_dtype = "fp8_ds_mla"
|
||||
logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.")
|
||||
return kv_cache_dtype, torch.uint8
|
||||
|
||||
if backend == AttentionBackendEnum.FLASHINFER_MLA_SPARSE_DSV4:
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
return kv_cache_dtype, torch.float8_e4m3fn
|
||||
# auto / bfloat16 -> contiguous BF16 cache.
|
||||
return kv_cache_dtype, torch.bfloat16
|
||||
|
||||
# FlashMLA (and ROCm Aiter): legacy UE8M0 paged uint8 cache.
|
||||
assert kv_cache_dtype.startswith("fp8"), (
|
||||
f"DeepseekV4 FlashMLA sparse backend only supports fp8 kv-cache, "
|
||||
f"got {kv_cache_dtype}"
|
||||
)
|
||||
if kv_cache_dtype != "fp8_ds_mla":
|
||||
if cache_config is not None:
|
||||
cache_config.cache_dtype = "fp8_ds_mla"
|
||||
kv_cache_dtype = "fp8_ds_mla"
|
||||
logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.")
|
||||
return kv_cache_dtype, torch.uint8
|
||||
# Plain bf16 / per-tensor fp8 KV row (FlashInfer).
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
return kv_cache_dtype, torch.float8_e4m3fn
|
||||
# auto / bfloat16 -> plain bf16 KV row.
|
||||
return kv_cache_dtype, torch.bfloat16
|
||||
|
||||
|
||||
class DeepseekV4MLA(nn.Module):
|
||||
class DeepseekV4Attention(nn.Module, AttentionLayerBase, ABC):
|
||||
"""DeepseekV4 MLA attention layer.
|
||||
|
||||
The platform-specific sparse-MLA forward (``forward_mqa`` /
|
||||
``get_padded_num_q_heads`` / ``_o_proj`` / ``backend_cls``) is provided by a
|
||||
subclass — ``DeepseekV4FlashMLAAttention`` / ``DeepseekV4FlashInferMLAAttention``
|
||||
(CUDA) or ``DeepseekV4ROCMAiterMLAAttention`` (ROCm) — selected by the
|
||||
platform-specific deepseek_v4 model module. The base is never instantiated
|
||||
directly.
|
||||
"""
|
||||
|
||||
# Provided by the platform subclass.
|
||||
backend_cls: ClassVar[type[AttentionBackend]]
|
||||
# KV-cache per-token block format (both layouts are paged). True (default)
|
||||
# = FlashMLA / ROCm fp8_ds_mla (UE8M0 block-scaled fp8 packed as uint8);
|
||||
# False = FlashInfer plain bf16 / per-tensor fp8 KV row.
|
||||
use_flashmla_fp8_layout: ClassVar[bool] = True
|
||||
# Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather
|
||||
# workspace allocated in _forward_prefill and is also read by the dummy-run
|
||||
# path to pre-reserve that workspace.
|
||||
PREFILL_CHUNK_SIZE: ClassVar[int] = 4
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_padded_num_q_heads(cls, num_heads: int) -> int:
|
||||
"""Q head count the q/output buffers are allocated at.
|
||||
|
||||
The layer allocates the q/output buffers at
|
||||
``[N, get_padded_num_q_heads(n_local_heads), head_dim]``. Must satisfy
|
||||
``result >= num_heads``. Backends with no padding constraint return
|
||||
``num_heads``.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
) -> None:
|
||||
"""Platform-specific sparse MLA forward; writes attention into ``output``."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _o_proj(self, o: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
|
||||
"""Inverse-RoPE + wo_a + wo_b output projection (platform-specific)."""
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
scale: float,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
q_lora_rank: int | None,
|
||||
kv_lora_rank: int,
|
||||
o_lora_rank: int | None,
|
||||
vllm_config: VllmConfig,
|
||||
fused_wqa_wkv: torch.nn.Module,
|
||||
q_norm: torch.nn.Module,
|
||||
wq_b: torch.nn.Module,
|
||||
kv_norm: torch.nn.Module,
|
||||
wo_a: torch.nn.Module,
|
||||
wo_b: torch.nn.Module,
|
||||
attn_sink: torch.nn.Module,
|
||||
rotary_emb: torch.nn.Module,
|
||||
indexer: torch.nn.Module | None,
|
||||
indexer_rotary_emb: torch.nn.Module,
|
||||
topk_indices_buffer: torch.Tensor | None,
|
||||
aux_stream_list: list[torch.cuda.Stream] | None,
|
||||
window_size: int,
|
||||
compress_ratio: int | None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
prefix: str,
|
||||
topk_indices_buffer: torch.Tensor | None = None,
|
||||
aux_stream_list: list[torch.cuda.Stream] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.n_local_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.scale = scale
|
||||
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.window_size = window_size
|
||||
self.compress_ratio = compress_ratio if compress_ratio is not None else 1
|
||||
self.prefix = prefix
|
||||
|
||||
# Extract config from vllm_config
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
cache_config = vllm_config.cache_config
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
layer_id = extract_layer_index(prefix)
|
||||
|
||||
# DeepseekV4-specific attributes (num_heads is already TP-adjusted)
|
||||
self.eps = config.rms_norm_eps
|
||||
self.rope_head_dim = config.qk_rope_head_dim
|
||||
self.nope_head_dim = head_dim - self.rope_head_dim
|
||||
self.n_local_groups = config.o_groups // tp_size
|
||||
self.prefix = prefix # Alias for compatibility with compressor
|
||||
self.hidden_size = config.hidden_size
|
||||
self.n_heads = config.num_attention_heads
|
||||
assert self.n_heads % tp_size == 0
|
||||
self.n_local_heads = self.n_heads // tp_size
|
||||
self.q_lora_rank = config.q_lora_rank
|
||||
self.o_lora_rank = config.o_lora_rank
|
||||
self.head_dim = config.head_dim
|
||||
self.rope_head_dim = config.qk_rope_head_dim
|
||||
self.nope_head_dim = self.head_dim - self.rope_head_dim
|
||||
self.n_groups = config.o_groups
|
||||
self.n_local_groups = self.n_groups // tp_size
|
||||
self.window_size = config.sliding_window
|
||||
# NOTE(zyongye) Compress ratio can't be 0
|
||||
# we do this for because MTP layer is not included
|
||||
# in the compress ratio list
|
||||
if layer_id < config.num_hidden_layers:
|
||||
self.compress_ratio = max(1, config.compress_ratios[layer_id])
|
||||
else:
|
||||
self.compress_ratio = 1
|
||||
self.eps = config.rms_norm_eps
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
# Store projection modules
|
||||
self.fused_wqa_wkv = fused_wqa_wkv
|
||||
self.q_norm = q_norm
|
||||
self.wq_b = wq_b
|
||||
|
||||
self.kv_norm = kv_norm
|
||||
self.wo_a = wo_a
|
||||
|
||||
self._wo_a_act_quant = QuantFP8(
|
||||
static=False,
|
||||
group_shape=GroupShape(1, 128),
|
||||
use_ue8m0=True,
|
||||
# Padded Q head count is dictated by the platform subclass.
|
||||
self.padded_heads = self.get_padded_num_q_heads(self.n_local_heads)
|
||||
# Sink padded to the same head count, initialized to -inf (no sink
|
||||
# effect). Weight loading fills the first n_local_heads slots.
|
||||
self.attn_sink = nn.Parameter(
|
||||
torch.full((self.padded_heads,), -float("inf"), dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
# Bypass packed-for-deepgemm path — we need FP32 scales (not packed
|
||||
# INT32) so fp8_einsum can handle layout transform internally.
|
||||
self._wo_a_act_quant.use_deep_gemm_supported = False
|
||||
self.wo_b = wo_b
|
||||
|
||||
# Pick fp8_einsum recipe based on GPU arch:
|
||||
# SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128
|
||||
# SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1
|
||||
cap = current_platform.get_device_capability()
|
||||
assert cap is not None, "DeepseekV4 attention requires a CUDA device"
|
||||
self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128)
|
||||
self._tma_aligned_scales = cap.major >= 10
|
||||
self.fused_wqa_wkv = MergedColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
[self.q_lora_rank, self.head_dim],
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fused_wqa_wkv",
|
||||
disable_tp=True, # fused ReplicatedLinear
|
||||
)
|
||||
self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
|
||||
self.wq_b = ColumnParallelLinear(
|
||||
self.q_lora_rank,
|
||||
self.n_heads * self.head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=f"{prefix}.wq_b",
|
||||
)
|
||||
|
||||
self.rotary_emb = rotary_emb
|
||||
self.indexer_rotary_emb = indexer_rotary_emb
|
||||
self.kv_norm = RMSNorm(self.head_dim, self.eps)
|
||||
self.wo_a = ColumnParallelLinear(
|
||||
self.n_heads * self.head_dim // self.n_groups,
|
||||
self.n_groups * self.o_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=f"{prefix}.wo_a",
|
||||
)
|
||||
self.wo_a.is_bmm = True
|
||||
self.wo_a.bmm_batch_size = self.n_local_groups
|
||||
self.wo_b = RowParallelLinear(
|
||||
self.n_groups * self.o_lora_rank,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=f"{prefix}.wo_b",
|
||||
)
|
||||
|
||||
# Initialize rotary embedding before the indexer/compressor consume it.
|
||||
self.rotary_emb = build_deepseek_v4_rope(
|
||||
config,
|
||||
head_dim=self.head_dim,
|
||||
rope_head_dim=self.rope_head_dim,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
compress_ratio=self.compress_ratio,
|
||||
)
|
||||
self.indexer_rotary_emb = self.rotary_emb
|
||||
self.topk_indices_buffer = topk_indices_buffer
|
||||
|
||||
self.indexer = indexer
|
||||
|
||||
# Per-head RMS normalization for Q (no learnable weights)
|
||||
self.q_head_norm = RMSNorm(head_dim, eps=self.eps, has_weight=False)
|
||||
|
||||
# TODO(yifan): currently hardcoded for FP8 sparse, make it more generic
|
||||
head_bytes = (
|
||||
self.nope_head_dim # 448 fp8 NoPE
|
||||
+ self.rope_head_dim * 2 # 64 bf16 RoPE
|
||||
+ self.nope_head_dim // 64 # 7B scale factors
|
||||
+ 1 # 1B pad
|
||||
)
|
||||
self.indexer = None
|
||||
if self.compress_ratio == 4:
|
||||
# Only C4A uses sparse attention and hence has indexer.
|
||||
# aux_stream_list[2] is free here (outer GEMMs joined) for the inner
|
||||
# overlap of wq_b+fused_indexer_q_rope_quant vs compressor. None on
|
||||
# ROCm, where aux_stream_list is None.
|
||||
indexer_aux_stream = (
|
||||
aux_stream_list[2] if aux_stream_list is not None else None
|
||||
)
|
||||
self.indexer = DeepseekV4Indexer(
|
||||
vllm_config,
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
quant_config=quant_config,
|
||||
cache_config=cache_config,
|
||||
topk_indices_buffer=topk_indices_buffer,
|
||||
compress_ratio=self.compress_ratio,
|
||||
prefix=f"{prefix}.indexer",
|
||||
aux_stream=indexer_aux_stream,
|
||||
)
|
||||
|
||||
# Will be None on ROCm for now.
|
||||
self.aux_stream_list = aux_stream_list
|
||||
@@ -252,45 +270,39 @@ class DeepseekV4MLA(nn.Module):
|
||||
self.ln_events = [torch.cuda.Event() for _ in range(4)]
|
||||
|
||||
assert cache_config is not None, "DeepseekV4 attention requires cache_config"
|
||||
# Resolve the SWA cache tensor dtype from the selected backend: FlashMLA
|
||||
# uses the legacy UE8M0 paged uint8 layout; FlashInfer uses a contiguous
|
||||
# bf16 / per-tensor fp8 row.
|
||||
backend = _resolve_dsv4_backend(vllm_config)
|
||||
_, swa_cache_torch_dtype = _resolve_dsv4_kv_cache_dtype(
|
||||
backend, cache_config.cache_dtype, cache_config
|
||||
# ---- Attention / KV-cache setup ----
|
||||
self.max_num_batched_tokens = (
|
||||
vllm_config.scheduler_config.max_num_batched_tokens
|
||||
)
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
# Resolve the kv-cache dtype from this backend's block format (a
|
||||
# ClassVar set by the subclass): fp8_ds_mla (UE8M0 block-scaled fp8 as
|
||||
# uint8) for FlashMLA / ROCm, vs a plain bf16 / per-tensor fp8 row for
|
||||
# FlashInfer. The same resolution drives the SWA cache tensor dtype
|
||||
# below.
|
||||
self.kv_cache_dtype, self.kv_cache_torch_dtype = _resolve_dsv4_kv_cache_dtype(
|
||||
self.use_flashmla_fp8_layout, cache_config.cache_dtype, cache_config
|
||||
)
|
||||
|
||||
self.swa_cache_layer = DeepseekV4SWACache(
|
||||
head_dim=self.head_dim,
|
||||
window_size=self.window_size,
|
||||
dtype=swa_cache_torch_dtype,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
prefix=f"{prefix}.swa_cache",
|
||||
cache_config=cache_config,
|
||||
)
|
||||
|
||||
self.mla_attn = DeepseekV4MLAAttention(
|
||||
num_heads=self.n_local_heads,
|
||||
head_dim=self.head_dim,
|
||||
scale=self.scale,
|
||||
qk_nope_head_dim=self.nope_head_dim,
|
||||
qk_rope_head_dim=self.rope_head_dim,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
compress_ratio=self.compress_ratio,
|
||||
window_size=self.window_size,
|
||||
head_bytes=head_bytes,
|
||||
swa_cache_layer=self.swa_cache_layer,
|
||||
attn_sink=attn_sink, # already padded with -inf
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
indexer=self.indexer,
|
||||
topk_indices_buffer=self.topk_indices_buffer,
|
||||
)
|
||||
# Mirror the inner layer's padded head count (single source of truth).
|
||||
self.padded_heads = self.mla_attn.padded_heads
|
||||
# Register with compilation context for metadata lookup.
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix and prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
if prefix:
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.kv_cache = torch.tensor([])
|
||||
|
||||
# Create the compressor for layers with compress_ratio > 1; after
|
||||
# creating the DeepseekV4MLAAttention layer to get its cache.
|
||||
# Create the compressor for layers with compress_ratio > 1; after the
|
||||
# attention setup above so its KV-cache prefix (self.prefix) is set.
|
||||
self.compressor = None
|
||||
if self.compress_ratio > 1:
|
||||
self.compressor = DeepseekCompressor(
|
||||
@@ -300,7 +312,7 @@ class DeepseekV4MLA(nn.Module):
|
||||
head_dim=self.head_dim,
|
||||
rotate=True,
|
||||
prefix=f"{prefix}.compressor",
|
||||
k_cache_prefix=self.mla_attn.prefix,
|
||||
k_cache_prefix=self.prefix,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -324,48 +336,8 @@ class DeepseekV4MLA(nn.Module):
|
||||
self.attention_impl(hidden_states, positions, o_padded)
|
||||
o = o_padded[:, : self.n_local_heads, :]
|
||||
|
||||
# Keep ROCm on the BF16 reference wo_a path util kernel ready.
|
||||
if current_platform.is_rocm():
|
||||
z = rocm_inv_rope_einsum(
|
||||
self.rotary_emb,
|
||||
o,
|
||||
positions,
|
||||
self.rope_head_dim,
|
||||
self.n_local_groups,
|
||||
self.o_lora_rank,
|
||||
self.wo_a,
|
||||
)
|
||||
return self.wo_b(z.flatten(1))
|
||||
|
||||
# O projection: inverse RoPE + FP8 quant + einsum + wo_b
|
||||
o_fp8, o_scale = fused_inv_rope_fp8_quant(
|
||||
o,
|
||||
positions,
|
||||
self.rotary_emb.cos_sin_cache,
|
||||
n_groups=self.n_local_groups,
|
||||
heads_per_group=self.n_local_heads // self.n_local_groups,
|
||||
nope_dim=self.nope_head_dim,
|
||||
rope_dim=self.rope_head_dim,
|
||||
tma_aligned_scales=self._tma_aligned_scales,
|
||||
)
|
||||
|
||||
wo_a_fp8 = self.wo_a.weight
|
||||
wo_a_scale = self.wo_a.weight_scale_inv
|
||||
|
||||
z = torch.empty(
|
||||
(num_tokens, self.n_local_groups, self.o_lora_rank),
|
||||
device=o.device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
fp8_einsum(
|
||||
"bhr,hdr->bhd",
|
||||
(o_fp8, o_scale),
|
||||
(wo_a_fp8, wo_a_scale),
|
||||
z,
|
||||
recipe=self._einsum_recipe,
|
||||
)
|
||||
|
||||
return self.wo_b(z.flatten(1))
|
||||
# Inverse-RoPE + wo_a + wo_b output projection (platform-specific).
|
||||
return self._o_proj(o, positions)
|
||||
|
||||
def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]:
|
||||
aux_streams = self.aux_stream_list
|
||||
@@ -451,7 +423,7 @@ class DeepseekV4MLA(nn.Module):
|
||||
)
|
||||
|
||||
# wq_b + kv_insert (+ MLA compressor when an indexer is present) ride
|
||||
# on the default stream so q stays on its consumer stream (mla_attn
|
||||
# on the default stream so q stays on its consumer stream (forward_mqa
|
||||
# downstream reads q on default). Indexer/compressor go on aux for
|
||||
# overlap with default's GEMM + cache write.
|
||||
if self.indexer is not None:
|
||||
@@ -514,7 +486,7 @@ class DeepseekV4MLA(nn.Module):
|
||||
|
||||
# MLA attention writes into the pre-allocated `out` buffer
|
||||
# ([num_tokens, padded_heads, head_dim]).
|
||||
self.mla_attn(q, kv, positions, output=out)
|
||||
self.forward_mqa(q, kv, positions, out)
|
||||
|
||||
def _fused_qnorm_rope_kv_insert(
|
||||
self,
|
||||
@@ -549,7 +521,7 @@ class DeepseekV4MLA(nn.Module):
|
||||
cos_sin_cache = self.rotary_emb.cos_sin_cache
|
||||
cache_dtype = swa_kv_cache.dtype
|
||||
|
||||
# kv is unchanged; mla_attn reads kv solely via swa_kv_cache.
|
||||
# kv is unchanged; attention reads kv solely via swa_kv_cache.
|
||||
if cache_dtype == torch.uint8:
|
||||
# Legacy FlashMLA UE8M0 paged path. Horizontally fused:
|
||||
# Q side: per-head RMSNorm (no weight) + GPT-J RoPE, zero-filling
|
||||
@@ -569,9 +541,10 @@ class DeepseekV4MLA(nn.Module):
|
||||
swa_metadata.block_size,
|
||||
)
|
||||
|
||||
# FlashInfer full-cache path: contiguous [num_blocks, block_size, 512]
|
||||
# cache (no Q padding). bf16 rewrites q in place; per-tensor fp8 writes a
|
||||
# separately-allocated fp8 q and quantizes the KV row.
|
||||
# FlashInfer full-cache path: the [num_blocks, block_size, 512] cache
|
||||
# stores the KV row in its plain dtype (no Q padding). bf16 rewrites q
|
||||
# in place; per-tensor fp8 writes a separately-allocated fp8 q and
|
||||
# quantizes the KV row.
|
||||
block_size = swa_metadata.block_size
|
||||
swa_kv_cache_3d = swa_kv_cache.view(-1, block_size, self.head_dim)
|
||||
if cache_dtype == torch.bfloat16:
|
||||
@@ -597,99 +570,13 @@ class DeepseekV4MLA(nn.Module):
|
||||
swa_metadata.slot_mapping,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
self.mla_attn._flashinfer_fp8_kv_scale,
|
||||
self.mla_attn._flashinfer_fp8_q_scale_inv,
|
||||
self._flashinfer_fp8_kv_scale,
|
||||
self._flashinfer_fp8_q_scale_inv,
|
||||
self.eps,
|
||||
block_size,
|
||||
)
|
||||
return q_fp8
|
||||
|
||||
|
||||
class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
scale: float,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
q_lora_rank: int | None,
|
||||
kv_lora_rank: int,
|
||||
compress_ratio: int,
|
||||
window_size: int,
|
||||
head_bytes: int,
|
||||
swa_cache_layer: DeepseekV4SWACache,
|
||||
attn_sink: torch.Tensor,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
# Sparse MLA Args
|
||||
indexer: object | None = None,
|
||||
topk_indices_buffer: torch.Tensor | None = None,
|
||||
aux_stream: torch.cuda.Stream | None = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.impl_cls = _select_v4_sparse_impl(vllm_config)
|
||||
self.backend_cls = self.impl_cls.backend_cls
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = 1
|
||||
self.head_dim = head_dim
|
||||
self.scale = scale
|
||||
self.window_size = window_size
|
||||
self.head_bytes = head_bytes
|
||||
self.compress_ratio = compress_ratio
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.nope_head_dim = qk_nope_head_dim
|
||||
self.rope_head_dim = qk_rope_head_dim
|
||||
self.indexer = indexer
|
||||
self.topk_indices_buffer = topk_indices_buffer
|
||||
|
||||
self.prefix = prefix # Alias for compatibility with compressor
|
||||
|
||||
self.aux_stream = aux_stream
|
||||
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
|
||||
|
||||
# Padded Q head count is dictated by the selected impl.
|
||||
self.padded_heads = self.impl_cls.get_padded_num_q_heads(num_heads)
|
||||
|
||||
# Store attention sink
|
||||
assert attn_sink is not None
|
||||
self.attn_sink: torch.Tensor = attn_sink
|
||||
# Store SWA cache
|
||||
assert swa_cache_layer is not None
|
||||
self.swa_cache_layer: DeepseekV4SWACache = swa_cache_layer
|
||||
|
||||
# Get vllm config for cache setup
|
||||
self.max_num_batched_tokens = (
|
||||
vllm_config.scheduler_config.max_num_batched_tokens
|
||||
)
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
# Resolve the kv-cache dtype from the selected backend. FlashMLA uses
|
||||
# the legacy UE8M0 paged uint8 (fp8_ds_mla) layout; FlashInfer uses a
|
||||
# contiguous bf16 / per-tensor fp8 row.
|
||||
backend = _resolve_dsv4_backend(vllm_config)
|
||||
kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "fp8"
|
||||
self.kv_cache_dtype, self.kv_cache_torch_dtype = _resolve_dsv4_kv_cache_dtype(
|
||||
backend, kv_cache_dtype, cache_config
|
||||
)
|
||||
|
||||
# Per-impl layer buffers (e.g. FlashInfer FP8 scale buffers). No-op for
|
||||
# the FlashMLA / ROCm impls.
|
||||
self.impl_cls.init_layer_buffers(self)
|
||||
|
||||
# Register with compilation context for metadata lookup
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix and prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
if prefix:
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
self.kv_cache = torch.tensor([])
|
||||
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
return self.backend_cls
|
||||
|
||||
@@ -698,8 +585,9 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
|
||||
self.compress_ratio <= 1
|
||||
): # SWA part. Allocated separately as DeepseekV4SWACache.
|
||||
return None
|
||||
# FlashMLA uses the UE8M0 paged uint8 layout (576B aligned); FlashInfer
|
||||
# uses a contiguous bf16 / per-tensor fp8 cache with no extra alignment.
|
||||
# FlashMLA uses the fp8_ds_mla block format (UE8M0 block-scaled fp8 as
|
||||
# uint8, 576B aligned); FlashInfer stores a plain bf16 / per-tensor fp8
|
||||
# row with no extra alignment.
|
||||
is_flashmla = self.kv_cache_dtype == "fp8_ds_mla"
|
||||
return MLAAttentionSpec(
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
@@ -712,15 +600,6 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
|
||||
model_version="deepseek_v4",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
) -> None:
|
||||
self.impl_cls.forward_mqa(self, q, kv, positions, output)
|
||||
|
||||
|
||||
class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase):
|
||||
def __init__(
|
||||
|
||||
@@ -3,28 +3,30 @@
|
||||
"""DeepSeek V4 FlashInfer TRTLLM-gen sparse MLA backend.
|
||||
|
||||
Uses FlashInfer's public ``trtllm_batch_decode_sparse_mla_dsv4`` launcher with a
|
||||
contiguous bf16 / per-tensor FP8 KV cache. Shares the V4 sparse-index pipeline
|
||||
(SWA cache + compressor + indexer, 256-token blocks, head_size 512) with the
|
||||
FlashMLA V4 backend; only the attention forward differs.
|
||||
plain bf16 / per-tensor FP8 KV row (vs FlashMLA's packed ``fp8_ds_mla`` block
|
||||
format). Shares the V4 sparse-index pipeline (SWA cache + compressor + indexer,
|
||||
256-token blocks, head_size 512) with the FlashMLA V4 backend; only the
|
||||
attention forward differs.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from typing import TYPE_CHECKING, ClassVar, cast
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.models.deepseek_v4.attention import DeepseekV4Attention
|
||||
from vllm.models.deepseek_v4.common.ops import (
|
||||
build_flashinfer_mixed_sparse_indices,
|
||||
)
|
||||
from vllm.models.deepseek_v4.nvidia.flashmla import (
|
||||
DeepseekV4FlashMLASparseBackend,
|
||||
DeepseekV4SparseMLAAttentionImpl,
|
||||
from vllm.models.deepseek_v4.nvidia.flashmla import DeepseekV4FlashMLASparseBackend
|
||||
from vllm.models.deepseek_v4.nvidia.ops.o_proj import (
|
||||
compute_fp8_einsum_recipe,
|
||||
deep_gemm_fp8_o_proj,
|
||||
)
|
||||
from vllm.utils.flashinfer import flashinfer_trtllm_batch_decode_sparse_mla_dsv4
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import FlashMLASparseMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.models.deepseek_v4.attention import DeepseekV4MLAAttention
|
||||
from vllm.v1.attention.backends.mla.sparse_swa import DeepseekSparseSWAMetadata
|
||||
|
||||
# 128 MB TRTLLM-gen workspace, allocated once per device and zero-initialized
|
||||
@@ -51,23 +53,21 @@ class DeepseekV4FlashInferMLASparseBackend(DeepseekV4FlashMLASparseBackend):
|
||||
Inheriting from the FlashMLA V4 backend reuses its ``FlashMLASparseMetadata``
|
||||
builder (which the V4 sparse-index pipeline needs — the V3.2 FlashInfer
|
||||
builder lacks the ``c128a_*`` fields), 256-token blocks, head_size 512, and
|
||||
the contiguous (num_blocks, block_size, 512) cache shape for non-``fp8_ds_mla``
|
||||
dtypes.
|
||||
the (num_blocks, block_size, 512) cache shape for non-``fp8_ds_mla`` dtypes.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHINFER_MLA_SPARSE_DSV4"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["DeepseekV4FlashInferMLASparseImpl"]:
|
||||
return DeepseekV4FlashInferMLASparseImpl
|
||||
|
||||
|
||||
class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
"""FlashInfer TRTLLM-gen sparse MLA implementation for DeepSeek V4."""
|
||||
class DeepseekV4FlashInferMLAAttention(DeepseekV4Attention):
|
||||
"""FlashInfer TRTLLM-gen sparse MLA attention layer for DeepSeek V4."""
|
||||
|
||||
backend_cls = DeepseekV4FlashInferMLASparseBackend
|
||||
# FlashInfer stores a plain bf16 / per-tensor fp8 KV row, not the FlashMLA
|
||||
# packed fp8_ds_mla block format (UE8M0 block-scaled fp8 as uint8).
|
||||
use_flashmla_fp8_layout: ClassVar[bool] = False
|
||||
|
||||
@classmethod
|
||||
def get_padded_num_q_heads(cls, num_heads: int) -> int:
|
||||
@@ -79,27 +79,44 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
)
|
||||
return 64 if num_heads <= 64 else 128
|
||||
|
||||
@classmethod
|
||||
def init_layer_buffers(cls, layer: "DeepseekV4MLAAttention") -> None:
|
||||
def _o_proj(self, o: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
|
||||
return deep_gemm_fp8_o_proj(
|
||||
o,
|
||||
positions,
|
||||
self.rotary_emb.cos_sin_cache,
|
||||
self.wo_a,
|
||||
self.wo_b,
|
||||
n_groups=self.n_local_groups,
|
||||
heads_per_group=self.n_local_heads // self.n_local_groups,
|
||||
nope_dim=self.nope_head_dim,
|
||||
rope_dim=self.rope_head_dim,
|
||||
o_lora_rank=self.o_lora_rank,
|
||||
einsum_recipe=self._einsum_recipe,
|
||||
tma_aligned_scales=self._tma_aligned_scales,
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._einsum_recipe, self._tma_aligned_scales = compute_fp8_einsum_recipe()
|
||||
# Per-tensor FP8 scale buffers + precomputed scalar BMM scales. Only the
|
||||
# per-tensor FP8 cache path consumes these; bf16 reads ``layer.scale``.
|
||||
if layer.kv_cache_torch_dtype != torch.float8_e4m3fn:
|
||||
# per-tensor FP8 cache path consumes these; bf16 reads ``self.scale``.
|
||||
if self.kv_cache_torch_dtype != torch.float8_e4m3fn:
|
||||
return
|
||||
# TODO: load real per-tensor Q/KV scales from the checkpoint; unit
|
||||
# scales until the scale tensor names are wired.
|
||||
fp8_q_scale = 1.0
|
||||
fp8_kv_scale = 1.0
|
||||
layer.register_buffer(
|
||||
self.register_buffer(
|
||||
"_flashinfer_fp8_q_scale",
|
||||
torch.tensor([fp8_q_scale], dtype=torch.float32),
|
||||
persistent=False,
|
||||
)
|
||||
layer.register_buffer(
|
||||
self.register_buffer(
|
||||
"_flashinfer_fp8_q_scale_inv",
|
||||
torch.tensor([1.0 / fp8_q_scale], dtype=torch.float32),
|
||||
persistent=False,
|
||||
)
|
||||
layer.register_buffer(
|
||||
self.register_buffer(
|
||||
"_flashinfer_fp8_kv_scale",
|
||||
torch.tensor([fp8_kv_scale], dtype=torch.float32),
|
||||
persistent=False,
|
||||
@@ -107,13 +124,11 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
# TRTLLM-gen takes scalar scale args on a distinct (correct) C++ path
|
||||
# vs 1-elem tensors, so these are Python floats. bmm1 folds the softmax
|
||||
# scale and the Q/KV per-tensor scales; bmm2 is the KV scale.
|
||||
layer._flashinfer_fp8_bmm1_scale = layer.scale * fp8_q_scale * fp8_kv_scale
|
||||
layer._flashinfer_fp8_bmm2_scale = fp8_kv_scale
|
||||
self._flashinfer_fp8_bmm1_scale = self.scale * fp8_q_scale * fp8_kv_scale
|
||||
self._flashinfer_fp8_bmm2_scale = fp8_kv_scale
|
||||
|
||||
@classmethod
|
||||
def forward_mqa( # type: ignore[override]
|
||||
cls,
|
||||
layer: "DeepseekV4MLAAttention",
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
@@ -147,21 +162,20 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
|
||||
assert isinstance(attn_metadata, dict)
|
||||
flashmla_metadata = cast(
|
||||
FlashMLASparseMetadata | None, attn_metadata.get(layer.prefix)
|
||||
FlashMLASparseMetadata | None, attn_metadata.get(self.prefix)
|
||||
)
|
||||
swa_metadata = cast(
|
||||
"DeepseekSparseSWAMetadata | None",
|
||||
attn_metadata.get(layer.swa_cache_layer.prefix),
|
||||
attn_metadata.get(self.swa_cache_layer.prefix),
|
||||
)
|
||||
assert swa_metadata is not None
|
||||
|
||||
swa_only = layer.compress_ratio <= 1
|
||||
swa_only = self.compress_ratio <= 1
|
||||
# SWA-only layers don't allocate their own compressed KV cache.
|
||||
self_kv_cache = layer.kv_cache if not swa_only else None
|
||||
swa_kv_cache = layer.swa_cache_layer.kv_cache
|
||||
self_kv_cache = self.kv_cache if not swa_only else None
|
||||
swa_kv_cache = self.swa_cache_layer.kv_cache
|
||||
|
||||
cls._forward(
|
||||
layer=layer,
|
||||
self._forward(
|
||||
q=q,
|
||||
kv_cache=self_kv_cache,
|
||||
swa_k_cache=swa_kv_cache,
|
||||
@@ -171,10 +185,8 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
output=output,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _build_sparse_index_metadata(
|
||||
cls,
|
||||
layer: "DeepseekV4MLAAttention",
|
||||
self,
|
||||
kv_cache: torch.Tensor | None,
|
||||
swa_k_cache: torch.Tensor,
|
||||
swa_metadata: "DeepseekSparseSWAMetadata",
|
||||
@@ -200,17 +212,17 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
assert swa_metadata.block_table is not None
|
||||
|
||||
decode_swa_indices = swa_metadata.decode_swa_indices.reshape(
|
||||
num_decode_tokens, layer.window_size
|
||||
num_decode_tokens, self.window_size
|
||||
)
|
||||
decode_compressed_topk_lens = None
|
||||
decode_compressed_indices_are_local = False
|
||||
decode_is_valid_token = None
|
||||
|
||||
if swa_only:
|
||||
assert layer.topk_indices_buffer is not None
|
||||
assert self.topk_indices_buffer is not None
|
||||
compressed_kv_cache = swa_k_cache
|
||||
decode_compressed_indices = None
|
||||
prefill_topk_indices = layer.topk_indices_buffer[
|
||||
prefill_topk_indices = self.topk_indices_buffer[
|
||||
num_decode_tokens:num_tokens, :0
|
||||
]
|
||||
compressed_block_table = None
|
||||
@@ -221,24 +233,24 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
assert attn_metadata is not None
|
||||
compressed_kv_cache = kv_cache
|
||||
compressed_block_table = attn_metadata.block_table[:num_reqs]
|
||||
compressed_block_size = attn_metadata.block_size // layer.compress_ratio
|
||||
compressed_block_size = attn_metadata.block_size // self.compress_ratio
|
||||
|
||||
if layer.compress_ratio == 4:
|
||||
assert layer.topk_indices_buffer is not None
|
||||
if self.compress_ratio == 4:
|
||||
assert self.topk_indices_buffer is not None
|
||||
if num_prefill_tokens > 0:
|
||||
prefill_topk_indices = layer.topk_indices_buffer[
|
||||
prefill_topk_indices = self.topk_indices_buffer[
|
||||
num_decode_tokens:num_tokens
|
||||
]
|
||||
top_k = prefill_topk_indices.shape[-1]
|
||||
else:
|
||||
prefill_topk_indices = layer.topk_indices_buffer[:0, :0]
|
||||
prefill_topk_indices = self.topk_indices_buffer[:0, :0]
|
||||
top_k = 0
|
||||
|
||||
decode_compressed_indices_are_local = True
|
||||
assert swa_metadata.is_valid_token is not None
|
||||
decode_is_valid_token = swa_metadata.is_valid_token[:num_decode_tokens]
|
||||
if num_decode_tokens > 0:
|
||||
decode_compressed_indices = layer.topk_indices_buffer[
|
||||
decode_compressed_indices = self.topk_indices_buffer[
|
||||
:num_decode_tokens
|
||||
]
|
||||
else:
|
||||
@@ -284,18 +296,16 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
swa_metadata.block_size,
|
||||
compressed_block_table,
|
||||
compressed_block_size,
|
||||
layer.window_size,
|
||||
layer.compress_ratio,
|
||||
self.window_size,
|
||||
self.compress_ratio,
|
||||
top_k,
|
||||
decode_compressed_indices_are_local=decode_compressed_indices_are_local,
|
||||
decode_is_valid_token=decode_is_valid_token,
|
||||
)
|
||||
return compressed_kv_cache, seq_lens, sparse_indices, sparse_topk_lens
|
||||
|
||||
@classmethod
|
||||
def _forward(
|
||||
cls,
|
||||
layer: "DeepseekV4MLAAttention",
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_cache: torch.Tensor | None,
|
||||
swa_k_cache: torch.Tensor,
|
||||
@@ -304,7 +314,7 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
swa_only: bool,
|
||||
output: torch.Tensor,
|
||||
) -> None:
|
||||
assert layer.kv_cache_torch_dtype in (torch.bfloat16, torch.float8_e4m3fn)
|
||||
assert self.kv_cache_torch_dtype in (torch.bfloat16, torch.float8_e4m3fn)
|
||||
num_decodes = swa_metadata.num_decodes
|
||||
num_prefills = swa_metadata.num_prefills
|
||||
num_decode_tokens = swa_metadata.num_decode_tokens
|
||||
@@ -319,8 +329,7 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
seq_lens,
|
||||
sparse_indices,
|
||||
sparse_topk_lens,
|
||||
) = cls._build_sparse_index_metadata(
|
||||
layer=layer,
|
||||
) = self._build_sparse_index_metadata(
|
||||
kv_cache=kv_cache,
|
||||
swa_k_cache=swa_k_cache,
|
||||
swa_metadata=swa_metadata,
|
||||
@@ -332,12 +341,12 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
# restrict to the real tokens (the launcher validates sparse indices).
|
||||
query = q[:num_tokens]
|
||||
output = output[:num_tokens]
|
||||
bmm1_scale: float | torch.Tensor = layer.scale
|
||||
bmm1_scale: float | torch.Tensor = self.scale
|
||||
bmm2_scale: float | torch.Tensor = 1.0
|
||||
if layer.kv_cache_torch_dtype == torch.float8_e4m3fn:
|
||||
if self.kv_cache_torch_dtype == torch.float8_e4m3fn:
|
||||
assert query.dtype == torch.float8_e4m3fn
|
||||
bmm1_scale = layer._flashinfer_fp8_bmm1_scale
|
||||
bmm2_scale = layer._flashinfer_fp8_bmm2_scale
|
||||
bmm1_scale = self._flashinfer_fp8_bmm1_scale
|
||||
bmm2_scale = self._flashinfer_fp8_bmm2_scale
|
||||
else:
|
||||
assert query.dtype == torch.bfloat16
|
||||
query = query.contiguous()
|
||||
@@ -376,7 +385,7 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
out=output[:num_decode_tokens],
|
||||
bmm1_scale=bmm1_scale,
|
||||
bmm2_scale=bmm2_scale,
|
||||
sinks=layer.attn_sink,
|
||||
sinks=self.attn_sink,
|
||||
cum_seq_lens_q=decode_cu,
|
||||
max_q_len=int(decode_lens_cpu.max().item()),
|
||||
)
|
||||
@@ -401,7 +410,7 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
out=output[num_decode_tokens:num_tokens],
|
||||
bmm1_scale=bmm1_scale,
|
||||
bmm2_scale=bmm2_scale,
|
||||
sinks=layer.attn_sink,
|
||||
sinks=self.attn_sink,
|
||||
cum_seq_lens_q=prefill_cu,
|
||||
max_q_len=int(prefill_lens_cpu.max().item()),
|
||||
)
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, ClassVar, cast
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.models.deepseek_v4.attention import DeepseekV4Attention
|
||||
from vllm.models.deepseek_v4.common.ops import (
|
||||
combine_topk_swa_indices,
|
||||
compute_global_topk_indices_and_lens,
|
||||
dequantize_and_gather_k_cache,
|
||||
)
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
MultipleOf,
|
||||
SparseMLAAttentionImpl,
|
||||
from vllm.models.deepseek_v4.nvidia.ops.o_proj import (
|
||||
compute_fp8_einsum_recipe,
|
||||
deep_gemm_fp8_o_proj,
|
||||
)
|
||||
from vllm.v1.attention.backend import MultipleOf
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
FlashMLASparseBackend,
|
||||
FlashMLASparseMetadata,
|
||||
@@ -28,63 +28,9 @@ from vllm.v1.attention.ops.flashmla import (
|
||||
from vllm.v1.worker.workspace import current_workspace_manager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.models.deepseek_v4.attention import (
|
||||
DeepseekV4MLAAttention,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.sparse_swa import DeepseekSparseSWAMetadata
|
||||
|
||||
|
||||
class DeepseekV4SparseMLAAttentionImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
"""Abstract parent for DeepseekV4 sparse MLA impls.
|
||||
|
||||
V4 sparse MLA is driven by the layer (``DeepseekV4MLAAttention.forward``)
|
||||
rather than the v1 framework, so ``forward_mqa`` is overridden with a
|
||||
classmethod that takes the layer as its first argument. This Liskov-broken
|
||||
override is intentional: the grandparent's instance-method ``forward_mqa``
|
||||
is never called on V4 layers.
|
||||
"""
|
||||
|
||||
backend_cls: ClassVar[type[AttentionBackend]]
|
||||
|
||||
# Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather
|
||||
# workspace allocated in _forward_prefill and is also read by the V4 layer's
|
||||
# dummy-run path to pre-reserve that workspace.
|
||||
PREFILL_CHUNK_SIZE: ClassVar[int] = 4
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def forward_mqa( # type: ignore[override]
|
||||
cls,
|
||||
layer: "DeepseekV4MLAAttention",
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_padded_num_q_heads(cls, num_heads: int) -> int:
|
||||
"""Q head count the backend wants q allocated at.
|
||||
|
||||
The MLA wrapper allocates the q/output buffers at
|
||||
``[N, get_padded_num_q_heads(n_local_heads), head_dim]``. Must
|
||||
satisfy ``result >= num_heads``. Backends with no padding constraint
|
||||
return ``num_heads``.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def init_layer_buffers(cls, layer: "DeepseekV4MLAAttention") -> None:
|
||||
"""Register impl-specific buffers on the layer at construction.
|
||||
|
||||
No-op by default; FlashInfer overrides this to register its per-tensor
|
||||
FP8 scale buffers.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
class DeepseekV4FlashMLASparseBackend(FlashMLASparseBackend):
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
@@ -94,10 +40,6 @@ class DeepseekV4FlashMLASparseBackend(FlashMLASparseBackend):
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA_SPARSE_DSV4"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["DeepseekV4SparseMLAAttentionImpl"]:
|
||||
return DeepseekV4FlashMLASparseImpl
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
# DeepSeek V4 layout: 448 NoPE + 64 RoPE = 512 (overrides the
|
||||
@@ -120,11 +62,31 @@ class DeepseekV4FlashMLASparseBackend(FlashMLASparseBackend):
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
|
||||
class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
"""FlashMLA sparse MLA implementation for DeepSeek V4's custom MLA layer."""
|
||||
class DeepseekV4FlashMLAAttention(DeepseekV4Attention):
|
||||
"""FlashMLA sparse MLA attention layer for DeepSeek V4 (CUDA)."""
|
||||
|
||||
backend_cls = DeepseekV4FlashMLASparseBackend
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._einsum_recipe, self._tma_aligned_scales = compute_fp8_einsum_recipe()
|
||||
|
||||
def _o_proj(self, o: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
|
||||
return deep_gemm_fp8_o_proj(
|
||||
o,
|
||||
positions,
|
||||
self.rotary_emb.cos_sin_cache,
|
||||
self.wo_a,
|
||||
self.wo_b,
|
||||
n_groups=self.n_local_groups,
|
||||
heads_per_group=self.n_local_heads // self.n_local_groups,
|
||||
nope_dim=self.nope_head_dim,
|
||||
rope_dim=self.rope_head_dim,
|
||||
o_lora_rank=self.o_lora_rank,
|
||||
einsum_recipe=self._einsum_recipe,
|
||||
tma_aligned_scales=self._tma_aligned_scales,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_padded_num_q_heads(cls, num_heads: int) -> int:
|
||||
# FP8 decode kernel only supports h_q = 64 or 128.
|
||||
@@ -135,10 +97,8 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
)
|
||||
return 64 if num_heads <= 64 else 128
|
||||
|
||||
@classmethod
|
||||
def forward_mqa( # type: ignore[override]
|
||||
cls,
|
||||
layer: "DeepseekV4MLAAttention",
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
@@ -159,35 +119,35 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
# Warmup dummy run: no real metadata. Reserve the same bf16
|
||||
# gather workspace _forward_prefill would; the dequantize / topk
|
||||
# / sparse_fwd kernels are skipped this step.
|
||||
swa_only = layer.compress_ratio <= 1
|
||||
swa_only = self.compress_ratio <= 1
|
||||
N = (
|
||||
0
|
||||
if swa_only
|
||||
else (layer.max_model_len + layer.compress_ratio - 1)
|
||||
// layer.compress_ratio
|
||||
else (self.max_model_len + self.compress_ratio - 1)
|
||||
// self.compress_ratio
|
||||
)
|
||||
M = N + layer.window_size + layer.max_num_batched_tokens
|
||||
M = N + self.window_size + self.max_num_batched_tokens
|
||||
current_workspace_manager().get_simultaneous(
|
||||
((cls.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
|
||||
((self.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
|
||||
)
|
||||
output.zero_()
|
||||
return
|
||||
|
||||
assert isinstance(attn_metadata, dict)
|
||||
flashmla_metadata = cast(
|
||||
FlashMLASparseMetadata | None, attn_metadata.get(layer.prefix)
|
||||
FlashMLASparseMetadata | None, attn_metadata.get(self.prefix)
|
||||
)
|
||||
swa_metadata = cast(
|
||||
"DeepseekSparseSWAMetadata | None",
|
||||
attn_metadata.get(layer.swa_cache_layer.prefix),
|
||||
attn_metadata.get(self.swa_cache_layer.prefix),
|
||||
)
|
||||
assert swa_metadata is not None
|
||||
|
||||
swa_only = layer.compress_ratio <= 1
|
||||
swa_only = self.compress_ratio <= 1
|
||||
# SWA-only layers (compress_ratio <= 1) don't have their own KV cache
|
||||
# allocation, so layer.kv_cache may be empty after profiling cleanup.
|
||||
self_kv_cache = layer.kv_cache if not swa_only else None
|
||||
swa_kv_cache = layer.swa_cache_layer.kv_cache
|
||||
# allocation, so self.kv_cache may be empty after profiling cleanup.
|
||||
self_kv_cache = self.kv_cache if not swa_only else None
|
||||
swa_kv_cache = self.swa_cache_layer.kv_cache
|
||||
|
||||
# Split prefill and decode
|
||||
num_decodes = swa_metadata.num_decodes
|
||||
@@ -195,8 +155,7 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
num_decode_tokens = swa_metadata.num_decode_tokens
|
||||
|
||||
if num_prefills > 0:
|
||||
cls._forward_prefill(
|
||||
layer=layer,
|
||||
self._forward_prefill(
|
||||
q=q[num_decode_tokens:],
|
||||
positions=positions[num_decode_tokens:],
|
||||
compressed_k_cache=self_kv_cache,
|
||||
@@ -206,8 +165,7 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
swa_metadata=swa_metadata,
|
||||
)
|
||||
if num_decodes > 0:
|
||||
cls._forward_decode(
|
||||
layer=layer,
|
||||
self._forward_decode(
|
||||
q=q[:num_decode_tokens],
|
||||
kv_cache=self_kv_cache,
|
||||
swa_metadata=swa_metadata,
|
||||
@@ -216,10 +174,8 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
output=output[:num_decode_tokens],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _forward_decode(
|
||||
cls,
|
||||
layer: "DeepseekV4MLAAttention",
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_cache: torch.Tensor | None, # Only used when compress_ratio > 1
|
||||
swa_metadata: "DeepseekSparseSWAMetadata",
|
||||
@@ -235,13 +191,13 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
if not swa_only:
|
||||
assert attn_metadata is not None
|
||||
assert swa_metadata.is_valid_token is not None
|
||||
block_size = attn_metadata.block_size // layer.compress_ratio
|
||||
block_size = attn_metadata.block_size // self.compress_ratio
|
||||
is_valid = swa_metadata.is_valid_token[:num_decode_tokens]
|
||||
if layer.compress_ratio == 4:
|
||||
if self.compress_ratio == 4:
|
||||
# C4A: local indices differ per layer (filled by Indexer).
|
||||
assert layer.topk_indices_buffer is not None
|
||||
assert self.topk_indices_buffer is not None
|
||||
global_indices, topk_lens = compute_global_topk_indices_and_lens(
|
||||
layer.topk_indices_buffer[:num_decode_tokens],
|
||||
self.topk_indices_buffer[:num_decode_tokens],
|
||||
swa_metadata.token_to_req_indices,
|
||||
attn_metadata.block_table[:num_decodes],
|
||||
block_size,
|
||||
@@ -258,12 +214,12 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
|
||||
# We treat queries in the same seq as different queries
|
||||
# and later we only attend by generated indices.
|
||||
# q arrives pre-padded to layer.padded_heads by the outer wrapper.
|
||||
# q arrives pre-padded to self.padded_heads by the outer wrapper.
|
||||
q = q.unsqueeze(1)
|
||||
|
||||
# Prepare SWA cache (num_blocks, swa_block_size, 1, head_bytes)
|
||||
# Use unsqueeze to preserve strides (handles padded blocks correctly)
|
||||
swa_cache = layer.swa_cache_layer.kv_cache.unsqueeze(-2)
|
||||
swa_cache = self.swa_cache_layer.kv_cache.unsqueeze(-2)
|
||||
# Reshape KV cache to (num_blocks, block_size, 1, head_bytes)
|
||||
if kv_cache is not None:
|
||||
kv_cache = kv_cache.unsqueeze(-2)
|
||||
@@ -274,20 +230,20 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
# and num_splits via PyTorch's graph-aware allocator so CUDA graph
|
||||
# capture reuses the same addresses on replay); subsequent same-type
|
||||
# layers see have_initialized=True and skip the planner.
|
||||
if layer.compress_ratio <= 1:
|
||||
if self.compress_ratio <= 1:
|
||||
tile_metadata = swa_metadata.tile_sched_swaonly
|
||||
elif layer.compress_ratio == 4:
|
||||
elif self.compress_ratio == 4:
|
||||
tile_metadata = swa_metadata.tile_sched_c4a
|
||||
elif layer.compress_ratio == 128:
|
||||
elif self.compress_ratio == 128:
|
||||
tile_metadata = swa_metadata.tile_sched_c128a
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported compress_ratio={layer.compress_ratio}; "
|
||||
f"Unsupported compress_ratio={self.compress_ratio}; "
|
||||
"expected 1, 4, or 128."
|
||||
)
|
||||
assert tile_metadata is not None, (
|
||||
"swa_metadata missing tile_sched entry for "
|
||||
f"compress_ratio={layer.compress_ratio}; "
|
||||
f"compress_ratio={self.compress_ratio}; "
|
||||
"DeepseekSparseSWAMetadataBuilder.build_tile_scheduler did not "
|
||||
"allocate one for this layer type."
|
||||
)
|
||||
@@ -302,18 +258,16 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
is_fp8_kvcache=True,
|
||||
indices=swa_indices,
|
||||
topk_length=swa_lens,
|
||||
softmax_scale=layer.scale,
|
||||
attn_sink=layer.attn_sink,
|
||||
softmax_scale=self.scale,
|
||||
attn_sink=self.attn_sink,
|
||||
extra_k_cache=kv_cache if not swa_only else None,
|
||||
extra_indices_in_kvcache=topk_indices,
|
||||
extra_topk_length=topk_lens,
|
||||
out=output.unsqueeze(1),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _forward_prefill(
|
||||
cls,
|
||||
layer: "DeepseekV4MLAAttention",
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
compressed_k_cache: torch.Tensor | None, # Only used when compress_ratio > 1
|
||||
@@ -343,9 +297,9 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
prefill_token_base = query_start_loc_cpu[num_decodes]
|
||||
|
||||
if not swa_only:
|
||||
if layer.compress_ratio == 4:
|
||||
assert layer.topk_indices_buffer is not None
|
||||
topk_indices = layer.topk_indices_buffer[num_decode_tokens:]
|
||||
if self.compress_ratio == 4:
|
||||
assert self.topk_indices_buffer is not None
|
||||
topk_indices = self.topk_indices_buffer[num_decode_tokens:]
|
||||
topk_indices = topk_indices[:num_prefill_tokens]
|
||||
else:
|
||||
# C128A: pre-computed during metadata build.
|
||||
@@ -355,16 +309,16 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
# Compressed region must fit the full compressed pool (seq_len //
|
||||
# compress_ratio), not just top_k. top_k bounds how many indices
|
||||
# the indexer selects, not the pool size it indexes into.
|
||||
N = (layer.max_model_len + layer.compress_ratio - 1) // layer.compress_ratio
|
||||
N = (self.max_model_len + self.compress_ratio - 1) // self.compress_ratio
|
||||
else:
|
||||
# NOTE(woosuk): topk_indices will not be used for SWA-only layers.
|
||||
assert layer.topk_indices_buffer is not None
|
||||
topk_indices = layer.topk_indices_buffer[num_decode_tokens:]
|
||||
assert self.topk_indices_buffer is not None
|
||||
topk_indices = self.topk_indices_buffer[num_decode_tokens:]
|
||||
top_k = 0
|
||||
N = 0
|
||||
|
||||
M = N + layer.window_size + layer.max_num_batched_tokens
|
||||
chunk_size_const = cls.PREFILL_CHUNK_SIZE
|
||||
M = N + self.window_size + self.max_num_batched_tokens
|
||||
chunk_size_const = self.PREFILL_CHUNK_SIZE
|
||||
num_chunks = (num_prefills + chunk_size_const - 1) // chunk_size_const
|
||||
|
||||
workspace_manager = current_workspace_manager()
|
||||
@@ -382,10 +336,10 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
dequantize_and_gather_k_cache(
|
||||
kv[:chunk_size],
|
||||
compressed_k_cache,
|
||||
seq_lens=seq_lens[chunk_start:chunk_end] // layer.compress_ratio,
|
||||
seq_lens=seq_lens[chunk_start:chunk_end] // self.compress_ratio,
|
||||
gather_lens=None,
|
||||
block_table=block_table[chunk_start:chunk_end],
|
||||
block_size=attn_metadata.block_size // layer.compress_ratio,
|
||||
block_size=attn_metadata.block_size // self.compress_ratio,
|
||||
offset=0,
|
||||
)
|
||||
|
||||
@@ -416,8 +370,8 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
],
|
||||
seq_lens[chunk_start:chunk_end],
|
||||
gather_lens[chunk_start:chunk_end],
|
||||
layer.window_size,
|
||||
layer.compress_ratio,
|
||||
self.window_size,
|
||||
self.compress_ratio,
|
||||
top_k,
|
||||
M,
|
||||
N,
|
||||
@@ -426,8 +380,8 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
q=q[query_start:query_end],
|
||||
kv=kv.view(-1, 1, q.shape[-1]),
|
||||
indices=combined_indices.unsqueeze(1),
|
||||
sm_scale=layer.scale,
|
||||
attn_sink=layer.attn_sink,
|
||||
sm_scale=self.scale,
|
||||
attn_sink=self.attn_sink,
|
||||
topk_length=combined_lens,
|
||||
out=output[query_start:query_end],
|
||||
)
|
||||
|
||||
@@ -33,7 +33,6 @@ from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
|
||||
from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
@@ -55,13 +54,14 @@ from vllm.model_executor.models.utils import (
|
||||
maybe_prefix,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.models.deepseek_v4.attention import (
|
||||
DeepseekV4Indexer,
|
||||
DeepseekV4MLA,
|
||||
from vllm.models.deepseek_v4.attention import DeepseekV4Attention
|
||||
from vllm.models.deepseek_v4.nvidia.flashinfer_sparse import (
|
||||
DeepseekV4FlashInferMLAAttention,
|
||||
)
|
||||
from vllm.models.deepseek_v4.common.rope import build_deepseek_v4_rope
|
||||
from vllm.models.deepseek_v4.nvidia.flashmla import DeepseekV4FlashMLAAttention
|
||||
from vllm.models.deepseek_v4.nvidia.ops.prepare_megamoe import prepare_megamoe_inputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
|
||||
class DeepseekV4MLP(nn.Module):
|
||||
@@ -713,163 +713,18 @@ class DeepseekV4MoE(nn.Module):
|
||||
self.experts.finalize_weights()
|
||||
|
||||
|
||||
class DeepseekV4Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str,
|
||||
topk_indices_buffer: torch.Tensor | None = None,
|
||||
aux_stream_list: list[torch.cuda.Stream] | None = None,
|
||||
def _select_dsv4_attn_cls(vllm_config: VllmConfig) -> type[DeepseekV4Attention]:
|
||||
"""Pick the CUDA sparse-MLA attention class for the configured backend.
|
||||
|
||||
An explicit ``--attention-backend FLASHINFER_MLA_SPARSE_DSV4`` selects the
|
||||
FlashInfer TRTLLM-gen path; otherwise the FlashMLA path is used.
|
||||
"""
|
||||
if (
|
||||
vllm_config.attention_config.backend
|
||||
== AttentionBackendEnum.FLASHINFER_MLA_SPARSE_DSV4
|
||||
):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
layer_id = extract_layer_index(prefix)
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.hidden_size = config.hidden_size
|
||||
self.n_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert self.n_heads % tp_size == 0
|
||||
|
||||
self.n_local_heads = self.n_heads // tp_size
|
||||
self.q_lora_rank = config.q_lora_rank
|
||||
self.o_lora_rank = config.o_lora_rank
|
||||
self.head_dim = config.head_dim
|
||||
self.rope_head_dim = config.qk_rope_head_dim
|
||||
self.nope_head_dim = self.head_dim - self.rope_head_dim
|
||||
self.n_groups = config.o_groups
|
||||
self.n_local_groups = self.n_groups // tp_size
|
||||
self.window_size = config.sliding_window
|
||||
# NOTE(zyongye) Compress ratio can't be 0
|
||||
# we do this for because MTP layer is not included
|
||||
# in the compress ratio list
|
||||
if layer_id < config.num_hidden_layers:
|
||||
self.compress_ratio = max(1, config.compress_ratios[layer_id])
|
||||
else:
|
||||
self.compress_ratio = 1
|
||||
self.eps = config.rms_norm_eps
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
|
||||
# Padded to min 64 heads for FlashMLA, initialized to -inf
|
||||
# (no sink effect). Weight loading fills the first n_local_heads slots.
|
||||
padded_heads = max(self.n_local_heads, 64)
|
||||
self.attn_sink = nn.Parameter(
|
||||
torch.full((padded_heads,), -float("inf"), dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
self.fused_wqa_wkv = MergedColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
[self.q_lora_rank, self.head_dim],
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fused_wqa_wkv",
|
||||
disable_tp=True, # fused ReplicatedLinear
|
||||
)
|
||||
self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
|
||||
self.wq_b = ColumnParallelLinear(
|
||||
self.q_lora_rank,
|
||||
self.n_heads * self.head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=f"{prefix}.wq_b",
|
||||
)
|
||||
|
||||
self.kv_norm = RMSNorm(self.head_dim, self.eps)
|
||||
self.wo_a = ColumnParallelLinear(
|
||||
self.n_heads * self.head_dim // self.n_groups,
|
||||
self.n_groups * self.o_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=f"{prefix}.wo_a",
|
||||
)
|
||||
self.wo_a.is_bmm = True
|
||||
self.wo_a.bmm_batch_size = self.n_local_groups
|
||||
self.wo_b = RowParallelLinear(
|
||||
self.n_groups * self.o_lora_rank,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=f"{prefix}.wo_b",
|
||||
)
|
||||
self.softmax_scale = self.head_dim**-0.5
|
||||
self.scale_fmt = config.quantization_config["scale_fmt"]
|
||||
|
||||
self.rope_parameters = config.rope_scaling
|
||||
|
||||
# Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it)
|
||||
self.rotary_emb = build_deepseek_v4_rope(
|
||||
config,
|
||||
head_dim=self.head_dim,
|
||||
rope_head_dim=self.rope_head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
compress_ratio=self.compress_ratio,
|
||||
)
|
||||
|
||||
self.indexer = None
|
||||
if self.compress_ratio == 4:
|
||||
# Only C4A uses sparse attention and hence has indexer.
|
||||
# aux_stream_list[0] runs indexer.forward() in the wrapper; [2] is
|
||||
# free here (outer GEMMs joined) for the inner overlap of
|
||||
# wq_b+fused_indexer_q_rope_quant vs compressor.
|
||||
indexer_aux_stream = (
|
||||
aux_stream_list[2] if aux_stream_list is not None else None
|
||||
)
|
||||
self.indexer = DeepseekV4Indexer(
|
||||
vllm_config,
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
quant_config=quant_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
topk_indices_buffer=topk_indices_buffer,
|
||||
compress_ratio=self.compress_ratio,
|
||||
prefix=f"{prefix}.indexer",
|
||||
aux_stream=indexer_aux_stream,
|
||||
)
|
||||
|
||||
self.mla_attn = DeepseekV4MLA(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=self.n_local_heads,
|
||||
head_dim=self.head_dim,
|
||||
scale=self.softmax_scale,
|
||||
qk_nope_head_dim=self.nope_head_dim,
|
||||
qk_rope_head_dim=self.rope_head_dim,
|
||||
v_head_dim=self.head_dim,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.head_dim,
|
||||
o_lora_rank=self.o_lora_rank,
|
||||
vllm_config=vllm_config,
|
||||
fused_wqa_wkv=self.fused_wqa_wkv,
|
||||
q_norm=self.q_norm,
|
||||
wq_b=self.wq_b,
|
||||
kv_norm=self.kv_norm,
|
||||
wo_a=self.wo_a,
|
||||
wo_b=self.wo_b,
|
||||
attn_sink=self.attn_sink,
|
||||
rotary_emb=self.rotary_emb,
|
||||
indexer=self.indexer,
|
||||
indexer_rotary_emb=self.rotary_emb,
|
||||
topk_indices_buffer=topk_indices_buffer,
|
||||
aux_stream_list=aux_stream_list,
|
||||
window_size=self.window_size,
|
||||
compress_ratio=self.compress_ratio,
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
llama_4_scaling: torch.Tensor | None,
|
||||
):
|
||||
return self.mla_attn(positions, hidden_states, llama_4_scaling)
|
||||
return DeepseekV4FlashInferMLAAttention
|
||||
return DeepseekV4FlashMLAAttention
|
||||
|
||||
|
||||
class DeepseekV4DecoderLayer(nn.Module):
|
||||
@@ -886,7 +741,7 @@ class DeepseekV4DecoderLayer(nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.rms_norm_eps = config.rms_norm_eps
|
||||
self.attn = DeepseekV4Attention(
|
||||
self.attn = _select_dsv4_attn_cls(vllm_config)(
|
||||
vllm_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
topk_indices_buffer=topk_indices_buffer,
|
||||
@@ -1043,7 +898,7 @@ class DeepseekV4Model(nn.Module):
|
||||
self.rms_norm_eps = config.rms_norm_eps
|
||||
|
||||
# Three aux streams: one per non-default input GEMM in
|
||||
# DeepseekV4MLA.attn_gemm_parallel_execute
|
||||
# DeepseekV4Attention.attn_gemm_parallel_execute
|
||||
# (compressor kv_score, indexer.weights_proj, indexer.compressor
|
||||
# kv_score). fused_wqa_wkv stays on the default stream.
|
||||
aux_stream_list = [torch.cuda.Stream() for _ in range(3)]
|
||||
@@ -1341,7 +1196,6 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
|
||||
".ffn.gate.bias": ".ffn.gate.e_score_correction_bias",
|
||||
},
|
||||
orig_to_new_substr={
|
||||
".attn.compressor.": ".attn.mla_attn.compressor.",
|
||||
".shared_experts.w2": ".shared_experts.down_proj",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.models.deepseek_v4.common.ops import fused_inv_rope_fp8_quant
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import fp8_einsum
|
||||
|
||||
|
||||
def compute_fp8_einsum_recipe() -> tuple[tuple[int, int, int], bool]:
|
||||
"""fp8_einsum recipe + scale layout for the current GPU arch.
|
||||
|
||||
SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128.
|
||||
SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1.
|
||||
|
||||
Returns ``(einsum_recipe, tma_aligned_scales)`` for ``deep_gemm_fp8_o_proj``.
|
||||
"""
|
||||
cap = current_platform.get_device_capability()
|
||||
assert cap is not None, "DeepseekV4 attention requires a CUDA device"
|
||||
einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128)
|
||||
tma_aligned_scales = cap.major >= 10
|
||||
return einsum_recipe, tma_aligned_scales
|
||||
|
||||
|
||||
def deep_gemm_fp8_o_proj(
|
||||
o: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
wo_a: nn.Module,
|
||||
wo_b: nn.Module,
|
||||
*,
|
||||
n_groups: int,
|
||||
heads_per_group: int,
|
||||
nope_dim: int,
|
||||
rope_dim: int,
|
||||
o_lora_rank: int,
|
||||
einsum_recipe: tuple[int, int, int],
|
||||
tma_aligned_scales: bool,
|
||||
) -> torch.Tensor:
|
||||
"""O projection: inverse RoPE + FP8 quant + einsum + wo_b.
|
||||
|
||||
Shared by the FlashMLA and FlashInfer CUDA backends. ``einsum_recipe`` /
|
||||
``tma_aligned_scales`` come from ``compute_fp8_einsum_recipe``.
|
||||
"""
|
||||
o_fp8, o_scale = fused_inv_rope_fp8_quant(
|
||||
o,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
n_groups=n_groups,
|
||||
heads_per_group=heads_per_group,
|
||||
nope_dim=nope_dim,
|
||||
rope_dim=rope_dim,
|
||||
tma_aligned_scales=tma_aligned_scales,
|
||||
)
|
||||
z = torch.empty(
|
||||
(o.shape[0], n_groups, o_lora_rank),
|
||||
device=o.device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
fp8_einsum(
|
||||
"bhr,hdr->bhd",
|
||||
(o_fp8, o_scale),
|
||||
(wo_a.weight, wo_a.weight_scale_inv),
|
||||
z,
|
||||
recipe=einsum_recipe,
|
||||
)
|
||||
return wo_b(z.flatten(1))
|
||||
@@ -602,7 +602,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
||||
fp8_use_mixed_batch = (
|
||||
self.num_heads < MIN_HEADS_FOR_BF16_PREFILL and not self.is_deepseek_v4
|
||||
)
|
||||
# DeepseekV4 has its own attention impl (DeepseekV4MLAAttention) that does not
|
||||
# DeepseekV4 has its own attention impl (DeepseekV4Attention) that does not
|
||||
# consume fp8_extra_metadata. Skipping the build here avoids a
|
||||
# forced D2H sync on seq_lens that would otherwise fire on every
|
||||
# prefill-bearing step, lifting GPU utilization on long-prefill
|
||||
|
||||
Reference in New Issue
Block a user