mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[DSv4] Refactor DeepseekV4Attention
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -18,7 +18,6 @@ from vllm.model_executor.layers.activation import SiluAndMul, SiluAndMulWithClam
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
@@ -46,10 +45,8 @@ from vllm.model_executor.models.utils import (
|
||||
maybe_prefix,
|
||||
)
|
||||
from vllm.models.deepseek_v4.attention import (
|
||||
DeepseekV4Indexer,
|
||||
DeepseekV4MLA,
|
||||
DeepseekV4Attention,
|
||||
)
|
||||
from vllm.models.deepseek_v4.common.rope import build_deepseek_v4_rope
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.import_utils import has_tilelang
|
||||
@@ -225,158 +222,6 @@ class DeepseekV4MoE(nn.Module):
|
||||
return final_hidden_states.view(org_shape)
|
||||
|
||||
|
||||
class DeepseekV4Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str,
|
||||
topk_indices_buffer: torch.Tensor | None = None,
|
||||
aux_stream_list: list[torch.cuda.Stream] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
layer_id = extract_layer_index(prefix)
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.hidden_size = config.hidden_size
|
||||
self.n_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert self.n_heads % tp_size == 0
|
||||
|
||||
self.n_local_heads = self.n_heads // tp_size
|
||||
self.q_lora_rank = config.q_lora_rank
|
||||
self.o_lora_rank = config.o_lora_rank
|
||||
self.head_dim = config.head_dim
|
||||
self.rope_head_dim = config.qk_rope_head_dim
|
||||
self.nope_head_dim = self.head_dim - self.rope_head_dim
|
||||
self.n_groups = config.o_groups
|
||||
self.n_local_groups = self.n_groups // tp_size
|
||||
self.window_size = config.sliding_window
|
||||
# NOTE(zyongye) Compress ratio can't be 0
|
||||
# we do this for because MTP layer is not included
|
||||
# in the compress ratio list
|
||||
if layer_id < config.num_hidden_layers:
|
||||
self.compress_ratio = max(1, config.compress_ratios[layer_id])
|
||||
else:
|
||||
self.compress_ratio = 1
|
||||
self.eps = config.rms_norm_eps
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
|
||||
# Padded to min 64 heads for FlashMLA, initialized to -inf
|
||||
# (no sink effect). Weight loading fills the first n_local_heads slots.
|
||||
padded_heads = max(self.n_local_heads, 64)
|
||||
self.attn_sink = nn.Parameter(
|
||||
torch.full((padded_heads,), -float("inf"), dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
self.fused_wqa_wkv = MergedColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
[self.q_lora_rank, self.head_dim],
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fused_wqa_wkv",
|
||||
disable_tp=True, # fused ReplicatedLinear
|
||||
)
|
||||
self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
|
||||
self.wq_b = ColumnParallelLinear(
|
||||
self.q_lora_rank,
|
||||
self.n_heads * self.head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=f"{prefix}.wq_b",
|
||||
)
|
||||
|
||||
self.kv_norm = RMSNorm(self.head_dim, self.eps)
|
||||
self.wo_a = ColumnParallelLinear(
|
||||
self.n_heads * self.head_dim // self.n_groups,
|
||||
self.n_groups * self.o_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=f"{prefix}.wo_a",
|
||||
)
|
||||
self.wo_a.is_bmm = True
|
||||
self.wo_a.bmm_batch_size = self.n_local_groups
|
||||
self.wo_b = RowParallelLinear(
|
||||
self.n_groups * self.o_lora_rank,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=f"{prefix}.wo_b",
|
||||
)
|
||||
self.softmax_scale = self.head_dim**-0.5
|
||||
self.scale_fmt = config.quantization_config["scale_fmt"]
|
||||
|
||||
self.rope_parameters = config.rope_scaling
|
||||
|
||||
# Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it)
|
||||
self.rotary_emb = build_deepseek_v4_rope(
|
||||
config,
|
||||
head_dim=self.head_dim,
|
||||
rope_head_dim=self.rope_head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
compress_ratio=self.compress_ratio,
|
||||
)
|
||||
|
||||
self.indexer = None
|
||||
if self.compress_ratio == 4:
|
||||
# Only C4A uses sparse attention and hence has indexer.
|
||||
self.indexer = DeepseekV4Indexer(
|
||||
vllm_config,
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
quant_config=quant_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
topk_indices_buffer=topk_indices_buffer,
|
||||
compress_ratio=self.compress_ratio,
|
||||
prefix=f"{prefix}.indexer",
|
||||
)
|
||||
|
||||
self.mla_attn = DeepseekV4MLA(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=self.n_local_heads,
|
||||
head_dim=self.head_dim,
|
||||
scale=self.softmax_scale,
|
||||
qk_nope_head_dim=self.nope_head_dim,
|
||||
qk_rope_head_dim=self.rope_head_dim,
|
||||
v_head_dim=self.head_dim,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.head_dim,
|
||||
o_lora_rank=self.o_lora_rank,
|
||||
vllm_config=vllm_config,
|
||||
fused_wqa_wkv=self.fused_wqa_wkv,
|
||||
q_norm=self.q_norm,
|
||||
wq_b=self.wq_b,
|
||||
kv_norm=self.kv_norm,
|
||||
wo_a=self.wo_a,
|
||||
wo_b=self.wo_b,
|
||||
attn_sink=self.attn_sink,
|
||||
rotary_emb=self.rotary_emb,
|
||||
indexer=self.indexer,
|
||||
indexer_rotary_emb=self.rotary_emb,
|
||||
topk_indices_buffer=topk_indices_buffer,
|
||||
aux_stream_list=aux_stream_list,
|
||||
window_size=self.window_size,
|
||||
compress_ratio=self.compress_ratio,
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
llama_4_scaling: torch.Tensor | None,
|
||||
):
|
||||
return self.mla_attn(positions, hidden_states, llama_4_scaling)
|
||||
|
||||
|
||||
class DeepseekV4DecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -601,7 +446,7 @@ class DeepseekV4Model(nn.Module):
|
||||
self.rms_norm_eps = config.rms_norm_eps
|
||||
|
||||
# Three aux streams: one per non-default input GEMM in
|
||||
# DeepseekV4MLA.attn_gemm_parallel_execute
|
||||
# DeepseekV4Attention.attn_gemm_parallel_execute
|
||||
# (compressor kv_score, indexer.weights_proj, indexer.compressor
|
||||
# kv_score). fused_wqa_wkv stays on the default stream.
|
||||
# Disable them on ROCm because of hang issues.
|
||||
@@ -897,7 +742,6 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
|
||||
".ffn.gate.bias": ".ffn.gate.e_score_correction_bias",
|
||||
},
|
||||
orig_to_new_substr={
|
||||
".attn.compressor.": ".attn.mla_attn.compressor.",
|
||||
".shared_experts.w2": ".shared_experts.down_proj",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -33,7 +33,7 @@ from vllm.v1.worker.workspace import current_workspace_manager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.models.deepseek_v4.attention import (
|
||||
DeepseekV4MLAAttention,
|
||||
DeepseekV4Attention,
|
||||
)
|
||||
|
||||
|
||||
@@ -599,7 +599,7 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
@classmethod
|
||||
def forward_mqa( # type: ignore[override]
|
||||
cls,
|
||||
layer: "DeepseekV4MLAAttention",
|
||||
layer: "DeepseekV4Attention",
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
@@ -677,7 +677,7 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
@classmethod
|
||||
def _forward_decode(
|
||||
cls,
|
||||
layer: "DeepseekV4MLAAttention",
|
||||
layer: "DeepseekV4Attention",
|
||||
q: torch.Tensor,
|
||||
kv_cache: torch.Tensor | None,
|
||||
swa_metadata: DeepseekV4ROCMAiterSparseSWAMetadata,
|
||||
@@ -740,7 +740,7 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
@classmethod
|
||||
def _forward_prefill(
|
||||
cls,
|
||||
layer: "DeepseekV4MLAAttention",
|
||||
layer: "DeepseekV4Attention",
|
||||
q: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
compressed_k_cache: torch.Tensor | None,
|
||||
|
||||
@@ -15,7 +15,10 @@ from transformers import DeepseekV2Config, DeepseekV3Config
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.breakable_cudagraph import eager_break_during_capture
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer
|
||||
from vllm.models.deepseek_v4.common.ops import (
|
||||
@@ -42,12 +45,8 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import (
|
||||
QuantFP8,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
)
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.models.deepseek_v4.common.rope import build_deepseek_v4_rope
|
||||
from vllm.models.deepseek_v4.compressor import DeepseekCompressor
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.multi_stream_utils import (
|
||||
@@ -88,78 +87,89 @@ def _select_v4_sparse_impl() -> "type[DeepseekV4SparseMLAAttentionImpl]":
|
||||
return DeepseekV4FlashMLASparseImpl
|
||||
|
||||
|
||||
class DeepseekV4MLA(nn.Module):
|
||||
class DeepseekV4Attention(nn.Module, AttentionLayerBase):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
scale: float,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
q_lora_rank: int | None,
|
||||
kv_lora_rank: int,
|
||||
o_lora_rank: int | None,
|
||||
vllm_config: VllmConfig,
|
||||
fused_wqa_wkv: torch.nn.Module,
|
||||
q_norm: torch.nn.Module,
|
||||
wq_b: torch.nn.Module,
|
||||
kv_norm: torch.nn.Module,
|
||||
wo_a: torch.nn.Module,
|
||||
wo_b: torch.nn.Module,
|
||||
attn_sink: torch.nn.Module,
|
||||
rotary_emb: torch.nn.Module,
|
||||
indexer: torch.nn.Module | None,
|
||||
indexer_rotary_emb: torch.nn.Module,
|
||||
topk_indices_buffer: torch.Tensor | None,
|
||||
aux_stream_list: list[torch.cuda.Stream] | None,
|
||||
window_size: int,
|
||||
compress_ratio: int | None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
prefix: str,
|
||||
topk_indices_buffer: torch.Tensor | None = None,
|
||||
aux_stream_list: list[torch.cuda.Stream] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.n_local_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.scale = scale
|
||||
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.window_size = window_size
|
||||
self.compress_ratio = compress_ratio if compress_ratio is not None else 1
|
||||
self.prefix = prefix
|
||||
|
||||
# Extract config from vllm_config
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
cache_config = vllm_config.cache_config
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
layer_id = extract_layer_index(prefix)
|
||||
|
||||
# DeepseekV4-specific attributes (num_heads is already TP-adjusted)
|
||||
self.eps = config.rms_norm_eps
|
||||
self.rope_head_dim = config.qk_rope_head_dim
|
||||
self.nope_head_dim = head_dim - self.rope_head_dim
|
||||
self.n_local_groups = config.o_groups // tp_size
|
||||
self.prefix = prefix # Alias for compatibility with compressor
|
||||
self.hidden_size = config.hidden_size
|
||||
self.n_heads = config.num_attention_heads
|
||||
assert self.n_heads % tp_size == 0
|
||||
self.n_local_heads = self.n_heads // tp_size
|
||||
self.q_lora_rank = config.q_lora_rank
|
||||
self.o_lora_rank = config.o_lora_rank
|
||||
self.head_dim = config.head_dim
|
||||
self.rope_head_dim = config.qk_rope_head_dim
|
||||
self.nope_head_dim = self.head_dim - self.rope_head_dim
|
||||
self.n_groups = config.o_groups
|
||||
self.n_local_groups = self.n_groups // tp_size
|
||||
self.window_size = config.sliding_window
|
||||
# NOTE(zyongye) Compress ratio can't be 0
|
||||
# we do this for because MTP layer is not included
|
||||
# in the compress ratio list
|
||||
if layer_id < config.num_hidden_layers:
|
||||
self.compress_ratio = max(1, config.compress_ratios[layer_id])
|
||||
else:
|
||||
self.compress_ratio = 1
|
||||
self.eps = config.rms_norm_eps
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
# Store projection modules
|
||||
self.fused_wqa_wkv = fused_wqa_wkv
|
||||
self.q_norm = q_norm
|
||||
self.wq_b = wq_b
|
||||
|
||||
self.kv_norm = kv_norm
|
||||
self.wo_a = wo_a
|
||||
|
||||
self._wo_a_act_quant = QuantFP8(
|
||||
static=False,
|
||||
group_shape=GroupShape(1, 128),
|
||||
use_ue8m0=True,
|
||||
# Padded to min 64 heads for FlashMLA, initialized to -inf
|
||||
# (no sink effect). Weight loading fills the first n_local_heads slots.
|
||||
padded_heads = max(self.n_local_heads, 64)
|
||||
self.attn_sink = nn.Parameter(
|
||||
torch.full((padded_heads,), -float("inf"), dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
self.fused_wqa_wkv = MergedColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
[self.q_lora_rank, self.head_dim],
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fused_wqa_wkv",
|
||||
disable_tp=True, # fused ReplicatedLinear
|
||||
)
|
||||
self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
|
||||
self.wq_b = ColumnParallelLinear(
|
||||
self.q_lora_rank,
|
||||
self.n_heads * self.head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=f"{prefix}.wq_b",
|
||||
)
|
||||
|
||||
self.kv_norm = RMSNorm(self.head_dim, self.eps)
|
||||
self.wo_a = ColumnParallelLinear(
|
||||
self.n_heads * self.head_dim // self.n_groups,
|
||||
self.n_groups * self.o_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=f"{prefix}.wo_a",
|
||||
)
|
||||
self.wo_a.is_bmm = True
|
||||
self.wo_a.bmm_batch_size = self.n_local_groups
|
||||
self.wo_b = RowParallelLinear(
|
||||
self.n_groups * self.o_lora_rank,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=f"{prefix}.wo_b",
|
||||
)
|
||||
# Bypass packed-for-deepgemm path — we need FP32 scales (not packed
|
||||
# INT32) so fp8_einsum can handle layout transform internally.
|
||||
self._wo_a_act_quant.use_deep_gemm_supported = False
|
||||
self.wo_b = wo_b
|
||||
|
||||
# Pick fp8_einsum recipe based on GPU arch:
|
||||
# SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128
|
||||
@@ -169,22 +179,38 @@ class DeepseekV4MLA(nn.Module):
|
||||
self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128)
|
||||
self._tma_aligned_scales = cap.major >= 10
|
||||
|
||||
self.rotary_emb = rotary_emb
|
||||
self.indexer_rotary_emb = indexer_rotary_emb
|
||||
# Initialize rotary embedding before the indexer/compressor consume it.
|
||||
self.rotary_emb = build_deepseek_v4_rope(
|
||||
config,
|
||||
head_dim=self.head_dim,
|
||||
rope_head_dim=self.rope_head_dim,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
compress_ratio=self.compress_ratio,
|
||||
)
|
||||
self.indexer_rotary_emb = self.rotary_emb
|
||||
self.topk_indices_buffer = topk_indices_buffer
|
||||
|
||||
self.indexer = indexer
|
||||
|
||||
# Per-head RMS normalization for Q (no learnable weights)
|
||||
self.q_head_norm = RMSNorm(head_dim, eps=self.eps, has_weight=False)
|
||||
|
||||
# TODO(yifan): currently hardcoded for FP8 sparse, make it more generic
|
||||
head_bytes = (
|
||||
self.nope_head_dim # 448 fp8 NoPE
|
||||
+ self.rope_head_dim * 2 # 64 bf16 RoPE
|
||||
+ self.nope_head_dim // 64 # 7B scale factors
|
||||
+ 1 # 1B pad
|
||||
)
|
||||
self.indexer = None
|
||||
if self.compress_ratio == 4:
|
||||
# Only C4A uses sparse attention and hence has indexer.
|
||||
# aux_stream_list[2] is free here (outer GEMMs joined) for the inner
|
||||
# overlap of wq_b+fused_indexer_q_rope_quant vs compressor. None on
|
||||
# ROCm, where aux_stream_list is None.
|
||||
indexer_aux_stream = (
|
||||
aux_stream_list[2] if aux_stream_list is not None else None
|
||||
)
|
||||
self.indexer = DeepseekV4Indexer(
|
||||
vllm_config,
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
quant_config=quant_config,
|
||||
cache_config=cache_config,
|
||||
topk_indices_buffer=topk_indices_buffer,
|
||||
compress_ratio=self.compress_ratio,
|
||||
prefix=f"{prefix}.indexer",
|
||||
aux_stream=indexer_aux_stream,
|
||||
)
|
||||
|
||||
# Will be None on ROCm for now.
|
||||
self.aux_stream_list = aux_stream_list
|
||||
@@ -202,30 +228,48 @@ class DeepseekV4MLA(nn.Module):
|
||||
cache_config=cache_config,
|
||||
)
|
||||
|
||||
self.mla_attn = DeepseekV4MLAAttention(
|
||||
num_heads=self.n_local_heads,
|
||||
head_dim=self.head_dim,
|
||||
scale=self.scale,
|
||||
qk_nope_head_dim=self.nope_head_dim,
|
||||
qk_rope_head_dim=self.rope_head_dim,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
compress_ratio=self.compress_ratio,
|
||||
window_size=self.window_size,
|
||||
head_bytes=head_bytes,
|
||||
swa_cache_layer=self.swa_cache_layer,
|
||||
attn_sink=attn_sink, # already padded with -inf
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
indexer=self.indexer,
|
||||
topk_indices_buffer=self.topk_indices_buffer,
|
||||
)
|
||||
# Mirror the inner layer's padded head count (single source of truth).
|
||||
self.padded_heads = self.mla_attn.padded_heads
|
||||
# ---- Attention layer setup (formerly DeepseekV4MLAAttention) ----
|
||||
self.impl_cls = _select_v4_sparse_impl()
|
||||
self.backend_cls = self.impl_cls.backend_cls
|
||||
# Padded Q head count is dictated by the selected impl.
|
||||
self.padded_heads = self.impl_cls.get_padded_num_q_heads(self.n_local_heads)
|
||||
|
||||
# Create the compressor for layers with compress_ratio > 1; after
|
||||
# creating the DeepseekV4MLAAttention layer to get its cache.
|
||||
self.max_num_batched_tokens = (
|
||||
vllm_config.scheduler_config.max_num_batched_tokens
|
||||
)
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
# DeepseekV4 only supports fp8 kv-cache format for now.
|
||||
kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "fp8"
|
||||
assert kv_cache_dtype.startswith("fp8"), (
|
||||
f"DeepseekV4 only supports fp8 kv-cache format for now, "
|
||||
f"got {kv_cache_dtype}"
|
||||
)
|
||||
assert issubclass(self.get_attn_backend(), FlashMLASparseBackend), (
|
||||
"Only FlashMLA Sparse Attention backend is supported for DeepseekV4 for now"
|
||||
)
|
||||
# FlashMLA Sparse Attention fp8 backend uses "fp8_ds_mla" kv-cache format
|
||||
# Automatically convert fp8 kv-cache format to "fp8_ds_mla"
|
||||
if (
|
||||
issubclass(self.get_attn_backend(), FlashMLASparseBackend)
|
||||
and kv_cache_dtype.startswith("fp8")
|
||||
and kv_cache_dtype != "fp8_ds_mla"
|
||||
):
|
||||
assert cache_config is not None
|
||||
cache_config.cache_dtype = "fp8_ds_mla"
|
||||
kv_cache_dtype = "fp8_ds_mla"
|
||||
logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.")
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
# Register with compilation context for metadata lookup
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix and prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
if prefix:
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.kv_cache = torch.tensor([])
|
||||
|
||||
# Create the compressor for layers with compress_ratio > 1; after the
|
||||
# attention-layer setup above so its KV cache prefix is available.
|
||||
self.compressor = None
|
||||
if self.compress_ratio > 1:
|
||||
self.compressor = DeepseekCompressor(
|
||||
@@ -235,7 +279,7 @@ class DeepseekV4MLA(nn.Module):
|
||||
head_dim=self.head_dim,
|
||||
rotate=True,
|
||||
prefix=f"{prefix}.compressor",
|
||||
k_cache_prefix=self.mla_attn.prefix,
|
||||
k_cache_prefix=self.prefix,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -449,7 +493,7 @@ class DeepseekV4MLA(nn.Module):
|
||||
|
||||
# MLA attention writes into the pre-allocated `out` buffer
|
||||
# ([num_tokens, padded_heads, head_dim]).
|
||||
self.mla_attn(q, kv, positions, output=out)
|
||||
self.impl_cls.forward_mqa(self, q, kv, positions, out)
|
||||
|
||||
def _fused_qnorm_rope_kv_insert(
|
||||
self,
|
||||
@@ -498,102 +542,6 @@ class DeepseekV4MLA(nn.Module):
|
||||
swa_metadata.block_size,
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
scale: float,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
q_lora_rank: int | None,
|
||||
kv_lora_rank: int,
|
||||
compress_ratio: int,
|
||||
window_size: int,
|
||||
head_bytes: int,
|
||||
swa_cache_layer: DeepseekV4SWACache,
|
||||
attn_sink: torch.Tensor,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
# Sparse MLA Args
|
||||
indexer: object | None = None,
|
||||
topk_indices_buffer: torch.Tensor | None = None,
|
||||
aux_stream: torch.cuda.Stream | None = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.impl_cls = _select_v4_sparse_impl()
|
||||
self.backend_cls = self.impl_cls.backend_cls
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = 1
|
||||
self.head_dim = head_dim
|
||||
self.scale = scale
|
||||
self.window_size = window_size
|
||||
self.head_bytes = head_bytes
|
||||
self.compress_ratio = compress_ratio
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.nope_head_dim = qk_nope_head_dim
|
||||
self.rope_head_dim = qk_rope_head_dim
|
||||
self.indexer = indexer
|
||||
self.topk_indices_buffer = topk_indices_buffer
|
||||
|
||||
self.prefix = prefix # Alias for compatibility with compressor
|
||||
|
||||
self.aux_stream = aux_stream
|
||||
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
|
||||
|
||||
# Padded Q head count is dictated by the selected impl.
|
||||
self.padded_heads = self.impl_cls.get_padded_num_q_heads(num_heads)
|
||||
|
||||
# Store attention sink
|
||||
assert attn_sink is not None
|
||||
self.attn_sink: torch.Tensor = attn_sink
|
||||
# Store SWA cache
|
||||
assert swa_cache_layer is not None
|
||||
self.swa_cache_layer: DeepseekV4SWACache = swa_cache_layer
|
||||
|
||||
# Get vllm config for cache setup
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.max_num_batched_tokens = (
|
||||
vllm_config.scheduler_config.max_num_batched_tokens
|
||||
)
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
# DeepseekV4 only supports fp8 kv-cache format for now.
|
||||
kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "fp8"
|
||||
|
||||
assert kv_cache_dtype.startswith("fp8"), (
|
||||
f"DeepseekV4 only supports fp8 kv-cache format for now, "
|
||||
f"got {kv_cache_dtype}"
|
||||
)
|
||||
assert issubclass(self.get_attn_backend(), FlashMLASparseBackend), (
|
||||
"Only FlashMLA Sparse Attention backend is supported for DeepseekV4 for now"
|
||||
)
|
||||
# FlashMLA Sparse Attention fp8 backend uses "fp8_ds_mla" kv-cache format
|
||||
# Automatically convert fp8 kv-cache format to "fp8_ds_mla"
|
||||
if (
|
||||
issubclass(self.get_attn_backend(), FlashMLASparseBackend)
|
||||
and kv_cache_dtype.startswith("fp8")
|
||||
and kv_cache_dtype != "fp8_ds_mla"
|
||||
):
|
||||
assert cache_config is not None
|
||||
cache_config.cache_dtype = "fp8_ds_mla"
|
||||
kv_cache_dtype = "fp8_ds_mla"
|
||||
logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.")
|
||||
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
# Register with compilation context for metadata lookup
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix and prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
if prefix:
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
self.kv_cache = torch.tensor([])
|
||||
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
return self.backend_cls
|
||||
|
||||
@@ -613,15 +561,6 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
|
||||
model_version="deepseek_v4",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
) -> None:
|
||||
self.impl_cls.forward_mqa(self, q, kv, positions, output)
|
||||
|
||||
|
||||
class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase):
|
||||
def __init__(
|
||||
@@ -778,8 +717,6 @@ class DeepseekV4Indexer(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
rotary_emb: nn.Module,
|
||||
) -> torch.Tensor:
|
||||
compressor = self.compressor
|
||||
|
||||
def wq_b_and_q_quant():
|
||||
# ReplicatedLinear returns (output, bias); bias is None.
|
||||
q, _ = self.wq_b(qr)
|
||||
@@ -798,7 +735,7 @@ class DeepseekV4Indexer(nn.Module):
|
||||
# join orders that write before indexer_op (skip_k_cache_insert=True).
|
||||
(q_quant, weights), k = maybe_execute_in_parallel(
|
||||
wq_b_and_q_quant,
|
||||
lambda: compressor(compressed_kv_score, positions, rotary_emb),
|
||||
lambda: self.compressor(compressed_kv_score, positions, rotary_emb),
|
||||
self.ln_events[0],
|
||||
self.ln_events[1],
|
||||
self.aux_stream,
|
||||
|
||||
@@ -29,7 +29,7 @@ from vllm.v1.worker.workspace import current_workspace_manager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.models.deepseek_v4.attention import (
|
||||
DeepseekV4MLAAttention,
|
||||
DeepseekV4Attention,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.sparse_swa import DeepseekSparseSWAMetadata
|
||||
|
||||
@@ -37,7 +37,7 @@ if TYPE_CHECKING:
|
||||
class DeepseekV4SparseMLAAttentionImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
"""Abstract parent for DeepseekV4 sparse MLA impls.
|
||||
|
||||
V4 sparse MLA is driven by the layer (``DeepseekV4MLAAttention.forward``)
|
||||
V4 sparse MLA is driven by the layer (``DeepseekV4Attention.forward``)
|
||||
rather than the v1 framework, so ``forward_mqa`` is overridden with a
|
||||
classmethod that takes the layer as its first argument. This Liskov-broken
|
||||
override is intentional: the grandparent's instance-method ``forward_mqa``
|
||||
@@ -55,7 +55,7 @@ class DeepseekV4SparseMLAAttentionImpl(SparseMLAAttentionImpl[FlashMLASparseMeta
|
||||
@abstractmethod
|
||||
def forward_mqa( # type: ignore[override]
|
||||
cls,
|
||||
layer: "DeepseekV4MLAAttention",
|
||||
layer: "DeepseekV4Attention",
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
@@ -129,7 +129,7 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
@classmethod
|
||||
def forward_mqa( # type: ignore[override]
|
||||
cls,
|
||||
layer: "DeepseekV4MLAAttention",
|
||||
layer: "DeepseekV4Attention",
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
@@ -210,7 +210,7 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
@classmethod
|
||||
def _forward_decode(
|
||||
cls,
|
||||
layer: "DeepseekV4MLAAttention",
|
||||
layer: "DeepseekV4Attention",
|
||||
q: torch.Tensor,
|
||||
kv_cache: torch.Tensor | None, # Only used when compress_ratio > 1
|
||||
swa_metadata: "DeepseekSparseSWAMetadata",
|
||||
@@ -304,7 +304,7 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
|
||||
@classmethod
|
||||
def _forward_prefill(
|
||||
cls,
|
||||
layer: "DeepseekV4MLAAttention",
|
||||
layer: "DeepseekV4Attention",
|
||||
q: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
compressed_k_cache: torch.Tensor | None, # Only used when compress_ratio > 1
|
||||
|
||||
@@ -29,7 +29,6 @@ from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
|
||||
from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
@@ -52,10 +51,8 @@ from vllm.model_executor.models.utils import (
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.models.deepseek_v4.attention import (
|
||||
DeepseekV4Indexer,
|
||||
DeepseekV4MLA,
|
||||
DeepseekV4Attention,
|
||||
)
|
||||
from vllm.models.deepseek_v4.common.rope import build_deepseek_v4_rope
|
||||
from vllm.models.deepseek_v4.nvidia.ops.prepare_megamoe import prepare_megamoe_inputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@@ -608,165 +605,6 @@ class DeepseekV4MoE(nn.Module):
|
||||
self.experts.finalize_weights()
|
||||
|
||||
|
||||
class DeepseekV4Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str,
|
||||
topk_indices_buffer: torch.Tensor | None = None,
|
||||
aux_stream_list: list[torch.cuda.Stream] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
layer_id = extract_layer_index(prefix)
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.hidden_size = config.hidden_size
|
||||
self.n_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert self.n_heads % tp_size == 0
|
||||
|
||||
self.n_local_heads = self.n_heads // tp_size
|
||||
self.q_lora_rank = config.q_lora_rank
|
||||
self.o_lora_rank = config.o_lora_rank
|
||||
self.head_dim = config.head_dim
|
||||
self.rope_head_dim = config.qk_rope_head_dim
|
||||
self.nope_head_dim = self.head_dim - self.rope_head_dim
|
||||
self.n_groups = config.o_groups
|
||||
self.n_local_groups = self.n_groups // tp_size
|
||||
self.window_size = config.sliding_window
|
||||
# NOTE(zyongye) Compress ratio can't be 0
|
||||
# we do this for because MTP layer is not included
|
||||
# in the compress ratio list
|
||||
if layer_id < config.num_hidden_layers:
|
||||
self.compress_ratio = max(1, config.compress_ratios[layer_id])
|
||||
else:
|
||||
self.compress_ratio = 1
|
||||
self.eps = config.rms_norm_eps
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
|
||||
# Padded to min 64 heads for FlashMLA, initialized to -inf
|
||||
# (no sink effect). Weight loading fills the first n_local_heads slots.
|
||||
padded_heads = max(self.n_local_heads, 64)
|
||||
self.attn_sink = nn.Parameter(
|
||||
torch.full((padded_heads,), -float("inf"), dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
self.fused_wqa_wkv = MergedColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
[self.q_lora_rank, self.head_dim],
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fused_wqa_wkv",
|
||||
disable_tp=True, # fused ReplicatedLinear
|
||||
)
|
||||
self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
|
||||
self.wq_b = ColumnParallelLinear(
|
||||
self.q_lora_rank,
|
||||
self.n_heads * self.head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=f"{prefix}.wq_b",
|
||||
)
|
||||
|
||||
self.kv_norm = RMSNorm(self.head_dim, self.eps)
|
||||
self.wo_a = ColumnParallelLinear(
|
||||
self.n_heads * self.head_dim // self.n_groups,
|
||||
self.n_groups * self.o_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=f"{prefix}.wo_a",
|
||||
)
|
||||
self.wo_a.is_bmm = True
|
||||
self.wo_a.bmm_batch_size = self.n_local_groups
|
||||
self.wo_b = RowParallelLinear(
|
||||
self.n_groups * self.o_lora_rank,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=f"{prefix}.wo_b",
|
||||
)
|
||||
self.softmax_scale = self.head_dim**-0.5
|
||||
self.scale_fmt = config.quantization_config["scale_fmt"]
|
||||
|
||||
self.rope_parameters = config.rope_scaling
|
||||
|
||||
# Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it)
|
||||
self.rotary_emb = build_deepseek_v4_rope(
|
||||
config,
|
||||
head_dim=self.head_dim,
|
||||
rope_head_dim=self.rope_head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
compress_ratio=self.compress_ratio,
|
||||
)
|
||||
|
||||
self.indexer = None
|
||||
if self.compress_ratio == 4:
|
||||
# Only C4A uses sparse attention and hence has indexer.
|
||||
# aux_stream_list[0] runs indexer.forward() in the wrapper; [2] is
|
||||
# free here (outer GEMMs joined) for the inner overlap of
|
||||
# wq_b+fused_indexer_q_rope_quant vs compressor.
|
||||
indexer_aux_stream = (
|
||||
aux_stream_list[2] if aux_stream_list is not None else None
|
||||
)
|
||||
self.indexer = DeepseekV4Indexer(
|
||||
vllm_config,
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
quant_config=quant_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
topk_indices_buffer=topk_indices_buffer,
|
||||
compress_ratio=self.compress_ratio,
|
||||
prefix=f"{prefix}.indexer",
|
||||
aux_stream=indexer_aux_stream,
|
||||
)
|
||||
|
||||
self.mla_attn = DeepseekV4MLA(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=self.n_local_heads,
|
||||
head_dim=self.head_dim,
|
||||
scale=self.softmax_scale,
|
||||
qk_nope_head_dim=self.nope_head_dim,
|
||||
qk_rope_head_dim=self.rope_head_dim,
|
||||
v_head_dim=self.head_dim,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.head_dim,
|
||||
o_lora_rank=self.o_lora_rank,
|
||||
vllm_config=vllm_config,
|
||||
fused_wqa_wkv=self.fused_wqa_wkv,
|
||||
q_norm=self.q_norm,
|
||||
wq_b=self.wq_b,
|
||||
kv_norm=self.kv_norm,
|
||||
wo_a=self.wo_a,
|
||||
wo_b=self.wo_b,
|
||||
attn_sink=self.attn_sink,
|
||||
rotary_emb=self.rotary_emb,
|
||||
indexer=self.indexer,
|
||||
indexer_rotary_emb=self.rotary_emb,
|
||||
topk_indices_buffer=topk_indices_buffer,
|
||||
aux_stream_list=aux_stream_list,
|
||||
window_size=self.window_size,
|
||||
compress_ratio=self.compress_ratio,
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
llama_4_scaling: torch.Tensor | None,
|
||||
):
|
||||
return self.mla_attn(positions, hidden_states, llama_4_scaling)
|
||||
|
||||
|
||||
class DeepseekV4DecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -938,7 +776,7 @@ class DeepseekV4Model(nn.Module):
|
||||
self.rms_norm_eps = config.rms_norm_eps
|
||||
|
||||
# Three aux streams: one per non-default input GEMM in
|
||||
# DeepseekV4MLA.attn_gemm_parallel_execute
|
||||
# DeepseekV4Attention.attn_gemm_parallel_execute
|
||||
# (compressor kv_score, indexer.weights_proj, indexer.compressor
|
||||
# kv_score). fused_wqa_wkv stays on the default stream.
|
||||
aux_stream_list = [torch.cuda.Stream() for _ in range(3)]
|
||||
@@ -1236,7 +1074,6 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
|
||||
".ffn.gate.bias": ".ffn.gate.e_score_correction_bias",
|
||||
},
|
||||
orig_to_new_substr={
|
||||
".attn.compressor.": ".attn.mla_attn.compressor.",
|
||||
".shared_experts.w2": ".shared_experts.down_proj",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -602,7 +602,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
||||
fp8_use_mixed_batch = (
|
||||
self.num_heads < MIN_HEADS_FOR_BF16_PREFILL and not self.is_deepseek_v4
|
||||
)
|
||||
# DeepseekV4 has its own attention impl (DeepseekV4MLAAttention) that does not
|
||||
# DeepseekV4 has its own attention impl (DeepseekV4Attention) that does not
|
||||
# consume fp8_extra_metadata. Skipping the build here avoids a
|
||||
# forced D2H sync on seq_lens that would otherwise fire on every
|
||||
# prefill-bearing step, lifting GPU utilization on long-prefill
|
||||
|
||||
Reference in New Issue
Block a user