[DSV4] Remove unncessary classes & functions (#44246)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-06-01 14:43:00 -07:00
committed by GitHub
parent e4cbc4385d
commit 8c3cc98cff
3 changed files with 62 additions and 124 deletions
+14 -18
View File
@@ -48,8 +48,7 @@ from vllm.model_executor.models.utils import (
)
from vllm.models.deepseek_v4.attention import (
DeepseekV4Indexer,
DeepseekV4MLAModules,
DeepseekV4MultiHeadLatentAttentionWrapper,
DeepseekV4MLA,
)
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
@@ -314,7 +313,7 @@ class DeepseekV4Attention(nn.Module):
self.rope_parameters = config.rope_scaling
# Initialize rotary embedding BEFORE DeepseekV4MLAModules (which needs it)
# Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it)
rope_parameters = config.rope_parameters
rope_parameters["rope_theta"] = (
config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta
@@ -351,7 +350,17 @@ class DeepseekV4Attention(nn.Module):
prefix=f"{prefix}.indexer",
)
mla_modules = DeepseekV4MLAModules(
self.mla_attn = DeepseekV4MLA(
hidden_size=self.hidden_size,
num_heads=self.n_local_heads,
head_dim=self.head_dim,
scale=self.softmax_scale,
qk_nope_head_dim=self.nope_head_dim,
qk_rope_head_dim=self.rope_head_dim,
v_head_dim=self.head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.head_dim,
o_lora_rank=self.o_lora_rank,
vllm_config=vllm_config,
fused_wqa_wkv=self.fused_wqa_wkv,
q_norm=self.q_norm,
@@ -365,19 +374,6 @@ class DeepseekV4Attention(nn.Module):
indexer_rotary_emb=self.rotary_emb,
topk_indices_buffer=topk_indices_buffer,
aux_stream_list=aux_stream_list,
)
self.mla_attn = DeepseekV4MultiHeadLatentAttentionWrapper(
hidden_size=self.hidden_size,
num_heads=self.n_local_heads,
head_dim=self.head_dim,
scale=self.softmax_scale,
qk_nope_head_dim=self.nope_head_dim,
qk_rope_head_dim=self.rope_head_dim,
v_head_dim=self.head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.head_dim,
o_lora_rank=self.o_lora_rank,
mla_modules=mla_modules,
window_size=self.window_size,
compress_ratio=self.compress_ratio,
cache_config=vllm_config.cache_config,
@@ -618,7 +614,7 @@ class DeepseekV4Model(nn.Module):
self.rms_norm_eps = config.rms_norm_eps
# Three aux streams: one per non-default input GEMM in
# DeepseekV4MultiHeadLatentAttentionWrapper.attn_gemm_parallel_execute
# DeepseekV4MLA.attn_gemm_parallel_execute
# (compressor kv_score, indexer.weights_proj, indexer.compressor
# kv_score). fused_wqa_wkv stays on the default stream.
# Disable them on ROCm because of hang issues.
+34 -88
View File
@@ -5,7 +5,6 @@ DeepseekV4 MLA Attention Layer
"""
from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, cast
import torch
@@ -38,9 +37,8 @@ from vllm.config import (
get_current_vllm_config,
)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -90,46 +88,7 @@ def _select_v4_sparse_impl() -> "type[DeepseekV4SparseMLAAttentionImpl]":
return DeepseekV4FlashMLASparseImpl
@dataclass
class DeepseekV4MLAModules:
"""Modules used in DeepseekV4 MLA."""
vllm_config: VllmConfig
fused_wqa_wkv: torch.nn.Module
q_norm: torch.nn.Module
wq_b: torch.nn.Module
kv_norm: torch.nn.Module
wo_a: torch.nn.Module
wo_b: torch.nn.Module
attn_sink: torch.nn.Module
rotary_emb: torch.nn.Module
indexer: torch.nn.Module | None
indexer_rotary_emb: torch.nn.Module
topk_indices_buffer: torch.Tensor | None
aux_stream_list: list[torch.cuda.Stream] | None = None
# --8<-- [start:multi_head_latent_attention]
@PluggableLayer.register("deepseek_v4_multi_head_latent_attention")
class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
"""Pluggable MLA layer which allows OOT backends to add
custom implementations of the outer MLA layer (including rope & o_proj).
Note that currently oot platforms can still use CustomOp.register_oot to
replace MLA layer entirely, although we use PluggableLayer to register
this layer now.
This class takes positions and hidden_states as input.
The input tensors can either contain prefill tokens or decode tokens.
The class does the following:
1. MLA Preprocess.
2. Perform multi-head attention to prefill tokens and
multi-query attention to decode tokens separately.
3. Return the output tensor.
"""
# --8<-- [end:multi_head_latent_attention]
class DeepseekV4MLA(nn.Module):
def __init__(
self,
hidden_size: int,
@@ -142,7 +101,19 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
q_lora_rank: int | None,
kv_lora_rank: int,
o_lora_rank: int | None,
mla_modules: DeepseekV4MLAModules,
vllm_config: VllmConfig,
fused_wqa_wkv: torch.nn.Module,
q_norm: torch.nn.Module,
wq_b: torch.nn.Module,
kv_norm: torch.nn.Module,
wo_a: torch.nn.Module,
wo_b: torch.nn.Module,
attn_sink: torch.nn.Module,
rotary_emb: torch.nn.Module,
indexer: torch.nn.Module | None,
indexer_rotary_emb: torch.nn.Module,
topk_indices_buffer: torch.Tensor | None,
aux_stream_list: list[torch.cuda.Stream] | None,
window_size: int,
compress_ratio: int | None,
cache_config: CacheConfig | None = None,
@@ -162,7 +133,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
self.prefix = prefix
# Extract config from vllm_config
config = mla_modules.vllm_config.model_config.hf_config
config = vllm_config.model_config.hf_config
tp_size = get_tensor_model_parallel_world_size()
# DeepseekV4-specific attributes (num_heads is already TP-adjusted)
@@ -173,12 +144,12 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
self.o_lora_rank = config.o_lora_rank
# Store projection modules
self.fused_wqa_wkv = mla_modules.fused_wqa_wkv
self.q_norm = mla_modules.q_norm
self.wq_b = mla_modules.wq_b
self.fused_wqa_wkv = fused_wqa_wkv
self.q_norm = q_norm
self.wq_b = wq_b
self.kv_norm = mla_modules.kv_norm
self.wo_a = mla_modules.wo_a
self.kv_norm = kv_norm
self.wo_a = wo_a
self._wo_a_act_quant = QuantFP8(
static=False,
@@ -188,7 +159,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
# Bypass packed-for-deepgemm path — we need FP32 scales (not packed
# INT32) so fp8_einsum can handle layout transform internally.
self._wo_a_act_quant.use_deep_gemm_supported = False
self.wo_b = mla_modules.wo_b
self.wo_b = wo_b
# Pick fp8_einsum recipe based on GPU arch:
# SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128
@@ -198,11 +169,11 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128)
self._tma_aligned_scales = cap.major >= 10
self.rotary_emb = mla_modules.rotary_emb
self.indexer_rotary_emb = mla_modules.indexer_rotary_emb
self.topk_indices_buffer = mla_modules.topk_indices_buffer
self.rotary_emb = rotary_emb
self.indexer_rotary_emb = indexer_rotary_emb
self.topk_indices_buffer = topk_indices_buffer
self.indexer = mla_modules.indexer
self.indexer = indexer
# Per-head RMS normalization for Q (no learnable weights)
self.q_head_norm = RMSNorm(head_dim, eps=self.eps, has_weight=False)
@@ -216,7 +187,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
)
# Will be None on ROCm for now.
self.aux_stream_list = mla_modules.aux_stream_list
self.aux_stream_list = aux_stream_list
# [0]: GEMM start / post-GEMM event0. [1..3]: GEMM done events;
# [1] doubles as post-GEMM event1. Reuse is safe: GEMM fully joins
# before post-GEMM starts.
@@ -243,7 +214,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
window_size=self.window_size,
head_bytes=head_bytes,
swa_cache_layer=self.swa_cache_layer,
attn_sink=mla_modules.attn_sink, # already padded with -inf
attn_sink=attn_sink, # already padded with -inf
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
@@ -253,21 +224,12 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
# Mirror the inner layer's padded head count (single source of truth).
self.padded_heads = self.mla_attn.padded_heads
# Register this layer in the compilation config's static forward context
# This allows the custom op to retrieve the layer during execution
compilation_config = mla_modules.vllm_config.compilation_config
# HACK
self.layer_name = prefix + ".deepseek_v4_multi_head_latent_attention"
if self.layer_name in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {self.layer_name}")
compilation_config.static_forward_context[self.layer_name] = self
# Create the compressor for layers with compress_ratio > 1; after
# creating the DeepseekV4MLAAttention layer to get its cache.
self.compressor = None
if self.compress_ratio > 1:
self.compressor = DeepseekCompressor(
vllm_config=mla_modules.vllm_config,
vllm_config=vllm_config,
compress_ratio=self.compress_ratio,
hidden_size=self.hidden_size,
head_dim=self.head_dim,
@@ -291,15 +253,10 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
device=hidden_states.device,
)
# @eager_break_during_capture: this is where the breakable
# cudagraph capture breaks (the attention op runs eagerly between
# captured graph segments).
deepseek_v4_attention(
hidden_states,
positions,
o_padded,
self.layer_name,
)
# attention_impl is wrapped with @eager_break_during_capture: this is
# where the breakable cudagraph capture breaks (the attention op runs
# eagerly between captured graph segments).
self.attention_impl(hidden_states, positions, o_padded)
o = o_padded[:, : self.n_local_heads, :]
# Keep ROCm on the BF16 reference wo_a path util kernel ready.
@@ -405,6 +362,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
return qr_kv, kv_score, indexer_kv_score, indexer_weights
@eager_break_during_capture
def attention_impl(
self,
hidden_states: torch.Tensor,
@@ -541,18 +499,6 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
)
@eager_break_during_capture
def deepseek_v4_attention(
hidden_states: torch.Tensor,
positions: torch.Tensor,
out: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.attention_impl(hidden_states, positions, out)
class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
def __init__(
self,
+14 -18
View File
@@ -54,8 +54,7 @@ from vllm.model_executor.models.utils import (
from vllm.model_executor.utils import set_weight_attrs
from vllm.models.deepseek_v4.attention import (
DeepseekV4Indexer,
DeepseekV4MLAModules,
DeepseekV4MultiHeadLatentAttentionWrapper,
DeepseekV4MLA,
)
from vllm.models.deepseek_v4.nvidia.ops.prepare_megamoe import prepare_megamoe_inputs
from vllm.sequence import IntermediateTensors
@@ -697,7 +696,7 @@ class DeepseekV4Attention(nn.Module):
self.rope_parameters = config.rope_scaling
# Initialize rotary embedding BEFORE DeepseekV4MLAModules (which needs it)
# Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it)
rope_parameters = config.rope_parameters
rope_parameters["rope_theta"] = (
config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta
@@ -741,7 +740,17 @@ class DeepseekV4Attention(nn.Module):
aux_stream=indexer_aux_stream,
)
mla_modules = DeepseekV4MLAModules(
self.mla_attn = DeepseekV4MLA(
hidden_size=self.hidden_size,
num_heads=self.n_local_heads,
head_dim=self.head_dim,
scale=self.softmax_scale,
qk_nope_head_dim=self.nope_head_dim,
qk_rope_head_dim=self.rope_head_dim,
v_head_dim=self.head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.head_dim,
o_lora_rank=self.o_lora_rank,
vllm_config=vllm_config,
fused_wqa_wkv=self.fused_wqa_wkv,
q_norm=self.q_norm,
@@ -755,19 +764,6 @@ class DeepseekV4Attention(nn.Module):
indexer_rotary_emb=self.rotary_emb,
topk_indices_buffer=topk_indices_buffer,
aux_stream_list=aux_stream_list,
)
self.mla_attn = DeepseekV4MultiHeadLatentAttentionWrapper(
hidden_size=self.hidden_size,
num_heads=self.n_local_heads,
head_dim=self.head_dim,
scale=self.softmax_scale,
qk_nope_head_dim=self.nope_head_dim,
qk_rope_head_dim=self.rope_head_dim,
v_head_dim=self.head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.head_dim,
o_lora_rank=self.o_lora_rank,
mla_modules=mla_modules,
window_size=self.window_size,
compress_ratio=self.compress_ratio,
cache_config=vllm_config.cache_config,
@@ -955,7 +951,7 @@ class DeepseekV4Model(nn.Module):
self.rms_norm_eps = config.rms_norm_eps
# Three aux streams: one per non-default input GEMM in
# DeepseekV4MultiHeadLatentAttentionWrapper.attn_gemm_parallel_execute
# DeepseekV4MLA.attn_gemm_parallel_execute
# (compressor kv_score, indexer.weights_proj, indexer.compressor
# kv_score). fused_wqa_wkv stays on the default stream.
aux_stream_list = [torch.cuda.Stream() for _ in range(3)]