[DSV4] Move more ops out of eager breakpoint (#44561)

This commit is contained in:
Woosuk Kwon
2026-06-05 06:42:41 -07:00
committed by GitHub
parent bbb6c274c8
commit 02d2da0748
+30 -14
View File
@@ -330,10 +330,34 @@ class DeepseekV4Attention(nn.Module, AttentionLayerBase, ABC):
device=hidden_states.device,
)
# Metadata-independent input GEMMs + RMSNorm stay in the captured
# graph; the metadata-dependent rest (q up-proj + kv-insert, indexer,
# compressor, MLA attention) runs in the eager break.
qr_kv, kv_score, indexer_kv_score, indexer_weights = (
self.attn_gemm_parallel_execute(hidden_states)
)
qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1)
qr, kv = fused_q_kv_rmsnorm(
qr,
kv,
self.q_norm.weight.data,
self.kv_norm.weight.data,
self.eps,
)
# attention_impl is wrapped with @eager_break_during_capture: this is
# where the breakable cudagraph capture breaks (the attention op runs
# eagerly between captured graph segments).
self.attention_impl(hidden_states, positions, o_padded)
self.attention_impl(
hidden_states,
qr,
kv,
kv_score,
indexer_kv_score,
indexer_weights,
positions,
o_padded,
)
o = o_padded[:, : self.n_local_heads, :]
# Inverse-RoPE + wo_a + wo_b output projection (platform-specific).
@@ -403,25 +427,17 @@ class DeepseekV4Attention(nn.Module, AttentionLayerBase, ABC):
def attention_impl(
self,
hidden_states: torch.Tensor,
qr: torch.Tensor,
kv: torch.Tensor,
kv_score: torch.Tensor,
indexer_kv_score: torch.Tensor,
indexer_weights: torch.Tensor,
positions: torch.Tensor,
out: torch.Tensor, # [num_tokens, padded_heads, head_dim], written in place
) -> None:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
qr_kv, kv_score, indexer_kv_score, indexer_weights = (
self.attn_gemm_parallel_execute(hidden_states)
)
qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1)
qr, kv = fused_q_kv_rmsnorm(
qr,
kv,
self.q_norm.weight.data,
self.kv_norm.weight.data,
self.eps,
)
# wq_b + kv_insert (+ MLA compressor when an indexer is present) ride
# on the default stream so q stays on its consumer stream (forward_mqa
# downstream reads q on default). Indexer/compressor go on aux for