diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 71fd297a7ed..140e071c746 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -11,12 +11,12 @@ Sq as Q sequence length Skv as KV sequence length MLA has two possible ways of computing, a data-movement friendly approach and a -compute friendly approach, we generally want to use the compute friendly -approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1) -and the data-movement friendly approach for "decode" (i.e. the ratio -Sq / Skv is "large"). +compute friendly approach. We generally want to use the compute friendly +approach for "prefill" (i.e. the ratio Sq / Skv is relatively large, often near +1) and the data-movement friendly approach for "decode" (i.e. the ratio +Sq / Skv is small). -NOTE what we deem small and large is currently determined by if its labelled +NOTE what we deem small and large is currently determined by if it is labelled prefill or decode by the scheduler, but this is something we should probably tune. @@ -96,7 +96,7 @@ NOTE: in the actual code, Runtime q_c = h_t @ W_DQ q_nope = (q_c @ W_UQ).view(-1, N, P) -ql_nope = einsum("snh,lnh->snl", q, W_UK) +ql_nope = einsum("snh,lnh->snl", q_nope, W_UK) q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) new_kv_c = h_t @ W_DKV new_k_pe = RoPE(h_t @ W_KR) @@ -115,7 +115,7 @@ spda_o = scaled_dot_product_attention( ) o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV) -return o.view(-1, N * V) @ self.num_heads @ W_O +return o.view(-1, N * V) @ W_O ## Chunked Prefill