From 1ff2b11e1728e57d7a232a8f12544cb564da230e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 1 Jun 2026 23:27:28 +0000 Subject: [PATCH] [DSv4] Refactor DeepseekV4Attention Signed-off-by: Woosuk Kwon --- vllm/models/deepseek_v4/amd/model.py | 160 +------- vllm/models/deepseek_v4/amd/rocm.py | 8 +- vllm/models/deepseek_v4/attention.py | 365 ++++++++---------- vllm/models/deepseek_v4/nvidia/flashmla.py | 12 +- vllm/models/deepseek_v4/nvidia/model.py | 167 +------- .../attention/backends/mla/flashmla_sparse.py | 2 +- 6 files changed, 166 insertions(+), 548 deletions(-) diff --git a/vllm/models/deepseek_v4/amd/model.py b/vllm/models/deepseek_v4/amd/model.py index fb724fbe2f1..744c502f8dd 100644 --- a/vllm/models/deepseek_v4/amd/model.py +++ b/vllm/models/deepseek_v4/amd/model.py @@ -18,7 +18,6 @@ from vllm.model_executor.layers.activation import SiluAndMul, SiluAndMulWithClam from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear, ) @@ -46,10 +45,8 @@ from vllm.model_executor.models.utils import ( maybe_prefix, ) from vllm.models.deepseek_v4.attention import ( - DeepseekV4Indexer, - DeepseekV4MLA, + DeepseekV4Attention, ) -from vllm.models.deepseek_v4.common.rope import build_deepseek_v4_rope from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.import_utils import has_tilelang @@ -225,158 +222,6 @@ class DeepseekV4MoE(nn.Module): return final_hidden_states.view(org_shape) -class DeepseekV4Attention(nn.Module): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str, - topk_indices_buffer: torch.Tensor | None = None, - aux_stream_list: list[torch.cuda.Stream] | None = None, - ): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - layer_id = extract_layer_index(prefix) - - self.layer_id = layer_id - self.hidden_size = config.hidden_size - self.n_heads = config.num_attention_heads - tp_size = get_tensor_model_parallel_world_size() - assert self.n_heads % tp_size == 0 - - self.n_local_heads = self.n_heads // tp_size - self.q_lora_rank = config.q_lora_rank - self.o_lora_rank = config.o_lora_rank - self.head_dim = config.head_dim - self.rope_head_dim = config.qk_rope_head_dim - self.nope_head_dim = self.head_dim - self.rope_head_dim - self.n_groups = config.o_groups - self.n_local_groups = self.n_groups // tp_size - self.window_size = config.sliding_window - # NOTE(zyongye) Compress ratio can't be 0 - # we do this for because MTP layer is not included - # in the compress ratio list - if layer_id < config.num_hidden_layers: - self.compress_ratio = max(1, config.compress_ratios[layer_id]) - else: - self.compress_ratio = 1 - self.eps = config.rms_norm_eps - self.max_position_embeddings = config.max_position_embeddings - - # Padded to min 64 heads for FlashMLA, initialized to -inf - # (no sink effect). Weight loading fills the first n_local_heads slots. - padded_heads = max(self.n_local_heads, 64) - self.attn_sink = nn.Parameter( - torch.full((padded_heads,), -float("inf"), dtype=torch.float32), - requires_grad=False, - ) - - self.fused_wqa_wkv = MergedColumnParallelLinear( - self.hidden_size, - [self.q_lora_rank, self.head_dim], - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.fused_wqa_wkv", - disable_tp=True, # fused ReplicatedLinear - ) - self.q_norm = RMSNorm(self.q_lora_rank, self.eps) - self.wq_b = ColumnParallelLinear( - self.q_lora_rank, - self.n_heads * self.head_dim, - bias=False, - quant_config=quant_config, - return_bias=False, - prefix=f"{prefix}.wq_b", - ) - - self.kv_norm = RMSNorm(self.head_dim, self.eps) - self.wo_a = ColumnParallelLinear( - self.n_heads * self.head_dim // self.n_groups, - self.n_groups * self.o_lora_rank, - bias=False, - quant_config=quant_config, - return_bias=False, - prefix=f"{prefix}.wo_a", - ) - self.wo_a.is_bmm = True - self.wo_a.bmm_batch_size = self.n_local_groups - self.wo_b = RowParallelLinear( - self.n_groups * self.o_lora_rank, - self.hidden_size, - bias=False, - quant_config=quant_config, - return_bias=False, - prefix=f"{prefix}.wo_b", - ) - self.softmax_scale = self.head_dim**-0.5 - self.scale_fmt = config.quantization_config["scale_fmt"] - - self.rope_parameters = config.rope_scaling - - # Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it) - self.rotary_emb = build_deepseek_v4_rope( - config, - head_dim=self.head_dim, - rope_head_dim=self.rope_head_dim, - max_position_embeddings=self.max_position_embeddings, - compress_ratio=self.compress_ratio, - ) - - self.indexer = None - if self.compress_ratio == 4: - # Only C4A uses sparse attention and hence has indexer. - self.indexer = DeepseekV4Indexer( - vllm_config, - config=config, - hidden_size=self.hidden_size, - q_lora_rank=self.q_lora_rank, - quant_config=quant_config, - cache_config=vllm_config.cache_config, - topk_indices_buffer=topk_indices_buffer, - compress_ratio=self.compress_ratio, - prefix=f"{prefix}.indexer", - ) - - self.mla_attn = DeepseekV4MLA( - hidden_size=self.hidden_size, - num_heads=self.n_local_heads, - head_dim=self.head_dim, - scale=self.softmax_scale, - qk_nope_head_dim=self.nope_head_dim, - qk_rope_head_dim=self.rope_head_dim, - v_head_dim=self.head_dim, - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.head_dim, - o_lora_rank=self.o_lora_rank, - vllm_config=vllm_config, - fused_wqa_wkv=self.fused_wqa_wkv, - q_norm=self.q_norm, - wq_b=self.wq_b, - kv_norm=self.kv_norm, - wo_a=self.wo_a, - wo_b=self.wo_b, - attn_sink=self.attn_sink, - rotary_emb=self.rotary_emb, - indexer=self.indexer, - indexer_rotary_emb=self.rotary_emb, - topk_indices_buffer=topk_indices_buffer, - aux_stream_list=aux_stream_list, - window_size=self.window_size, - compress_ratio=self.compress_ratio, - cache_config=vllm_config.cache_config, - quant_config=quant_config, - prefix=prefix, - ) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - llama_4_scaling: torch.Tensor | None, - ): - return self.mla_attn(positions, hidden_states, llama_4_scaling) - - class DeepseekV4DecoderLayer(nn.Module): def __init__( self, @@ -601,7 +446,7 @@ class DeepseekV4Model(nn.Module): self.rms_norm_eps = config.rms_norm_eps # Three aux streams: one per non-default input GEMM in - # DeepseekV4MLA.attn_gemm_parallel_execute + # DeepseekV4Attention.attn_gemm_parallel_execute # (compressor kv_score, indexer.weights_proj, indexer.compressor # kv_score). fused_wqa_wkv stays on the default stream. # Disable them on ROCm because of hang issues. @@ -897,7 +742,6 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper: ".ffn.gate.bias": ".ffn.gate.e_score_correction_bias", }, orig_to_new_substr={ - ".attn.compressor.": ".attn.mla_attn.compressor.", ".shared_experts.w2": ".shared_experts.down_proj", }, ) diff --git a/vllm/models/deepseek_v4/amd/rocm.py b/vllm/models/deepseek_v4/amd/rocm.py index 2af93fba31e..7470e5bab07 100644 --- a/vllm/models/deepseek_v4/amd/rocm.py +++ b/vllm/models/deepseek_v4/amd/rocm.py @@ -33,7 +33,7 @@ from vllm.v1.worker.workspace import current_workspace_manager if TYPE_CHECKING: from vllm.models.deepseek_v4.attention import ( - DeepseekV4MLAAttention, + DeepseekV4Attention, ) @@ -599,7 +599,7 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): @classmethod def forward_mqa( # type: ignore[override] cls, - layer: "DeepseekV4MLAAttention", + layer: "DeepseekV4Attention", q: torch.Tensor, kv: torch.Tensor, positions: torch.Tensor, @@ -677,7 +677,7 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): @classmethod def _forward_decode( cls, - layer: "DeepseekV4MLAAttention", + layer: "DeepseekV4Attention", q: torch.Tensor, kv_cache: torch.Tensor | None, swa_metadata: DeepseekV4ROCMAiterSparseSWAMetadata, @@ -740,7 +740,7 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): @classmethod def _forward_prefill( cls, - layer: "DeepseekV4MLAAttention", + layer: "DeepseekV4Attention", q: torch.Tensor, positions: torch.Tensor, compressed_k_cache: torch.Tensor | None, diff --git a/vllm/models/deepseek_v4/attention.py b/vllm/models/deepseek_v4/attention.py index 55cb3d94ba6..9d6a5e71501 100644 --- a/vllm/models/deepseek_v4/attention.py +++ b/vllm/models/deepseek_v4/attention.py @@ -15,7 +15,10 @@ from transformers import DeepseekV2Config, DeepseekV3Config import vllm.envs as envs from vllm.compilation.breakable_cudagraph import eager_break_during_capture from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, ReplicatedLinear, + RowParallelLinear, ) from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer from vllm.models.deepseek_v4.common.ops import ( @@ -42,12 +45,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.quantization.input_quant_fp8 import ( - QuantFP8, -) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, -) +from vllm.model_executor.models.utils import extract_layer_index +from vllm.models.deepseek_v4.common.rope import build_deepseek_v4_rope from vllm.models.deepseek_v4.compressor import DeepseekCompressor from vllm.platforms import current_platform from vllm.utils.multi_stream_utils import ( @@ -88,78 +87,89 @@ def _select_v4_sparse_impl() -> "type[DeepseekV4SparseMLAAttentionImpl]": return DeepseekV4FlashMLASparseImpl -class DeepseekV4MLA(nn.Module): +class DeepseekV4Attention(nn.Module, AttentionLayerBase): def __init__( self, - hidden_size: int, - num_heads: int, - head_dim: int, - scale: float, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - v_head_dim: int, - q_lora_rank: int | None, - kv_lora_rank: int, - o_lora_rank: int | None, vllm_config: VllmConfig, - fused_wqa_wkv: torch.nn.Module, - q_norm: torch.nn.Module, - wq_b: torch.nn.Module, - kv_norm: torch.nn.Module, - wo_a: torch.nn.Module, - wo_b: torch.nn.Module, - attn_sink: torch.nn.Module, - rotary_emb: torch.nn.Module, - indexer: torch.nn.Module | None, - indexer_rotary_emb: torch.nn.Module, - topk_indices_buffer: torch.Tensor | None, - aux_stream_list: list[torch.cuda.Stream] | None, - window_size: int, - compress_ratio: int | None, - cache_config: CacheConfig | None = None, - quant_config: QuantizationConfig | None = None, - prefix: str = "", + prefix: str, + topk_indices_buffer: torch.Tensor | None = None, + aux_stream_list: list[torch.cuda.Stream] | None = None, ) -> None: super().__init__() - self.hidden_size = hidden_size - self.n_local_heads = num_heads - self.head_dim = head_dim - self.scale = scale - - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank - self.window_size = window_size - self.compress_ratio = compress_ratio if compress_ratio is not None else 1 - self.prefix = prefix - - # Extract config from vllm_config config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + cache_config = vllm_config.cache_config tp_size = get_tensor_model_parallel_world_size() + layer_id = extract_layer_index(prefix) - # DeepseekV4-specific attributes (num_heads is already TP-adjusted) - self.eps = config.rms_norm_eps - self.rope_head_dim = config.qk_rope_head_dim - self.nope_head_dim = head_dim - self.rope_head_dim - self.n_local_groups = config.o_groups // tp_size + self.prefix = prefix # Alias for compatibility with compressor + self.hidden_size = config.hidden_size + self.n_heads = config.num_attention_heads + assert self.n_heads % tp_size == 0 + self.n_local_heads = self.n_heads // tp_size + self.q_lora_rank = config.q_lora_rank self.o_lora_rank = config.o_lora_rank + self.head_dim = config.head_dim + self.rope_head_dim = config.qk_rope_head_dim + self.nope_head_dim = self.head_dim - self.rope_head_dim + self.n_groups = config.o_groups + self.n_local_groups = self.n_groups // tp_size + self.window_size = config.sliding_window + # NOTE(zyongye) Compress ratio can't be 0 + # we do this for because MTP layer is not included + # in the compress ratio list + if layer_id < config.num_hidden_layers: + self.compress_ratio = max(1, config.compress_ratios[layer_id]) + else: + self.compress_ratio = 1 + self.eps = config.rms_norm_eps + self.scale = self.head_dim**-0.5 - # Store projection modules - self.fused_wqa_wkv = fused_wqa_wkv - self.q_norm = q_norm - self.wq_b = wq_b - - self.kv_norm = kv_norm - self.wo_a = wo_a - - self._wo_a_act_quant = QuantFP8( - static=False, - group_shape=GroupShape(1, 128), - use_ue8m0=True, + # Padded to min 64 heads for FlashMLA, initialized to -inf + # (no sink effect). Weight loading fills the first n_local_heads slots. + padded_heads = max(self.n_local_heads, 64) + self.attn_sink = nn.Parameter( + torch.full((padded_heads,), -float("inf"), dtype=torch.float32), + requires_grad=False, + ) + + self.fused_wqa_wkv = MergedColumnParallelLinear( + self.hidden_size, + [self.q_lora_rank, self.head_dim], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.fused_wqa_wkv", + disable_tp=True, # fused ReplicatedLinear + ) + self.q_norm = RMSNorm(self.q_lora_rank, self.eps) + self.wq_b = ColumnParallelLinear( + self.q_lora_rank, + self.n_heads * self.head_dim, + bias=False, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.wq_b", + ) + + self.kv_norm = RMSNorm(self.head_dim, self.eps) + self.wo_a = ColumnParallelLinear( + self.n_heads * self.head_dim // self.n_groups, + self.n_groups * self.o_lora_rank, + bias=False, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.wo_a", + ) + self.wo_a.is_bmm = True + self.wo_a.bmm_batch_size = self.n_local_groups + self.wo_b = RowParallelLinear( + self.n_groups * self.o_lora_rank, + self.hidden_size, + bias=False, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.wo_b", ) - # Bypass packed-for-deepgemm path — we need FP32 scales (not packed - # INT32) so fp8_einsum can handle layout transform internally. - self._wo_a_act_quant.use_deep_gemm_supported = False - self.wo_b = wo_b # Pick fp8_einsum recipe based on GPU arch: # SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128 @@ -169,22 +179,38 @@ class DeepseekV4MLA(nn.Module): self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128) self._tma_aligned_scales = cap.major >= 10 - self.rotary_emb = rotary_emb - self.indexer_rotary_emb = indexer_rotary_emb + # Initialize rotary embedding before the indexer/compressor consume it. + self.rotary_emb = build_deepseek_v4_rope( + config, + head_dim=self.head_dim, + rope_head_dim=self.rope_head_dim, + max_position_embeddings=config.max_position_embeddings, + compress_ratio=self.compress_ratio, + ) + self.indexer_rotary_emb = self.rotary_emb self.topk_indices_buffer = topk_indices_buffer - self.indexer = indexer - - # Per-head RMS normalization for Q (no learnable weights) - self.q_head_norm = RMSNorm(head_dim, eps=self.eps, has_weight=False) - - # TODO(yifan): currently hardcoded for FP8 sparse, make it more generic - head_bytes = ( - self.nope_head_dim # 448 fp8 NoPE - + self.rope_head_dim * 2 # 64 bf16 RoPE - + self.nope_head_dim // 64 # 7B scale factors - + 1 # 1B pad - ) + self.indexer = None + if self.compress_ratio == 4: + # Only C4A uses sparse attention and hence has indexer. + # aux_stream_list[2] is free here (outer GEMMs joined) for the inner + # overlap of wq_b+fused_indexer_q_rope_quant vs compressor. None on + # ROCm, where aux_stream_list is None. + indexer_aux_stream = ( + aux_stream_list[2] if aux_stream_list is not None else None + ) + self.indexer = DeepseekV4Indexer( + vllm_config, + config=config, + hidden_size=self.hidden_size, + q_lora_rank=self.q_lora_rank, + quant_config=quant_config, + cache_config=cache_config, + topk_indices_buffer=topk_indices_buffer, + compress_ratio=self.compress_ratio, + prefix=f"{prefix}.indexer", + aux_stream=indexer_aux_stream, + ) # Will be None on ROCm for now. self.aux_stream_list = aux_stream_list @@ -202,30 +228,48 @@ class DeepseekV4MLA(nn.Module): cache_config=cache_config, ) - self.mla_attn = DeepseekV4MLAAttention( - num_heads=self.n_local_heads, - head_dim=self.head_dim, - scale=self.scale, - qk_nope_head_dim=self.nope_head_dim, - qk_rope_head_dim=self.rope_head_dim, - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - compress_ratio=self.compress_ratio, - window_size=self.window_size, - head_bytes=head_bytes, - swa_cache_layer=self.swa_cache_layer, - attn_sink=attn_sink, # already padded with -inf - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix, - indexer=self.indexer, - topk_indices_buffer=self.topk_indices_buffer, - ) - # Mirror the inner layer's padded head count (single source of truth). - self.padded_heads = self.mla_attn.padded_heads + # ---- Attention layer setup (formerly DeepseekV4MLAAttention) ---- + self.impl_cls = _select_v4_sparse_impl() + self.backend_cls = self.impl_cls.backend_cls + # Padded Q head count is dictated by the selected impl. + self.padded_heads = self.impl_cls.get_padded_num_q_heads(self.n_local_heads) - # Create the compressor for layers with compress_ratio > 1; after - # creating the DeepseekV4MLAAttention layer to get its cache. + self.max_num_batched_tokens = ( + vllm_config.scheduler_config.max_num_batched_tokens + ) + self.max_model_len = vllm_config.model_config.max_model_len + # DeepseekV4 only supports fp8 kv-cache format for now. + kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "fp8" + assert kv_cache_dtype.startswith("fp8"), ( + f"DeepseekV4 only supports fp8 kv-cache format for now, " + f"got {kv_cache_dtype}" + ) + assert issubclass(self.get_attn_backend(), FlashMLASparseBackend), ( + "Only FlashMLA Sparse Attention backend is supported for DeepseekV4 for now" + ) + # FlashMLA Sparse Attention fp8 backend uses "fp8_ds_mla" kv-cache format + # Automatically convert fp8 kv-cache format to "fp8_ds_mla" + if ( + issubclass(self.get_attn_backend(), FlashMLASparseBackend) + and kv_cache_dtype.startswith("fp8") + and kv_cache_dtype != "fp8_ds_mla" + ): + assert cache_config is not None + cache_config.cache_dtype = "fp8_ds_mla" + kv_cache_dtype = "fp8_ds_mla" + logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.") + self.kv_cache_dtype = kv_cache_dtype + + # Register with compilation context for metadata lookup + compilation_config = vllm_config.compilation_config + if prefix and prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + if prefix: + compilation_config.static_forward_context[prefix] = self + self.kv_cache = torch.tensor([]) + + # Create the compressor for layers with compress_ratio > 1; after the + # attention-layer setup above so its KV cache prefix is available. self.compressor = None if self.compress_ratio > 1: self.compressor = DeepseekCompressor( @@ -235,7 +279,7 @@ class DeepseekV4MLA(nn.Module): head_dim=self.head_dim, rotate=True, prefix=f"{prefix}.compressor", - k_cache_prefix=self.mla_attn.prefix, + k_cache_prefix=self.prefix, ) def forward( @@ -449,7 +493,7 @@ class DeepseekV4MLA(nn.Module): # MLA attention writes into the pre-allocated `out` buffer # ([num_tokens, padded_heads, head_dim]). - self.mla_attn(q, kv, positions, output=out) + self.impl_cls.forward_mqa(self, q, kv, positions, out) def _fused_qnorm_rope_kv_insert( self, @@ -498,102 +542,6 @@ class DeepseekV4MLA(nn.Module): swa_metadata.block_size, ) - -class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase): - def __init__( - self, - num_heads: int, - head_dim: int, - scale: float, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - q_lora_rank: int | None, - kv_lora_rank: int, - compress_ratio: int, - window_size: int, - head_bytes: int, - swa_cache_layer: DeepseekV4SWACache, - attn_sink: torch.Tensor, - cache_config: CacheConfig | None = None, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - # Sparse MLA Args - indexer: object | None = None, - topk_indices_buffer: torch.Tensor | None = None, - aux_stream: torch.cuda.Stream | None = None, - **extra_impl_args, - ) -> None: - super().__init__() - self.impl_cls = _select_v4_sparse_impl() - self.backend_cls = self.impl_cls.backend_cls - self.num_heads = num_heads - self.num_kv_heads = 1 - self.head_dim = head_dim - self.scale = scale - self.window_size = window_size - self.head_bytes = head_bytes - self.compress_ratio = compress_ratio - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank - self.nope_head_dim = qk_nope_head_dim - self.rope_head_dim = qk_rope_head_dim - self.indexer = indexer - self.topk_indices_buffer = topk_indices_buffer - - self.prefix = prefix # Alias for compatibility with compressor - - self.aux_stream = aux_stream - self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] - - # Padded Q head count is dictated by the selected impl. - self.padded_heads = self.impl_cls.get_padded_num_q_heads(num_heads) - - # Store attention sink - assert attn_sink is not None - self.attn_sink: torch.Tensor = attn_sink - # Store SWA cache - assert swa_cache_layer is not None - self.swa_cache_layer: DeepseekV4SWACache = swa_cache_layer - - # Get vllm config for cache setup - vllm_config = get_current_vllm_config() - self.max_num_batched_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens - ) - self.max_model_len = vllm_config.model_config.max_model_len - # DeepseekV4 only supports fp8 kv-cache format for now. - kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "fp8" - - assert kv_cache_dtype.startswith("fp8"), ( - f"DeepseekV4 only supports fp8 kv-cache format for now, " - f"got {kv_cache_dtype}" - ) - assert issubclass(self.get_attn_backend(), FlashMLASparseBackend), ( - "Only FlashMLA Sparse Attention backend is supported for DeepseekV4 for now" - ) - # FlashMLA Sparse Attention fp8 backend uses "fp8_ds_mla" kv-cache format - # Automatically convert fp8 kv-cache format to "fp8_ds_mla" - if ( - issubclass(self.get_attn_backend(), FlashMLASparseBackend) - and kv_cache_dtype.startswith("fp8") - and kv_cache_dtype != "fp8_ds_mla" - ): - assert cache_config is not None - cache_config.cache_dtype = "fp8_ds_mla" - kv_cache_dtype = "fp8_ds_mla" - logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.") - - self.kv_cache_dtype = kv_cache_dtype - - # Register with compilation context for metadata lookup - compilation_config = vllm_config.compilation_config - if prefix and prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - if prefix: - compilation_config.static_forward_context[prefix] = self - - self.kv_cache = torch.tensor([]) - def get_attn_backend(self) -> type[AttentionBackend]: return self.backend_cls @@ -613,15 +561,6 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase): model_version="deepseek_v4", ) - def forward( - self, - q: torch.Tensor, - kv: torch.Tensor, - positions: torch.Tensor, - output: torch.Tensor, - ) -> None: - self.impl_cls.forward_mqa(self, q, kv, positions, output) - class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase): def __init__( @@ -778,8 +717,6 @@ class DeepseekV4Indexer(nn.Module): positions: torch.Tensor, rotary_emb: nn.Module, ) -> torch.Tensor: - compressor = self.compressor - def wq_b_and_q_quant(): # ReplicatedLinear returns (output, bias); bias is None. q, _ = self.wq_b(qr) @@ -798,7 +735,7 @@ class DeepseekV4Indexer(nn.Module): # join orders that write before indexer_op (skip_k_cache_insert=True). (q_quant, weights), k = maybe_execute_in_parallel( wq_b_and_q_quant, - lambda: compressor(compressed_kv_score, positions, rotary_emb), + lambda: self.compressor(compressed_kv_score, positions, rotary_emb), self.ln_events[0], self.ln_events[1], self.aux_stream, diff --git a/vllm/models/deepseek_v4/nvidia/flashmla.py b/vllm/models/deepseek_v4/nvidia/flashmla.py index 5c8b08d4c12..0cf6712633c 100644 --- a/vllm/models/deepseek_v4/nvidia/flashmla.py +++ b/vllm/models/deepseek_v4/nvidia/flashmla.py @@ -29,7 +29,7 @@ from vllm.v1.worker.workspace import current_workspace_manager if TYPE_CHECKING: from vllm.models.deepseek_v4.attention import ( - DeepseekV4MLAAttention, + DeepseekV4Attention, ) from vllm.v1.attention.backends.mla.sparse_swa import DeepseekSparseSWAMetadata @@ -37,7 +37,7 @@ if TYPE_CHECKING: class DeepseekV4SparseMLAAttentionImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]): """Abstract parent for DeepseekV4 sparse MLA impls. - V4 sparse MLA is driven by the layer (``DeepseekV4MLAAttention.forward``) + V4 sparse MLA is driven by the layer (``DeepseekV4Attention.forward``) rather than the v1 framework, so ``forward_mqa`` is overridden with a classmethod that takes the layer as its first argument. This Liskov-broken override is intentional: the grandparent's instance-method ``forward_mqa`` @@ -55,7 +55,7 @@ class DeepseekV4SparseMLAAttentionImpl(SparseMLAAttentionImpl[FlashMLASparseMeta @abstractmethod def forward_mqa( # type: ignore[override] cls, - layer: "DeepseekV4MLAAttention", + layer: "DeepseekV4Attention", q: torch.Tensor, kv: torch.Tensor, positions: torch.Tensor, @@ -129,7 +129,7 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): @classmethod def forward_mqa( # type: ignore[override] cls, - layer: "DeepseekV4MLAAttention", + layer: "DeepseekV4Attention", q: torch.Tensor, kv: torch.Tensor, positions: torch.Tensor, @@ -210,7 +210,7 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): @classmethod def _forward_decode( cls, - layer: "DeepseekV4MLAAttention", + layer: "DeepseekV4Attention", q: torch.Tensor, kv_cache: torch.Tensor | None, # Only used when compress_ratio > 1 swa_metadata: "DeepseekSparseSWAMetadata", @@ -304,7 +304,7 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): @classmethod def _forward_prefill( cls, - layer: "DeepseekV4MLAAttention", + layer: "DeepseekV4Attention", q: torch.Tensor, positions: torch.Tensor, compressed_k_cache: torch.Tensor | None, # Only used when compress_ratio > 1 diff --git a/vllm/models/deepseek_v4/nvidia/model.py b/vllm/models/deepseek_v4/nvidia/model.py index 13e58360c8b..079471a41a1 100644 --- a/vllm/models/deepseek_v4/nvidia/model.py +++ b/vllm/models/deepseek_v4/nvidia/model.py @@ -29,7 +29,6 @@ from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import ( from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear, ) @@ -52,10 +51,8 @@ from vllm.model_executor.models.utils import ( ) from vllm.model_executor.utils import set_weight_attrs from vllm.models.deepseek_v4.attention import ( - DeepseekV4Indexer, - DeepseekV4MLA, + DeepseekV4Attention, ) -from vllm.models.deepseek_v4.common.rope import build_deepseek_v4_rope from vllm.models.deepseek_v4.nvidia.ops.prepare_megamoe import prepare_megamoe_inputs from vllm.sequence import IntermediateTensors @@ -608,165 +605,6 @@ class DeepseekV4MoE(nn.Module): self.experts.finalize_weights() -class DeepseekV4Attention(nn.Module): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str, - topk_indices_buffer: torch.Tensor | None = None, - aux_stream_list: list[torch.cuda.Stream] | None = None, - ): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - layer_id = extract_layer_index(prefix) - - self.layer_id = layer_id - self.hidden_size = config.hidden_size - self.n_heads = config.num_attention_heads - tp_size = get_tensor_model_parallel_world_size() - assert self.n_heads % tp_size == 0 - - self.n_local_heads = self.n_heads // tp_size - self.q_lora_rank = config.q_lora_rank - self.o_lora_rank = config.o_lora_rank - self.head_dim = config.head_dim - self.rope_head_dim = config.qk_rope_head_dim - self.nope_head_dim = self.head_dim - self.rope_head_dim - self.n_groups = config.o_groups - self.n_local_groups = self.n_groups // tp_size - self.window_size = config.sliding_window - # NOTE(zyongye) Compress ratio can't be 0 - # we do this for because MTP layer is not included - # in the compress ratio list - if layer_id < config.num_hidden_layers: - self.compress_ratio = max(1, config.compress_ratios[layer_id]) - else: - self.compress_ratio = 1 - self.eps = config.rms_norm_eps - self.max_position_embeddings = config.max_position_embeddings - - # Padded to min 64 heads for FlashMLA, initialized to -inf - # (no sink effect). Weight loading fills the first n_local_heads slots. - padded_heads = max(self.n_local_heads, 64) - self.attn_sink = nn.Parameter( - torch.full((padded_heads,), -float("inf"), dtype=torch.float32), - requires_grad=False, - ) - - self.fused_wqa_wkv = MergedColumnParallelLinear( - self.hidden_size, - [self.q_lora_rank, self.head_dim], - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.fused_wqa_wkv", - disable_tp=True, # fused ReplicatedLinear - ) - self.q_norm = RMSNorm(self.q_lora_rank, self.eps) - self.wq_b = ColumnParallelLinear( - self.q_lora_rank, - self.n_heads * self.head_dim, - bias=False, - quant_config=quant_config, - return_bias=False, - prefix=f"{prefix}.wq_b", - ) - - self.kv_norm = RMSNorm(self.head_dim, self.eps) - self.wo_a = ColumnParallelLinear( - self.n_heads * self.head_dim // self.n_groups, - self.n_groups * self.o_lora_rank, - bias=False, - quant_config=quant_config, - return_bias=False, - prefix=f"{prefix}.wo_a", - ) - self.wo_a.is_bmm = True - self.wo_a.bmm_batch_size = self.n_local_groups - self.wo_b = RowParallelLinear( - self.n_groups * self.o_lora_rank, - self.hidden_size, - bias=False, - quant_config=quant_config, - return_bias=False, - prefix=f"{prefix}.wo_b", - ) - self.softmax_scale = self.head_dim**-0.5 - self.scale_fmt = config.quantization_config["scale_fmt"] - - self.rope_parameters = config.rope_scaling - - # Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it) - self.rotary_emb = build_deepseek_v4_rope( - config, - head_dim=self.head_dim, - rope_head_dim=self.rope_head_dim, - max_position_embeddings=self.max_position_embeddings, - compress_ratio=self.compress_ratio, - ) - - self.indexer = None - if self.compress_ratio == 4: - # Only C4A uses sparse attention and hence has indexer. - # aux_stream_list[0] runs indexer.forward() in the wrapper; [2] is - # free here (outer GEMMs joined) for the inner overlap of - # wq_b+fused_indexer_q_rope_quant vs compressor. - indexer_aux_stream = ( - aux_stream_list[2] if aux_stream_list is not None else None - ) - self.indexer = DeepseekV4Indexer( - vllm_config, - config=config, - hidden_size=self.hidden_size, - q_lora_rank=self.q_lora_rank, - quant_config=quant_config, - cache_config=vllm_config.cache_config, - topk_indices_buffer=topk_indices_buffer, - compress_ratio=self.compress_ratio, - prefix=f"{prefix}.indexer", - aux_stream=indexer_aux_stream, - ) - - self.mla_attn = DeepseekV4MLA( - hidden_size=self.hidden_size, - num_heads=self.n_local_heads, - head_dim=self.head_dim, - scale=self.softmax_scale, - qk_nope_head_dim=self.nope_head_dim, - qk_rope_head_dim=self.rope_head_dim, - v_head_dim=self.head_dim, - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.head_dim, - o_lora_rank=self.o_lora_rank, - vllm_config=vllm_config, - fused_wqa_wkv=self.fused_wqa_wkv, - q_norm=self.q_norm, - wq_b=self.wq_b, - kv_norm=self.kv_norm, - wo_a=self.wo_a, - wo_b=self.wo_b, - attn_sink=self.attn_sink, - rotary_emb=self.rotary_emb, - indexer=self.indexer, - indexer_rotary_emb=self.rotary_emb, - topk_indices_buffer=topk_indices_buffer, - aux_stream_list=aux_stream_list, - window_size=self.window_size, - compress_ratio=self.compress_ratio, - cache_config=vllm_config.cache_config, - quant_config=quant_config, - prefix=prefix, - ) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - llama_4_scaling: torch.Tensor | None, - ): - return self.mla_attn(positions, hidden_states, llama_4_scaling) - - class DeepseekV4DecoderLayer(nn.Module): def __init__( self, @@ -938,7 +776,7 @@ class DeepseekV4Model(nn.Module): self.rms_norm_eps = config.rms_norm_eps # Three aux streams: one per non-default input GEMM in - # DeepseekV4MLA.attn_gemm_parallel_execute + # DeepseekV4Attention.attn_gemm_parallel_execute # (compressor kv_score, indexer.weights_proj, indexer.compressor # kv_score). fused_wqa_wkv stays on the default stream. aux_stream_list = [torch.cuda.Stream() for _ in range(3)] @@ -1236,7 +1074,6 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper: ".ffn.gate.bias": ".ffn.gate.e_score_correction_bias", }, orig_to_new_substr={ - ".attn.compressor.": ".attn.mla_attn.compressor.", ".shared_experts.w2": ".shared_experts.down_proj", }, ) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 9140a6fccd5..e3173949bf1 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -602,7 +602,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad fp8_use_mixed_batch = ( self.num_heads < MIN_HEADS_FOR_BF16_PREFILL and not self.is_deepseek_v4 ) - # DeepseekV4 has its own attention impl (DeepseekV4MLAAttention) that does not + # DeepseekV4 has its own attention impl (DeepseekV4Attention) that does not # consume fp8_extra_metadata. Skipping the build here avoids a # forced D2H sync on seq_lens that would otherwise fire on every # prefill-bearing step, lifting GPU utilization on long-prefill