[DSv4] Refactor DeepseekV4Attention

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-06-01 23:27:28 +00:00
parent 5ac2b4bdb0
commit 1ff2b11e17
6 changed files with 166 additions and 548 deletions
+2 -158
View File
@@ -18,7 +18,6 @@ from vllm.model_executor.layers.activation import SiluAndMul, SiluAndMulWithClam
from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
@@ -46,10 +45,8 @@ from vllm.model_executor.models.utils import (
maybe_prefix,
)
from vllm.models.deepseek_v4.attention import (
DeepseekV4Indexer,
DeepseekV4MLA,
DeepseekV4Attention,
)
from vllm.models.deepseek_v4.common.rope import build_deepseek_v4_rope
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import has_tilelang
@@ -225,158 +222,6 @@ class DeepseekV4MoE(nn.Module):
return final_hidden_states.view(org_shape)
class DeepseekV4Attention(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str,
topk_indices_buffer: torch.Tensor | None = None,
aux_stream_list: list[torch.cuda.Stream] | None = None,
):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
layer_id = extract_layer_index(prefix)
self.layer_id = layer_id
self.hidden_size = config.hidden_size
self.n_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
assert self.n_heads % tp_size == 0
self.n_local_heads = self.n_heads // tp_size
self.q_lora_rank = config.q_lora_rank
self.o_lora_rank = config.o_lora_rank
self.head_dim = config.head_dim
self.rope_head_dim = config.qk_rope_head_dim
self.nope_head_dim = self.head_dim - self.rope_head_dim
self.n_groups = config.o_groups
self.n_local_groups = self.n_groups // tp_size
self.window_size = config.sliding_window
# NOTE(zyongye) Compress ratio can't be 0
# we do this for because MTP layer is not included
# in the compress ratio list
if layer_id < config.num_hidden_layers:
self.compress_ratio = max(1, config.compress_ratios[layer_id])
else:
self.compress_ratio = 1
self.eps = config.rms_norm_eps
self.max_position_embeddings = config.max_position_embeddings
# Padded to min 64 heads for FlashMLA, initialized to -inf
# (no sink effect). Weight loading fills the first n_local_heads slots.
padded_heads = max(self.n_local_heads, 64)
self.attn_sink = nn.Parameter(
torch.full((padded_heads,), -float("inf"), dtype=torch.float32),
requires_grad=False,
)
self.fused_wqa_wkv = MergedColumnParallelLinear(
self.hidden_size,
[self.q_lora_rank, self.head_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fused_wqa_wkv",
disable_tp=True, # fused ReplicatedLinear
)
self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
self.wq_b = ColumnParallelLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wq_b",
)
self.kv_norm = RMSNorm(self.head_dim, self.eps)
self.wo_a = ColumnParallelLinear(
self.n_heads * self.head_dim // self.n_groups,
self.n_groups * self.o_lora_rank,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wo_a",
)
self.wo_a.is_bmm = True
self.wo_a.bmm_batch_size = self.n_local_groups
self.wo_b = RowParallelLinear(
self.n_groups * self.o_lora_rank,
self.hidden_size,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wo_b",
)
self.softmax_scale = self.head_dim**-0.5
self.scale_fmt = config.quantization_config["scale_fmt"]
self.rope_parameters = config.rope_scaling
# Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it)
self.rotary_emb = build_deepseek_v4_rope(
config,
head_dim=self.head_dim,
rope_head_dim=self.rope_head_dim,
max_position_embeddings=self.max_position_embeddings,
compress_ratio=self.compress_ratio,
)
self.indexer = None
if self.compress_ratio == 4:
# Only C4A uses sparse attention and hence has indexer.
self.indexer = DeepseekV4Indexer(
vllm_config,
config=config,
hidden_size=self.hidden_size,
q_lora_rank=self.q_lora_rank,
quant_config=quant_config,
cache_config=vllm_config.cache_config,
topk_indices_buffer=topk_indices_buffer,
compress_ratio=self.compress_ratio,
prefix=f"{prefix}.indexer",
)
self.mla_attn = DeepseekV4MLA(
hidden_size=self.hidden_size,
num_heads=self.n_local_heads,
head_dim=self.head_dim,
scale=self.softmax_scale,
qk_nope_head_dim=self.nope_head_dim,
qk_rope_head_dim=self.rope_head_dim,
v_head_dim=self.head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.head_dim,
o_lora_rank=self.o_lora_rank,
vllm_config=vllm_config,
fused_wqa_wkv=self.fused_wqa_wkv,
q_norm=self.q_norm,
wq_b=self.wq_b,
kv_norm=self.kv_norm,
wo_a=self.wo_a,
wo_b=self.wo_b,
attn_sink=self.attn_sink,
rotary_emb=self.rotary_emb,
indexer=self.indexer,
indexer_rotary_emb=self.rotary_emb,
topk_indices_buffer=topk_indices_buffer,
aux_stream_list=aux_stream_list,
window_size=self.window_size,
compress_ratio=self.compress_ratio,
cache_config=vllm_config.cache_config,
quant_config=quant_config,
prefix=prefix,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None,
):
return self.mla_attn(positions, hidden_states, llama_4_scaling)
class DeepseekV4DecoderLayer(nn.Module):
def __init__(
self,
@@ -601,7 +446,7 @@ class DeepseekV4Model(nn.Module):
self.rms_norm_eps = config.rms_norm_eps
# Three aux streams: one per non-default input GEMM in
# DeepseekV4MLA.attn_gemm_parallel_execute
# DeepseekV4Attention.attn_gemm_parallel_execute
# (compressor kv_score, indexer.weights_proj, indexer.compressor
# kv_score). fused_wqa_wkv stays on the default stream.
# Disable them on ROCm because of hang issues.
@@ -897,7 +742,6 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
".ffn.gate.bias": ".ffn.gate.e_score_correction_bias",
},
orig_to_new_substr={
".attn.compressor.": ".attn.mla_attn.compressor.",
".shared_experts.w2": ".shared_experts.down_proj",
},
)
+4 -4
View File
@@ -33,7 +33,7 @@ from vllm.v1.worker.workspace import current_workspace_manager
if TYPE_CHECKING:
from vllm.models.deepseek_v4.attention import (
DeepseekV4MLAAttention,
DeepseekV4Attention,
)
@@ -599,7 +599,7 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
@classmethod
def forward_mqa( # type: ignore[override]
cls,
layer: "DeepseekV4MLAAttention",
layer: "DeepseekV4Attention",
q: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
@@ -677,7 +677,7 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
@classmethod
def _forward_decode(
cls,
layer: "DeepseekV4MLAAttention",
layer: "DeepseekV4Attention",
q: torch.Tensor,
kv_cache: torch.Tensor | None,
swa_metadata: DeepseekV4ROCMAiterSparseSWAMetadata,
@@ -740,7 +740,7 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
@classmethod
def _forward_prefill(
cls,
layer: "DeepseekV4MLAAttention",
layer: "DeepseekV4Attention",
q: torch.Tensor,
positions: torch.Tensor,
compressed_k_cache: torch.Tensor | None,
+151 -214
View File
@@ -15,7 +15,10 @@ from transformers import DeepseekV2Config, DeepseekV3Config
import vllm.envs as envs
from vllm.compilation.breakable_cudagraph import eager_break_during_capture
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer
from vllm.models.deepseek_v4.common.ops import (
@@ -42,12 +45,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.input_quant_fp8 import (
QuantFP8,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
from vllm.model_executor.models.utils import extract_layer_index
from vllm.models.deepseek_v4.common.rope import build_deepseek_v4_rope
from vllm.models.deepseek_v4.compressor import DeepseekCompressor
from vllm.platforms import current_platform
from vllm.utils.multi_stream_utils import (
@@ -88,78 +87,89 @@ def _select_v4_sparse_impl() -> "type[DeepseekV4SparseMLAAttentionImpl]":
return DeepseekV4FlashMLASparseImpl
class DeepseekV4MLA(nn.Module):
class DeepseekV4Attention(nn.Module, AttentionLayerBase):
def __init__(
self,
hidden_size: int,
num_heads: int,
head_dim: int,
scale: float,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int | None,
kv_lora_rank: int,
o_lora_rank: int | None,
vllm_config: VllmConfig,
fused_wqa_wkv: torch.nn.Module,
q_norm: torch.nn.Module,
wq_b: torch.nn.Module,
kv_norm: torch.nn.Module,
wo_a: torch.nn.Module,
wo_b: torch.nn.Module,
attn_sink: torch.nn.Module,
rotary_emb: torch.nn.Module,
indexer: torch.nn.Module | None,
indexer_rotary_emb: torch.nn.Module,
topk_indices_buffer: torch.Tensor | None,
aux_stream_list: list[torch.cuda.Stream] | None,
window_size: int,
compress_ratio: int | None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
prefix: str,
topk_indices_buffer: torch.Tensor | None = None,
aux_stream_list: list[torch.cuda.Stream] | None = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.n_local_heads = num_heads
self.head_dim = head_dim
self.scale = scale
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.window_size = window_size
self.compress_ratio = compress_ratio if compress_ratio is not None else 1
self.prefix = prefix
# Extract config from vllm_config
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
cache_config = vllm_config.cache_config
tp_size = get_tensor_model_parallel_world_size()
layer_id = extract_layer_index(prefix)
# DeepseekV4-specific attributes (num_heads is already TP-adjusted)
self.eps = config.rms_norm_eps
self.rope_head_dim = config.qk_rope_head_dim
self.nope_head_dim = head_dim - self.rope_head_dim
self.n_local_groups = config.o_groups // tp_size
self.prefix = prefix # Alias for compatibility with compressor
self.hidden_size = config.hidden_size
self.n_heads = config.num_attention_heads
assert self.n_heads % tp_size == 0
self.n_local_heads = self.n_heads // tp_size
self.q_lora_rank = config.q_lora_rank
self.o_lora_rank = config.o_lora_rank
self.head_dim = config.head_dim
self.rope_head_dim = config.qk_rope_head_dim
self.nope_head_dim = self.head_dim - self.rope_head_dim
self.n_groups = config.o_groups
self.n_local_groups = self.n_groups // tp_size
self.window_size = config.sliding_window
# NOTE(zyongye) Compress ratio can't be 0
# we do this for because MTP layer is not included
# in the compress ratio list
if layer_id < config.num_hidden_layers:
self.compress_ratio = max(1, config.compress_ratios[layer_id])
else:
self.compress_ratio = 1
self.eps = config.rms_norm_eps
self.scale = self.head_dim**-0.5
# Store projection modules
self.fused_wqa_wkv = fused_wqa_wkv
self.q_norm = q_norm
self.wq_b = wq_b
self.kv_norm = kv_norm
self.wo_a = wo_a
self._wo_a_act_quant = QuantFP8(
static=False,
group_shape=GroupShape(1, 128),
use_ue8m0=True,
# Padded to min 64 heads for FlashMLA, initialized to -inf
# (no sink effect). Weight loading fills the first n_local_heads slots.
padded_heads = max(self.n_local_heads, 64)
self.attn_sink = nn.Parameter(
torch.full((padded_heads,), -float("inf"), dtype=torch.float32),
requires_grad=False,
)
self.fused_wqa_wkv = MergedColumnParallelLinear(
self.hidden_size,
[self.q_lora_rank, self.head_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fused_wqa_wkv",
disable_tp=True, # fused ReplicatedLinear
)
self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
self.wq_b = ColumnParallelLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wq_b",
)
self.kv_norm = RMSNorm(self.head_dim, self.eps)
self.wo_a = ColumnParallelLinear(
self.n_heads * self.head_dim // self.n_groups,
self.n_groups * self.o_lora_rank,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wo_a",
)
self.wo_a.is_bmm = True
self.wo_a.bmm_batch_size = self.n_local_groups
self.wo_b = RowParallelLinear(
self.n_groups * self.o_lora_rank,
self.hidden_size,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wo_b",
)
# Bypass packed-for-deepgemm path — we need FP32 scales (not packed
# INT32) so fp8_einsum can handle layout transform internally.
self._wo_a_act_quant.use_deep_gemm_supported = False
self.wo_b = wo_b
# Pick fp8_einsum recipe based on GPU arch:
# SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128
@@ -169,22 +179,38 @@ class DeepseekV4MLA(nn.Module):
self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128)
self._tma_aligned_scales = cap.major >= 10
self.rotary_emb = rotary_emb
self.indexer_rotary_emb = indexer_rotary_emb
# Initialize rotary embedding before the indexer/compressor consume it.
self.rotary_emb = build_deepseek_v4_rope(
config,
head_dim=self.head_dim,
rope_head_dim=self.rope_head_dim,
max_position_embeddings=config.max_position_embeddings,
compress_ratio=self.compress_ratio,
)
self.indexer_rotary_emb = self.rotary_emb
self.topk_indices_buffer = topk_indices_buffer
self.indexer = indexer
# Per-head RMS normalization for Q (no learnable weights)
self.q_head_norm = RMSNorm(head_dim, eps=self.eps, has_weight=False)
# TODO(yifan): currently hardcoded for FP8 sparse, make it more generic
head_bytes = (
self.nope_head_dim # 448 fp8 NoPE
+ self.rope_head_dim * 2 # 64 bf16 RoPE
+ self.nope_head_dim // 64 # 7B scale factors
+ 1 # 1B pad
)
self.indexer = None
if self.compress_ratio == 4:
# Only C4A uses sparse attention and hence has indexer.
# aux_stream_list[2] is free here (outer GEMMs joined) for the inner
# overlap of wq_b+fused_indexer_q_rope_quant vs compressor. None on
# ROCm, where aux_stream_list is None.
indexer_aux_stream = (
aux_stream_list[2] if aux_stream_list is not None else None
)
self.indexer = DeepseekV4Indexer(
vllm_config,
config=config,
hidden_size=self.hidden_size,
q_lora_rank=self.q_lora_rank,
quant_config=quant_config,
cache_config=cache_config,
topk_indices_buffer=topk_indices_buffer,
compress_ratio=self.compress_ratio,
prefix=f"{prefix}.indexer",
aux_stream=indexer_aux_stream,
)
# Will be None on ROCm for now.
self.aux_stream_list = aux_stream_list
@@ -202,30 +228,48 @@ class DeepseekV4MLA(nn.Module):
cache_config=cache_config,
)
self.mla_attn = DeepseekV4MLAAttention(
num_heads=self.n_local_heads,
head_dim=self.head_dim,
scale=self.scale,
qk_nope_head_dim=self.nope_head_dim,
qk_rope_head_dim=self.rope_head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
compress_ratio=self.compress_ratio,
window_size=self.window_size,
head_bytes=head_bytes,
swa_cache_layer=self.swa_cache_layer,
attn_sink=attn_sink, # already padded with -inf
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
indexer=self.indexer,
topk_indices_buffer=self.topk_indices_buffer,
)
# Mirror the inner layer's padded head count (single source of truth).
self.padded_heads = self.mla_attn.padded_heads
# ---- Attention layer setup (formerly DeepseekV4MLAAttention) ----
self.impl_cls = _select_v4_sparse_impl()
self.backend_cls = self.impl_cls.backend_cls
# Padded Q head count is dictated by the selected impl.
self.padded_heads = self.impl_cls.get_padded_num_q_heads(self.n_local_heads)
# Create the compressor for layers with compress_ratio > 1; after
# creating the DeepseekV4MLAAttention layer to get its cache.
self.max_num_batched_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens
)
self.max_model_len = vllm_config.model_config.max_model_len
# DeepseekV4 only supports fp8 kv-cache format for now.
kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "fp8"
assert kv_cache_dtype.startswith("fp8"), (
f"DeepseekV4 only supports fp8 kv-cache format for now, "
f"got {kv_cache_dtype}"
)
assert issubclass(self.get_attn_backend(), FlashMLASparseBackend), (
"Only FlashMLA Sparse Attention backend is supported for DeepseekV4 for now"
)
# FlashMLA Sparse Attention fp8 backend uses "fp8_ds_mla" kv-cache format
# Automatically convert fp8 kv-cache format to "fp8_ds_mla"
if (
issubclass(self.get_attn_backend(), FlashMLASparseBackend)
and kv_cache_dtype.startswith("fp8")
and kv_cache_dtype != "fp8_ds_mla"
):
assert cache_config is not None
cache_config.cache_dtype = "fp8_ds_mla"
kv_cache_dtype = "fp8_ds_mla"
logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.")
self.kv_cache_dtype = kv_cache_dtype
# Register with compilation context for metadata lookup
compilation_config = vllm_config.compilation_config
if prefix and prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
if prefix:
compilation_config.static_forward_context[prefix] = self
self.kv_cache = torch.tensor([])
# Create the compressor for layers with compress_ratio > 1; after the
# attention-layer setup above so its KV cache prefix is available.
self.compressor = None
if self.compress_ratio > 1:
self.compressor = DeepseekCompressor(
@@ -235,7 +279,7 @@ class DeepseekV4MLA(nn.Module):
head_dim=self.head_dim,
rotate=True,
prefix=f"{prefix}.compressor",
k_cache_prefix=self.mla_attn.prefix,
k_cache_prefix=self.prefix,
)
def forward(
@@ -449,7 +493,7 @@ class DeepseekV4MLA(nn.Module):
# MLA attention writes into the pre-allocated `out` buffer
# ([num_tokens, padded_heads, head_dim]).
self.mla_attn(q, kv, positions, output=out)
self.impl_cls.forward_mqa(self, q, kv, positions, out)
def _fused_qnorm_rope_kv_insert(
self,
@@ -498,102 +542,6 @@ class DeepseekV4MLA(nn.Module):
swa_metadata.block_size,
)
class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
def __init__(
self,
num_heads: int,
head_dim: int,
scale: float,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
q_lora_rank: int | None,
kv_lora_rank: int,
compress_ratio: int,
window_size: int,
head_bytes: int,
swa_cache_layer: DeepseekV4SWACache,
attn_sink: torch.Tensor,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
# Sparse MLA Args
indexer: object | None = None,
topk_indices_buffer: torch.Tensor | None = None,
aux_stream: torch.cuda.Stream | None = None,
**extra_impl_args,
) -> None:
super().__init__()
self.impl_cls = _select_v4_sparse_impl()
self.backend_cls = self.impl_cls.backend_cls
self.num_heads = num_heads
self.num_kv_heads = 1
self.head_dim = head_dim
self.scale = scale
self.window_size = window_size
self.head_bytes = head_bytes
self.compress_ratio = compress_ratio
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.nope_head_dim = qk_nope_head_dim
self.rope_head_dim = qk_rope_head_dim
self.indexer = indexer
self.topk_indices_buffer = topk_indices_buffer
self.prefix = prefix # Alias for compatibility with compressor
self.aux_stream = aux_stream
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
# Padded Q head count is dictated by the selected impl.
self.padded_heads = self.impl_cls.get_padded_num_q_heads(num_heads)
# Store attention sink
assert attn_sink is not None
self.attn_sink: torch.Tensor = attn_sink
# Store SWA cache
assert swa_cache_layer is not None
self.swa_cache_layer: DeepseekV4SWACache = swa_cache_layer
# Get vllm config for cache setup
vllm_config = get_current_vllm_config()
self.max_num_batched_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens
)
self.max_model_len = vllm_config.model_config.max_model_len
# DeepseekV4 only supports fp8 kv-cache format for now.
kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "fp8"
assert kv_cache_dtype.startswith("fp8"), (
f"DeepseekV4 only supports fp8 kv-cache format for now, "
f"got {kv_cache_dtype}"
)
assert issubclass(self.get_attn_backend(), FlashMLASparseBackend), (
"Only FlashMLA Sparse Attention backend is supported for DeepseekV4 for now"
)
# FlashMLA Sparse Attention fp8 backend uses "fp8_ds_mla" kv-cache format
# Automatically convert fp8 kv-cache format to "fp8_ds_mla"
if (
issubclass(self.get_attn_backend(), FlashMLASparseBackend)
and kv_cache_dtype.startswith("fp8")
and kv_cache_dtype != "fp8_ds_mla"
):
assert cache_config is not None
cache_config.cache_dtype = "fp8_ds_mla"
kv_cache_dtype = "fp8_ds_mla"
logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.")
self.kv_cache_dtype = kv_cache_dtype
# Register with compilation context for metadata lookup
compilation_config = vllm_config.compilation_config
if prefix and prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
if prefix:
compilation_config.static_forward_context[prefix] = self
self.kv_cache = torch.tensor([])
def get_attn_backend(self) -> type[AttentionBackend]:
return self.backend_cls
@@ -613,15 +561,6 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
model_version="deepseek_v4",
)
def forward(
self,
q: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
output: torch.Tensor,
) -> None:
self.impl_cls.forward_mqa(self, q, kv, positions, output)
class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase):
def __init__(
@@ -778,8 +717,6 @@ class DeepseekV4Indexer(nn.Module):
positions: torch.Tensor,
rotary_emb: nn.Module,
) -> torch.Tensor:
compressor = self.compressor
def wq_b_and_q_quant():
# ReplicatedLinear returns (output, bias); bias is None.
q, _ = self.wq_b(qr)
@@ -798,7 +735,7 @@ class DeepseekV4Indexer(nn.Module):
# join orders that write before indexer_op (skip_k_cache_insert=True).
(q_quant, weights), k = maybe_execute_in_parallel(
wq_b_and_q_quant,
lambda: compressor(compressed_kv_score, positions, rotary_emb),
lambda: self.compressor(compressed_kv_score, positions, rotary_emb),
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
+6 -6
View File
@@ -29,7 +29,7 @@ from vllm.v1.worker.workspace import current_workspace_manager
if TYPE_CHECKING:
from vllm.models.deepseek_v4.attention import (
DeepseekV4MLAAttention,
DeepseekV4Attention,
)
from vllm.v1.attention.backends.mla.sparse_swa import DeepseekSparseSWAMetadata
@@ -37,7 +37,7 @@ if TYPE_CHECKING:
class DeepseekV4SparseMLAAttentionImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
"""Abstract parent for DeepseekV4 sparse MLA impls.
V4 sparse MLA is driven by the layer (``DeepseekV4MLAAttention.forward``)
V4 sparse MLA is driven by the layer (``DeepseekV4Attention.forward``)
rather than the v1 framework, so ``forward_mqa`` is overridden with a
classmethod that takes the layer as its first argument. This Liskov-broken
override is intentional: the grandparent's instance-method ``forward_mqa``
@@ -55,7 +55,7 @@ class DeepseekV4SparseMLAAttentionImpl(SparseMLAAttentionImpl[FlashMLASparseMeta
@abstractmethod
def forward_mqa( # type: ignore[override]
cls,
layer: "DeepseekV4MLAAttention",
layer: "DeepseekV4Attention",
q: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
@@ -129,7 +129,7 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
@classmethod
def forward_mqa( # type: ignore[override]
cls,
layer: "DeepseekV4MLAAttention",
layer: "DeepseekV4Attention",
q: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
@@ -210,7 +210,7 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
@classmethod
def _forward_decode(
cls,
layer: "DeepseekV4MLAAttention",
layer: "DeepseekV4Attention",
q: torch.Tensor,
kv_cache: torch.Tensor | None, # Only used when compress_ratio > 1
swa_metadata: "DeepseekSparseSWAMetadata",
@@ -304,7 +304,7 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
@classmethod
def _forward_prefill(
cls,
layer: "DeepseekV4MLAAttention",
layer: "DeepseekV4Attention",
q: torch.Tensor,
positions: torch.Tensor,
compressed_k_cache: torch.Tensor | None, # Only used when compress_ratio > 1
+2 -165
View File
@@ -29,7 +29,6 @@ from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
@@ -52,10 +51,8 @@ from vllm.model_executor.models.utils import (
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.models.deepseek_v4.attention import (
DeepseekV4Indexer,
DeepseekV4MLA,
DeepseekV4Attention,
)
from vllm.models.deepseek_v4.common.rope import build_deepseek_v4_rope
from vllm.models.deepseek_v4.nvidia.ops.prepare_megamoe import prepare_megamoe_inputs
from vllm.sequence import IntermediateTensors
@@ -608,165 +605,6 @@ class DeepseekV4MoE(nn.Module):
self.experts.finalize_weights()
class DeepseekV4Attention(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str,
topk_indices_buffer: torch.Tensor | None = None,
aux_stream_list: list[torch.cuda.Stream] | None = None,
):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
layer_id = extract_layer_index(prefix)
self.layer_id = layer_id
self.hidden_size = config.hidden_size
self.n_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
assert self.n_heads % tp_size == 0
self.n_local_heads = self.n_heads // tp_size
self.q_lora_rank = config.q_lora_rank
self.o_lora_rank = config.o_lora_rank
self.head_dim = config.head_dim
self.rope_head_dim = config.qk_rope_head_dim
self.nope_head_dim = self.head_dim - self.rope_head_dim
self.n_groups = config.o_groups
self.n_local_groups = self.n_groups // tp_size
self.window_size = config.sliding_window
# NOTE(zyongye) Compress ratio can't be 0
# we do this for because MTP layer is not included
# in the compress ratio list
if layer_id < config.num_hidden_layers:
self.compress_ratio = max(1, config.compress_ratios[layer_id])
else:
self.compress_ratio = 1
self.eps = config.rms_norm_eps
self.max_position_embeddings = config.max_position_embeddings
# Padded to min 64 heads for FlashMLA, initialized to -inf
# (no sink effect). Weight loading fills the first n_local_heads slots.
padded_heads = max(self.n_local_heads, 64)
self.attn_sink = nn.Parameter(
torch.full((padded_heads,), -float("inf"), dtype=torch.float32),
requires_grad=False,
)
self.fused_wqa_wkv = MergedColumnParallelLinear(
self.hidden_size,
[self.q_lora_rank, self.head_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fused_wqa_wkv",
disable_tp=True, # fused ReplicatedLinear
)
self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
self.wq_b = ColumnParallelLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wq_b",
)
self.kv_norm = RMSNorm(self.head_dim, self.eps)
self.wo_a = ColumnParallelLinear(
self.n_heads * self.head_dim // self.n_groups,
self.n_groups * self.o_lora_rank,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wo_a",
)
self.wo_a.is_bmm = True
self.wo_a.bmm_batch_size = self.n_local_groups
self.wo_b = RowParallelLinear(
self.n_groups * self.o_lora_rank,
self.hidden_size,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wo_b",
)
self.softmax_scale = self.head_dim**-0.5
self.scale_fmt = config.quantization_config["scale_fmt"]
self.rope_parameters = config.rope_scaling
# Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it)
self.rotary_emb = build_deepseek_v4_rope(
config,
head_dim=self.head_dim,
rope_head_dim=self.rope_head_dim,
max_position_embeddings=self.max_position_embeddings,
compress_ratio=self.compress_ratio,
)
self.indexer = None
if self.compress_ratio == 4:
# Only C4A uses sparse attention and hence has indexer.
# aux_stream_list[0] runs indexer.forward() in the wrapper; [2] is
# free here (outer GEMMs joined) for the inner overlap of
# wq_b+fused_indexer_q_rope_quant vs compressor.
indexer_aux_stream = (
aux_stream_list[2] if aux_stream_list is not None else None
)
self.indexer = DeepseekV4Indexer(
vllm_config,
config=config,
hidden_size=self.hidden_size,
q_lora_rank=self.q_lora_rank,
quant_config=quant_config,
cache_config=vllm_config.cache_config,
topk_indices_buffer=topk_indices_buffer,
compress_ratio=self.compress_ratio,
prefix=f"{prefix}.indexer",
aux_stream=indexer_aux_stream,
)
self.mla_attn = DeepseekV4MLA(
hidden_size=self.hidden_size,
num_heads=self.n_local_heads,
head_dim=self.head_dim,
scale=self.softmax_scale,
qk_nope_head_dim=self.nope_head_dim,
qk_rope_head_dim=self.rope_head_dim,
v_head_dim=self.head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.head_dim,
o_lora_rank=self.o_lora_rank,
vllm_config=vllm_config,
fused_wqa_wkv=self.fused_wqa_wkv,
q_norm=self.q_norm,
wq_b=self.wq_b,
kv_norm=self.kv_norm,
wo_a=self.wo_a,
wo_b=self.wo_b,
attn_sink=self.attn_sink,
rotary_emb=self.rotary_emb,
indexer=self.indexer,
indexer_rotary_emb=self.rotary_emb,
topk_indices_buffer=topk_indices_buffer,
aux_stream_list=aux_stream_list,
window_size=self.window_size,
compress_ratio=self.compress_ratio,
cache_config=vllm_config.cache_config,
quant_config=quant_config,
prefix=prefix,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None,
):
return self.mla_attn(positions, hidden_states, llama_4_scaling)
class DeepseekV4DecoderLayer(nn.Module):
def __init__(
self,
@@ -938,7 +776,7 @@ class DeepseekV4Model(nn.Module):
self.rms_norm_eps = config.rms_norm_eps
# Three aux streams: one per non-default input GEMM in
# DeepseekV4MLA.attn_gemm_parallel_execute
# DeepseekV4Attention.attn_gemm_parallel_execute
# (compressor kv_score, indexer.weights_proj, indexer.compressor
# kv_score). fused_wqa_wkv stays on the default stream.
aux_stream_list = [torch.cuda.Stream() for _ in range(3)]
@@ -1236,7 +1074,6 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
".ffn.gate.bias": ".ffn.gate.e_score_correction_bias",
},
orig_to_new_substr={
".attn.compressor.": ".attn.mla_attn.compressor.",
".shared_experts.w2": ".shared_experts.down_proj",
},
)
@@ -602,7 +602,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
fp8_use_mixed_batch = (
self.num_heads < MIN_HEADS_FOR_BF16_PREFILL and not self.is_deepseek_v4
)
# DeepseekV4 has its own attention impl (DeepseekV4MLAAttention) that does not
# DeepseekV4 has its own attention impl (DeepseekV4Attention) that does not
# consume fp8_extra_metadata. Skipping the build here avoids a
# forced D2H sync on seq_lens that would otherwise fire on every
# prefill-bearing step, lifting GPU utilization on long-prefill