From 8c3cc98cffd31b910c41b11076e8c175fc6dabe9 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 1 Jun 2026 14:43:00 -0700 Subject: [PATCH] [DSV4] Remove unncessary classes & functions (#44246) Signed-off-by: Woosuk Kwon --- vllm/models/deepseek_v4/amd/model.py | 32 +++---- vllm/models/deepseek_v4/attention.py | 122 +++++++----------------- vllm/models/deepseek_v4/nvidia/model.py | 32 +++---- 3 files changed, 62 insertions(+), 124 deletions(-) diff --git a/vllm/models/deepseek_v4/amd/model.py b/vllm/models/deepseek_v4/amd/model.py index 28836a2b143..885fffea868 100644 --- a/vllm/models/deepseek_v4/amd/model.py +++ b/vllm/models/deepseek_v4/amd/model.py @@ -48,8 +48,7 @@ from vllm.model_executor.models.utils import ( ) from vllm.models.deepseek_v4.attention import ( DeepseekV4Indexer, - DeepseekV4MLAModules, - DeepseekV4MultiHeadLatentAttentionWrapper, + DeepseekV4MLA, ) from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -314,7 +313,7 @@ class DeepseekV4Attention(nn.Module): self.rope_parameters = config.rope_scaling - # Initialize rotary embedding BEFORE DeepseekV4MLAModules (which needs it) + # Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it) rope_parameters = config.rope_parameters rope_parameters["rope_theta"] = ( config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta @@ -351,7 +350,17 @@ class DeepseekV4Attention(nn.Module): prefix=f"{prefix}.indexer", ) - mla_modules = DeepseekV4MLAModules( + self.mla_attn = DeepseekV4MLA( + hidden_size=self.hidden_size, + num_heads=self.n_local_heads, + head_dim=self.head_dim, + scale=self.softmax_scale, + qk_nope_head_dim=self.nope_head_dim, + qk_rope_head_dim=self.rope_head_dim, + v_head_dim=self.head_dim, + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.head_dim, + o_lora_rank=self.o_lora_rank, vllm_config=vllm_config, fused_wqa_wkv=self.fused_wqa_wkv, q_norm=self.q_norm, @@ -365,19 +374,6 @@ class DeepseekV4Attention(nn.Module): indexer_rotary_emb=self.rotary_emb, topk_indices_buffer=topk_indices_buffer, aux_stream_list=aux_stream_list, - ) - self.mla_attn = DeepseekV4MultiHeadLatentAttentionWrapper( - hidden_size=self.hidden_size, - num_heads=self.n_local_heads, - head_dim=self.head_dim, - scale=self.softmax_scale, - qk_nope_head_dim=self.nope_head_dim, - qk_rope_head_dim=self.rope_head_dim, - v_head_dim=self.head_dim, - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.head_dim, - o_lora_rank=self.o_lora_rank, - mla_modules=mla_modules, window_size=self.window_size, compress_ratio=self.compress_ratio, cache_config=vllm_config.cache_config, @@ -618,7 +614,7 @@ class DeepseekV4Model(nn.Module): self.rms_norm_eps = config.rms_norm_eps # Three aux streams: one per non-default input GEMM in - # DeepseekV4MultiHeadLatentAttentionWrapper.attn_gemm_parallel_execute + # DeepseekV4MLA.attn_gemm_parallel_execute # (compressor kv_score, indexer.weights_proj, indexer.compressor # kv_score). fused_wqa_wkv stays on the default stream. # Disable them on ROCm because of hang issues. diff --git a/vllm/models/deepseek_v4/attention.py b/vllm/models/deepseek_v4/attention.py index 4fae5dc0529..55cb3d94ba6 100644 --- a/vllm/models/deepseek_v4/attention.py +++ b/vllm/models/deepseek_v4/attention.py @@ -5,7 +5,6 @@ DeepseekV4 MLA Attention Layer """ from collections.abc import Callable -from dataclasses import dataclass from typing import TYPE_CHECKING, Any, cast import torch @@ -38,9 +37,8 @@ from vllm.config import ( get_current_vllm_config, ) from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.forward_context import ForwardContext, get_forward_context +from vllm.forward_context import get_forward_context from vllm.logger import init_logger -from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig @@ -90,46 +88,7 @@ def _select_v4_sparse_impl() -> "type[DeepseekV4SparseMLAAttentionImpl]": return DeepseekV4FlashMLASparseImpl -@dataclass -class DeepseekV4MLAModules: - """Modules used in DeepseekV4 MLA.""" - - vllm_config: VllmConfig - fused_wqa_wkv: torch.nn.Module - q_norm: torch.nn.Module - wq_b: torch.nn.Module - kv_norm: torch.nn.Module - wo_a: torch.nn.Module - wo_b: torch.nn.Module - attn_sink: torch.nn.Module - rotary_emb: torch.nn.Module - indexer: torch.nn.Module | None - indexer_rotary_emb: torch.nn.Module - topk_indices_buffer: torch.Tensor | None - aux_stream_list: list[torch.cuda.Stream] | None = None - - -# --8<-- [start:multi_head_latent_attention] -@PluggableLayer.register("deepseek_v4_multi_head_latent_attention") -class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): - """Pluggable MLA layer which allows OOT backends to add - custom implementations of the outer MLA layer (including rope & o_proj). - Note that currently oot platforms can still use CustomOp.register_oot to - replace MLA layer entirely, although we use PluggableLayer to register - this layer now. - - This class takes positions and hidden_states as input. - The input tensors can either contain prefill tokens or decode tokens. - The class does the following: - - 1. MLA Preprocess. - 2. Perform multi-head attention to prefill tokens and - multi-query attention to decode tokens separately. - 3. Return the output tensor. - """ - - # --8<-- [end:multi_head_latent_attention] - +class DeepseekV4MLA(nn.Module): def __init__( self, hidden_size: int, @@ -142,7 +101,19 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): q_lora_rank: int | None, kv_lora_rank: int, o_lora_rank: int | None, - mla_modules: DeepseekV4MLAModules, + vllm_config: VllmConfig, + fused_wqa_wkv: torch.nn.Module, + q_norm: torch.nn.Module, + wq_b: torch.nn.Module, + kv_norm: torch.nn.Module, + wo_a: torch.nn.Module, + wo_b: torch.nn.Module, + attn_sink: torch.nn.Module, + rotary_emb: torch.nn.Module, + indexer: torch.nn.Module | None, + indexer_rotary_emb: torch.nn.Module, + topk_indices_buffer: torch.Tensor | None, + aux_stream_list: list[torch.cuda.Stream] | None, window_size: int, compress_ratio: int | None, cache_config: CacheConfig | None = None, @@ -162,7 +133,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): self.prefix = prefix # Extract config from vllm_config - config = mla_modules.vllm_config.model_config.hf_config + config = vllm_config.model_config.hf_config tp_size = get_tensor_model_parallel_world_size() # DeepseekV4-specific attributes (num_heads is already TP-adjusted) @@ -173,12 +144,12 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): self.o_lora_rank = config.o_lora_rank # Store projection modules - self.fused_wqa_wkv = mla_modules.fused_wqa_wkv - self.q_norm = mla_modules.q_norm - self.wq_b = mla_modules.wq_b + self.fused_wqa_wkv = fused_wqa_wkv + self.q_norm = q_norm + self.wq_b = wq_b - self.kv_norm = mla_modules.kv_norm - self.wo_a = mla_modules.wo_a + self.kv_norm = kv_norm + self.wo_a = wo_a self._wo_a_act_quant = QuantFP8( static=False, @@ -188,7 +159,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): # Bypass packed-for-deepgemm path — we need FP32 scales (not packed # INT32) so fp8_einsum can handle layout transform internally. self._wo_a_act_quant.use_deep_gemm_supported = False - self.wo_b = mla_modules.wo_b + self.wo_b = wo_b # Pick fp8_einsum recipe based on GPU arch: # SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128 @@ -198,11 +169,11 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128) self._tma_aligned_scales = cap.major >= 10 - self.rotary_emb = mla_modules.rotary_emb - self.indexer_rotary_emb = mla_modules.indexer_rotary_emb - self.topk_indices_buffer = mla_modules.topk_indices_buffer + self.rotary_emb = rotary_emb + self.indexer_rotary_emb = indexer_rotary_emb + self.topk_indices_buffer = topk_indices_buffer - self.indexer = mla_modules.indexer + self.indexer = indexer # Per-head RMS normalization for Q (no learnable weights) self.q_head_norm = RMSNorm(head_dim, eps=self.eps, has_weight=False) @@ -216,7 +187,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): ) # Will be None on ROCm for now. - self.aux_stream_list = mla_modules.aux_stream_list + self.aux_stream_list = aux_stream_list # [0]: GEMM start / post-GEMM event0. [1..3]: GEMM done events; # [1] doubles as post-GEMM event1. Reuse is safe: GEMM fully joins # before post-GEMM starts. @@ -243,7 +214,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): window_size=self.window_size, head_bytes=head_bytes, swa_cache_layer=self.swa_cache_layer, - attn_sink=mla_modules.attn_sink, # already padded with -inf + attn_sink=attn_sink, # already padded with -inf cache_config=cache_config, quant_config=quant_config, prefix=prefix, @@ -253,21 +224,12 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): # Mirror the inner layer's padded head count (single source of truth). self.padded_heads = self.mla_attn.padded_heads - # Register this layer in the compilation config's static forward context - # This allows the custom op to retrieve the layer during execution - compilation_config = mla_modules.vllm_config.compilation_config - # HACK - self.layer_name = prefix + ".deepseek_v4_multi_head_latent_attention" - if self.layer_name in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {self.layer_name}") - compilation_config.static_forward_context[self.layer_name] = self - # Create the compressor for layers with compress_ratio > 1; after # creating the DeepseekV4MLAAttention layer to get its cache. self.compressor = None if self.compress_ratio > 1: self.compressor = DeepseekCompressor( - vllm_config=mla_modules.vllm_config, + vllm_config=vllm_config, compress_ratio=self.compress_ratio, hidden_size=self.hidden_size, head_dim=self.head_dim, @@ -291,15 +253,10 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): device=hidden_states.device, ) - # @eager_break_during_capture: this is where the breakable - # cudagraph capture breaks (the attention op runs eagerly between - # captured graph segments). - deepseek_v4_attention( - hidden_states, - positions, - o_padded, - self.layer_name, - ) + # attention_impl is wrapped with @eager_break_during_capture: this is + # where the breakable cudagraph capture breaks (the attention op runs + # eagerly between captured graph segments). + self.attention_impl(hidden_states, positions, o_padded) o = o_padded[:, : self.n_local_heads, :] # Keep ROCm on the BF16 reference wo_a path util kernel ready. @@ -405,6 +362,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): return qr_kv, kv_score, indexer_kv_score, indexer_weights + @eager_break_during_capture def attention_impl( self, hidden_states: torch.Tensor, @@ -541,18 +499,6 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): ) -@eager_break_during_capture -def deepseek_v4_attention( - hidden_states: torch.Tensor, - positions: torch.Tensor, - out: torch.Tensor, - layer_name: str, -) -> None: - forward_context: ForwardContext = get_forward_context() - self = forward_context.no_compile_layers[layer_name] - self.attention_impl(hidden_states, positions, out) - - class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase): def __init__( self, diff --git a/vllm/models/deepseek_v4/nvidia/model.py b/vllm/models/deepseek_v4/nvidia/model.py index 30a7e6e747f..e26d9e593be 100644 --- a/vllm/models/deepseek_v4/nvidia/model.py +++ b/vllm/models/deepseek_v4/nvidia/model.py @@ -54,8 +54,7 @@ from vllm.model_executor.models.utils import ( from vllm.model_executor.utils import set_weight_attrs from vllm.models.deepseek_v4.attention import ( DeepseekV4Indexer, - DeepseekV4MLAModules, - DeepseekV4MultiHeadLatentAttentionWrapper, + DeepseekV4MLA, ) from vllm.models.deepseek_v4.nvidia.ops.prepare_megamoe import prepare_megamoe_inputs from vllm.sequence import IntermediateTensors @@ -697,7 +696,7 @@ class DeepseekV4Attention(nn.Module): self.rope_parameters = config.rope_scaling - # Initialize rotary embedding BEFORE DeepseekV4MLAModules (which needs it) + # Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it) rope_parameters = config.rope_parameters rope_parameters["rope_theta"] = ( config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta @@ -741,7 +740,17 @@ class DeepseekV4Attention(nn.Module): aux_stream=indexer_aux_stream, ) - mla_modules = DeepseekV4MLAModules( + self.mla_attn = DeepseekV4MLA( + hidden_size=self.hidden_size, + num_heads=self.n_local_heads, + head_dim=self.head_dim, + scale=self.softmax_scale, + qk_nope_head_dim=self.nope_head_dim, + qk_rope_head_dim=self.rope_head_dim, + v_head_dim=self.head_dim, + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.head_dim, + o_lora_rank=self.o_lora_rank, vllm_config=vllm_config, fused_wqa_wkv=self.fused_wqa_wkv, q_norm=self.q_norm, @@ -755,19 +764,6 @@ class DeepseekV4Attention(nn.Module): indexer_rotary_emb=self.rotary_emb, topk_indices_buffer=topk_indices_buffer, aux_stream_list=aux_stream_list, - ) - self.mla_attn = DeepseekV4MultiHeadLatentAttentionWrapper( - hidden_size=self.hidden_size, - num_heads=self.n_local_heads, - head_dim=self.head_dim, - scale=self.softmax_scale, - qk_nope_head_dim=self.nope_head_dim, - qk_rope_head_dim=self.rope_head_dim, - v_head_dim=self.head_dim, - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.head_dim, - o_lora_rank=self.o_lora_rank, - mla_modules=mla_modules, window_size=self.window_size, compress_ratio=self.compress_ratio, cache_config=vllm_config.cache_config, @@ -955,7 +951,7 @@ class DeepseekV4Model(nn.Module): self.rms_norm_eps = config.rms_norm_eps # Three aux streams: one per non-default input GEMM in - # DeepseekV4MultiHeadLatentAttentionWrapper.attn_gemm_parallel_execute + # DeepseekV4MLA.attn_gemm_parallel_execute # (compressor kv_score, indexer.weights_proj, indexer.compressor # kv_score). fused_wqa_wkv stays on the default stream. aux_stream_list = [torch.cuda.Stream() for _ in range(3)]