docs: fix MLA attention docstring examples (#44118)

Co-authored-by: nightcityblade <nightcityblade@gmail.com>
This commit is contained in:
nightcityblade
2026-06-01 03:28:38 +08:00
committed by GitHub
parent 6bdabbad5b
commit 8b8546da1c
@@ -11,12 +11,12 @@ Sq as Q sequence length
Skv as KV sequence length
MLA has two possible ways of computing, a data-movement friendly approach and a
compute friendly approach, we generally want to use the compute friendly
approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1)
and the data-movement friendly approach for "decode" (i.e. the ratio
Sq / Skv is "large").
compute friendly approach. We generally want to use the compute friendly
approach for "prefill" (i.e. the ratio Sq / Skv is relatively large, often near
1) and the data-movement friendly approach for "decode" (i.e. the ratio
Sq / Skv is small).
NOTE what we deem small and large is currently determined by if its labelled
NOTE what we deem small and large is currently determined by if it is labelled
prefill or decode by the scheduler, but this is something we should probably
tune.
@@ -96,7 +96,7 @@ NOTE: in the actual code,
Runtime
q_c = h_t @ W_DQ
q_nope = (q_c @ W_UQ).view(-1, N, P)
ql_nope = einsum("snh,lnh->snl", q, W_UK)
ql_nope = einsum("snh,lnh->snl", q_nope, W_UK)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
@@ -115,7 +115,7 @@ spda_o = scaled_dot_product_attention(
)
o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV)
return o.view(-1, N * V) @ self.num_heads @ W_O
return o.view(-1, N * V) @ W_O
## Chunked Prefill