mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[DSV4] Move more ops out of eager breakpoint (#44561)
This commit is contained in:
@@ -330,10 +330,34 @@ class DeepseekV4Attention(nn.Module, AttentionLayerBase, ABC):
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
# Metadata-independent input GEMMs + RMSNorm stay in the captured
|
||||
# graph; the metadata-dependent rest (q up-proj + kv-insert, indexer,
|
||||
# compressor, MLA attention) runs in the eager break.
|
||||
qr_kv, kv_score, indexer_kv_score, indexer_weights = (
|
||||
self.attn_gemm_parallel_execute(hidden_states)
|
||||
)
|
||||
qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1)
|
||||
qr, kv = fused_q_kv_rmsnorm(
|
||||
qr,
|
||||
kv,
|
||||
self.q_norm.weight.data,
|
||||
self.kv_norm.weight.data,
|
||||
self.eps,
|
||||
)
|
||||
|
||||
# attention_impl is wrapped with @eager_break_during_capture: this is
|
||||
# where the breakable cudagraph capture breaks (the attention op runs
|
||||
# eagerly between captured graph segments).
|
||||
self.attention_impl(hidden_states, positions, o_padded)
|
||||
self.attention_impl(
|
||||
hidden_states,
|
||||
qr,
|
||||
kv,
|
||||
kv_score,
|
||||
indexer_kv_score,
|
||||
indexer_weights,
|
||||
positions,
|
||||
o_padded,
|
||||
)
|
||||
o = o_padded[:, : self.n_local_heads, :]
|
||||
|
||||
# Inverse-RoPE + wo_a + wo_b output projection (platform-specific).
|
||||
@@ -403,25 +427,17 @@ class DeepseekV4Attention(nn.Module, AttentionLayerBase, ABC):
|
||||
def attention_impl(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
qr: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
kv_score: torch.Tensor,
|
||||
indexer_kv_score: torch.Tensor,
|
||||
indexer_weights: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
out: torch.Tensor, # [num_tokens, padded_heads, head_dim], written in place
|
||||
) -> None:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
qr_kv, kv_score, indexer_kv_score, indexer_weights = (
|
||||
self.attn_gemm_parallel_execute(hidden_states)
|
||||
)
|
||||
|
||||
qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1)
|
||||
qr, kv = fused_q_kv_rmsnorm(
|
||||
qr,
|
||||
kv,
|
||||
self.q_norm.weight.data,
|
||||
self.kv_norm.weight.data,
|
||||
self.eps,
|
||||
)
|
||||
|
||||
# wq_b + kv_insert (+ MLA compressor when an indexer is present) ride
|
||||
# on the default stream so q stays on its consumer stream (forward_mqa
|
||||
# downstream reads q on default). Indexer/compressor go on aux for
|
||||
|
||||
Reference in New Issue
Block a user