mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[DSV4] Remove unncessary classes & functions (#44246)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -48,8 +48,7 @@ from vllm.model_executor.models.utils import (
|
||||
)
|
||||
from vllm.models.deepseek_v4.attention import (
|
||||
DeepseekV4Indexer,
|
||||
DeepseekV4MLAModules,
|
||||
DeepseekV4MultiHeadLatentAttentionWrapper,
|
||||
DeepseekV4MLA,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@@ -314,7 +313,7 @@ class DeepseekV4Attention(nn.Module):
|
||||
|
||||
self.rope_parameters = config.rope_scaling
|
||||
|
||||
# Initialize rotary embedding BEFORE DeepseekV4MLAModules (which needs it)
|
||||
# Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it)
|
||||
rope_parameters = config.rope_parameters
|
||||
rope_parameters["rope_theta"] = (
|
||||
config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta
|
||||
@@ -351,7 +350,17 @@ class DeepseekV4Attention(nn.Module):
|
||||
prefix=f"{prefix}.indexer",
|
||||
)
|
||||
|
||||
mla_modules = DeepseekV4MLAModules(
|
||||
self.mla_attn = DeepseekV4MLA(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=self.n_local_heads,
|
||||
head_dim=self.head_dim,
|
||||
scale=self.softmax_scale,
|
||||
qk_nope_head_dim=self.nope_head_dim,
|
||||
qk_rope_head_dim=self.rope_head_dim,
|
||||
v_head_dim=self.head_dim,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.head_dim,
|
||||
o_lora_rank=self.o_lora_rank,
|
||||
vllm_config=vllm_config,
|
||||
fused_wqa_wkv=self.fused_wqa_wkv,
|
||||
q_norm=self.q_norm,
|
||||
@@ -365,19 +374,6 @@ class DeepseekV4Attention(nn.Module):
|
||||
indexer_rotary_emb=self.rotary_emb,
|
||||
topk_indices_buffer=topk_indices_buffer,
|
||||
aux_stream_list=aux_stream_list,
|
||||
)
|
||||
self.mla_attn = DeepseekV4MultiHeadLatentAttentionWrapper(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=self.n_local_heads,
|
||||
head_dim=self.head_dim,
|
||||
scale=self.softmax_scale,
|
||||
qk_nope_head_dim=self.nope_head_dim,
|
||||
qk_rope_head_dim=self.rope_head_dim,
|
||||
v_head_dim=self.head_dim,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.head_dim,
|
||||
o_lora_rank=self.o_lora_rank,
|
||||
mla_modules=mla_modules,
|
||||
window_size=self.window_size,
|
||||
compress_ratio=self.compress_ratio,
|
||||
cache_config=vllm_config.cache_config,
|
||||
@@ -618,7 +614,7 @@ class DeepseekV4Model(nn.Module):
|
||||
self.rms_norm_eps = config.rms_norm_eps
|
||||
|
||||
# Three aux streams: one per non-default input GEMM in
|
||||
# DeepseekV4MultiHeadLatentAttentionWrapper.attn_gemm_parallel_execute
|
||||
# DeepseekV4MLA.attn_gemm_parallel_execute
|
||||
# (compressor kv_score, indexer.weights_proj, indexer.compressor
|
||||
# kv_score). fused_wqa_wkv stays on the default stream.
|
||||
# Disable them on ROCm because of hang issues.
|
||||
|
||||
@@ -5,7 +5,6 @@ DeepseekV4 MLA Attention Layer
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import torch
|
||||
@@ -38,9 +37,8 @@ from vllm.config import (
|
||||
get_current_vllm_config,
|
||||
)
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import PluggableLayer
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@@ -90,46 +88,7 @@ def _select_v4_sparse_impl() -> "type[DeepseekV4SparseMLAAttentionImpl]":
|
||||
return DeepseekV4FlashMLASparseImpl
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepseekV4MLAModules:
|
||||
"""Modules used in DeepseekV4 MLA."""
|
||||
|
||||
vllm_config: VllmConfig
|
||||
fused_wqa_wkv: torch.nn.Module
|
||||
q_norm: torch.nn.Module
|
||||
wq_b: torch.nn.Module
|
||||
kv_norm: torch.nn.Module
|
||||
wo_a: torch.nn.Module
|
||||
wo_b: torch.nn.Module
|
||||
attn_sink: torch.nn.Module
|
||||
rotary_emb: torch.nn.Module
|
||||
indexer: torch.nn.Module | None
|
||||
indexer_rotary_emb: torch.nn.Module
|
||||
topk_indices_buffer: torch.Tensor | None
|
||||
aux_stream_list: list[torch.cuda.Stream] | None = None
|
||||
|
||||
|
||||
# --8<-- [start:multi_head_latent_attention]
|
||||
@PluggableLayer.register("deepseek_v4_multi_head_latent_attention")
|
||||
class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
"""Pluggable MLA layer which allows OOT backends to add
|
||||
custom implementations of the outer MLA layer (including rope & o_proj).
|
||||
Note that currently oot platforms can still use CustomOp.register_oot to
|
||||
replace MLA layer entirely, although we use PluggableLayer to register
|
||||
this layer now.
|
||||
|
||||
This class takes positions and hidden_states as input.
|
||||
The input tensors can either contain prefill tokens or decode tokens.
|
||||
The class does the following:
|
||||
|
||||
1. MLA Preprocess.
|
||||
2. Perform multi-head attention to prefill tokens and
|
||||
multi-query attention to decode tokens separately.
|
||||
3. Return the output tensor.
|
||||
"""
|
||||
|
||||
# --8<-- [end:multi_head_latent_attention]
|
||||
|
||||
class DeepseekV4MLA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
@@ -142,7 +101,19 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
q_lora_rank: int | None,
|
||||
kv_lora_rank: int,
|
||||
o_lora_rank: int | None,
|
||||
mla_modules: DeepseekV4MLAModules,
|
||||
vllm_config: VllmConfig,
|
||||
fused_wqa_wkv: torch.nn.Module,
|
||||
q_norm: torch.nn.Module,
|
||||
wq_b: torch.nn.Module,
|
||||
kv_norm: torch.nn.Module,
|
||||
wo_a: torch.nn.Module,
|
||||
wo_b: torch.nn.Module,
|
||||
attn_sink: torch.nn.Module,
|
||||
rotary_emb: torch.nn.Module,
|
||||
indexer: torch.nn.Module | None,
|
||||
indexer_rotary_emb: torch.nn.Module,
|
||||
topk_indices_buffer: torch.Tensor | None,
|
||||
aux_stream_list: list[torch.cuda.Stream] | None,
|
||||
window_size: int,
|
||||
compress_ratio: int | None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
@@ -162,7 +133,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
self.prefix = prefix
|
||||
|
||||
# Extract config from vllm_config
|
||||
config = mla_modules.vllm_config.model_config.hf_config
|
||||
config = vllm_config.model_config.hf_config
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
# DeepseekV4-specific attributes (num_heads is already TP-adjusted)
|
||||
@@ -173,12 +144,12 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
self.o_lora_rank = config.o_lora_rank
|
||||
|
||||
# Store projection modules
|
||||
self.fused_wqa_wkv = mla_modules.fused_wqa_wkv
|
||||
self.q_norm = mla_modules.q_norm
|
||||
self.wq_b = mla_modules.wq_b
|
||||
self.fused_wqa_wkv = fused_wqa_wkv
|
||||
self.q_norm = q_norm
|
||||
self.wq_b = wq_b
|
||||
|
||||
self.kv_norm = mla_modules.kv_norm
|
||||
self.wo_a = mla_modules.wo_a
|
||||
self.kv_norm = kv_norm
|
||||
self.wo_a = wo_a
|
||||
|
||||
self._wo_a_act_quant = QuantFP8(
|
||||
static=False,
|
||||
@@ -188,7 +159,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
# Bypass packed-for-deepgemm path — we need FP32 scales (not packed
|
||||
# INT32) so fp8_einsum can handle layout transform internally.
|
||||
self._wo_a_act_quant.use_deep_gemm_supported = False
|
||||
self.wo_b = mla_modules.wo_b
|
||||
self.wo_b = wo_b
|
||||
|
||||
# Pick fp8_einsum recipe based on GPU arch:
|
||||
# SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128
|
||||
@@ -198,11 +169,11 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128)
|
||||
self._tma_aligned_scales = cap.major >= 10
|
||||
|
||||
self.rotary_emb = mla_modules.rotary_emb
|
||||
self.indexer_rotary_emb = mla_modules.indexer_rotary_emb
|
||||
self.topk_indices_buffer = mla_modules.topk_indices_buffer
|
||||
self.rotary_emb = rotary_emb
|
||||
self.indexer_rotary_emb = indexer_rotary_emb
|
||||
self.topk_indices_buffer = topk_indices_buffer
|
||||
|
||||
self.indexer = mla_modules.indexer
|
||||
self.indexer = indexer
|
||||
|
||||
# Per-head RMS normalization for Q (no learnable weights)
|
||||
self.q_head_norm = RMSNorm(head_dim, eps=self.eps, has_weight=False)
|
||||
@@ -216,7 +187,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
)
|
||||
|
||||
# Will be None on ROCm for now.
|
||||
self.aux_stream_list = mla_modules.aux_stream_list
|
||||
self.aux_stream_list = aux_stream_list
|
||||
# [0]: GEMM start / post-GEMM event0. [1..3]: GEMM done events;
|
||||
# [1] doubles as post-GEMM event1. Reuse is safe: GEMM fully joins
|
||||
# before post-GEMM starts.
|
||||
@@ -243,7 +214,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
window_size=self.window_size,
|
||||
head_bytes=head_bytes,
|
||||
swa_cache_layer=self.swa_cache_layer,
|
||||
attn_sink=mla_modules.attn_sink, # already padded with -inf
|
||||
attn_sink=attn_sink, # already padded with -inf
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
@@ -253,21 +224,12 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
# Mirror the inner layer's padded head count (single source of truth).
|
||||
self.padded_heads = self.mla_attn.padded_heads
|
||||
|
||||
# Register this layer in the compilation config's static forward context
|
||||
# This allows the custom op to retrieve the layer during execution
|
||||
compilation_config = mla_modules.vllm_config.compilation_config
|
||||
# HACK
|
||||
self.layer_name = prefix + ".deepseek_v4_multi_head_latent_attention"
|
||||
if self.layer_name in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {self.layer_name}")
|
||||
compilation_config.static_forward_context[self.layer_name] = self
|
||||
|
||||
# Create the compressor for layers with compress_ratio > 1; after
|
||||
# creating the DeepseekV4MLAAttention layer to get its cache.
|
||||
self.compressor = None
|
||||
if self.compress_ratio > 1:
|
||||
self.compressor = DeepseekCompressor(
|
||||
vllm_config=mla_modules.vllm_config,
|
||||
vllm_config=vllm_config,
|
||||
compress_ratio=self.compress_ratio,
|
||||
hidden_size=self.hidden_size,
|
||||
head_dim=self.head_dim,
|
||||
@@ -291,15 +253,10 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
# @eager_break_during_capture: this is where the breakable
|
||||
# cudagraph capture breaks (the attention op runs eagerly between
|
||||
# captured graph segments).
|
||||
deepseek_v4_attention(
|
||||
hidden_states,
|
||||
positions,
|
||||
o_padded,
|
||||
self.layer_name,
|
||||
)
|
||||
# attention_impl is wrapped with @eager_break_during_capture: this is
|
||||
# where the breakable cudagraph capture breaks (the attention op runs
|
||||
# eagerly between captured graph segments).
|
||||
self.attention_impl(hidden_states, positions, o_padded)
|
||||
o = o_padded[:, : self.n_local_heads, :]
|
||||
|
||||
# Keep ROCm on the BF16 reference wo_a path util kernel ready.
|
||||
@@ -405,6 +362,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
|
||||
return qr_kv, kv_score, indexer_kv_score, indexer_weights
|
||||
|
||||
@eager_break_during_capture
|
||||
def attention_impl(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -541,18 +499,6 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
)
|
||||
|
||||
|
||||
@eager_break_during_capture
|
||||
def deepseek_v4_attention(
|
||||
hidden_states: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.attention_impl(hidden_states, positions, out)
|
||||
|
||||
|
||||
class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -54,8 +54,7 @@ from vllm.model_executor.models.utils import (
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.models.deepseek_v4.attention import (
|
||||
DeepseekV4Indexer,
|
||||
DeepseekV4MLAModules,
|
||||
DeepseekV4MultiHeadLatentAttentionWrapper,
|
||||
DeepseekV4MLA,
|
||||
)
|
||||
from vllm.models.deepseek_v4.nvidia.ops.prepare_megamoe import prepare_megamoe_inputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@@ -697,7 +696,7 @@ class DeepseekV4Attention(nn.Module):
|
||||
|
||||
self.rope_parameters = config.rope_scaling
|
||||
|
||||
# Initialize rotary embedding BEFORE DeepseekV4MLAModules (which needs it)
|
||||
# Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it)
|
||||
rope_parameters = config.rope_parameters
|
||||
rope_parameters["rope_theta"] = (
|
||||
config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta
|
||||
@@ -741,7 +740,17 @@ class DeepseekV4Attention(nn.Module):
|
||||
aux_stream=indexer_aux_stream,
|
||||
)
|
||||
|
||||
mla_modules = DeepseekV4MLAModules(
|
||||
self.mla_attn = DeepseekV4MLA(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=self.n_local_heads,
|
||||
head_dim=self.head_dim,
|
||||
scale=self.softmax_scale,
|
||||
qk_nope_head_dim=self.nope_head_dim,
|
||||
qk_rope_head_dim=self.rope_head_dim,
|
||||
v_head_dim=self.head_dim,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.head_dim,
|
||||
o_lora_rank=self.o_lora_rank,
|
||||
vllm_config=vllm_config,
|
||||
fused_wqa_wkv=self.fused_wqa_wkv,
|
||||
q_norm=self.q_norm,
|
||||
@@ -755,19 +764,6 @@ class DeepseekV4Attention(nn.Module):
|
||||
indexer_rotary_emb=self.rotary_emb,
|
||||
topk_indices_buffer=topk_indices_buffer,
|
||||
aux_stream_list=aux_stream_list,
|
||||
)
|
||||
self.mla_attn = DeepseekV4MultiHeadLatentAttentionWrapper(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=self.n_local_heads,
|
||||
head_dim=self.head_dim,
|
||||
scale=self.softmax_scale,
|
||||
qk_nope_head_dim=self.nope_head_dim,
|
||||
qk_rope_head_dim=self.rope_head_dim,
|
||||
v_head_dim=self.head_dim,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.head_dim,
|
||||
o_lora_rank=self.o_lora_rank,
|
||||
mla_modules=mla_modules,
|
||||
window_size=self.window_size,
|
||||
compress_ratio=self.compress_ratio,
|
||||
cache_config=vllm_config.cache_config,
|
||||
@@ -955,7 +951,7 @@ class DeepseekV4Model(nn.Module):
|
||||
self.rms_norm_eps = config.rms_norm_eps
|
||||
|
||||
# Three aux streams: one per non-default input GEMM in
|
||||
# DeepseekV4MultiHeadLatentAttentionWrapper.attn_gemm_parallel_execute
|
||||
# DeepseekV4MLA.attn_gemm_parallel_execute
|
||||
# (compressor kv_score, indexer.weights_proj, indexer.compressor
|
||||
# kv_score). fused_wqa_wkv stays on the default stream.
|
||||
aux_stream_list = [torch.cuda.Stream() for _ in range(3)]
|
||||
|
||||
Reference in New Issue
Block a user