diff --git a/vllm/models/deepseek_v4/attention.py b/vllm/models/deepseek_v4/attention.py index 4bdf6acc079..29302584880 100644 --- a/vllm/models/deepseek_v4/attention.py +++ b/vllm/models/deepseek_v4/attention.py @@ -330,10 +330,34 @@ class DeepseekV4Attention(nn.Module, AttentionLayerBase, ABC): device=hidden_states.device, ) + # Metadata-independent input GEMMs + RMSNorm stay in the captured + # graph; the metadata-dependent rest (q up-proj + kv-insert, indexer, + # compressor, MLA attention) runs in the eager break. + qr_kv, kv_score, indexer_kv_score, indexer_weights = ( + self.attn_gemm_parallel_execute(hidden_states) + ) + qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1) + qr, kv = fused_q_kv_rmsnorm( + qr, + kv, + self.q_norm.weight.data, + self.kv_norm.weight.data, + self.eps, + ) + # attention_impl is wrapped with @eager_break_during_capture: this is # where the breakable cudagraph capture breaks (the attention op runs # eagerly between captured graph segments). - self.attention_impl(hidden_states, positions, o_padded) + self.attention_impl( + hidden_states, + qr, + kv, + kv_score, + indexer_kv_score, + indexer_weights, + positions, + o_padded, + ) o = o_padded[:, : self.n_local_heads, :] # Inverse-RoPE + wo_a + wo_b output projection (platform-specific). @@ -403,25 +427,17 @@ class DeepseekV4Attention(nn.Module, AttentionLayerBase, ABC): def attention_impl( self, hidden_states: torch.Tensor, + qr: torch.Tensor, + kv: torch.Tensor, + kv_score: torch.Tensor, + indexer_kv_score: torch.Tensor, + indexer_weights: torch.Tensor, positions: torch.Tensor, out: torch.Tensor, # [num_tokens, padded_heads, head_dim], written in place ) -> None: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata - qr_kv, kv_score, indexer_kv_score, indexer_weights = ( - self.attn_gemm_parallel_execute(hidden_states) - ) - - qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1) - qr, kv = fused_q_kv_rmsnorm( - qr, - kv, - self.q_norm.weight.data, - self.kv_norm.weight.data, - self.eps, - ) - # wq_b + kv_insert (+ MLA compressor when an indexer is present) ride # on the default stream so q stays on its consumer stream (forward_mqa # downstream reads q on default). Indexer/compressor go on aux for