From 3519a1a685e72afca3307525b0fb8f2f46c13f48 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Mon, 5 Jan 2026 03:02:53 -0800 Subject: [PATCH] optimize the multi-stream for DSA. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../_torch/attention_backend/sparse/dsa.py | 15 +++----- tensorrt_llm/_torch/modules/attention.py | 34 +++++++++++++++---- 2 files changed, 32 insertions(+), 17 deletions(-) diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index aa32d6317e..71460f8b74 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -1526,11 +1526,10 @@ class Indexer(nn.Module): weights = _scale(weights, q_scale, self.weight_scale_factor) return weights - def _qk_projection_and_rope(self, qr: torch.Tensor, indexer_k: torch.Tensor, + def _qk_projection_and_rope(self, qr: torch.Tensor, k: torch.Tensor, position_ids: torch.Tensor): """Project Q/K and apply RoPE""" q = self.wq_b(qr) - k = self.k_norm(indexer_k) q = q.view(-1, self.n_heads, self.head_dim) q_pe, q_nope = q.split([self.rope_dim, self.head_dim - self.rope_dim], dim=-1) @@ -1540,6 +1539,9 @@ class Indexer(nn.Module): k_pe = k_pe[:, 0, :] return q_pe, q_nope, k_pe, k_nope + def _weight_proj(self, hidden_states: torch.Tensor): + return self.weights_proj(_to_float(hidden_states)) + def _prep_q_or_k(self, qk_pe: torch.Tensor, qk_nope: torch.Tensor): """Concatenate, rotate, and FP8 quantize for Q or K""" q_or_k = torch.cat([qk_pe, qk_nope], dim=-1) @@ -1550,19 +1552,12 @@ class Indexer(nn.Module): return q_or_k @torch.inference_mode() - def forward(self, qr: torch.Tensor, hidden_states: torch.Tensor, + def forward(self, q_and_k: torch.Tensor, weights: torch.Tensor, metadata: DSAtrtllmAttentionMetadata, position_ids: torch.Tensor, indexer_k: torch.Tensor): quant_block_size = metadata.kv_cache_manager.quant_block_size assert quant_block_size == 128, "Only support quant_block_size = 128 for now" - q_and_k, weights = maybe_execute_in_parallel( - lambda: self._qk_projection_and_rope(qr, indexer_k, position_ids), - lambda: self.weights_proj(_to_float(hidden_states)), - self.ln_events[0], - self.ln_events[1], - self.aux_stream, - ) q_pe, q_nope, k_pe, k_nope = q_and_k q, k = maybe_execute_in_parallel( lambda: self._prep_q_or_k(q_pe, q_nope), diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 47e793e481..c68cc28af5 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -1346,6 +1346,14 @@ class MLA(nn.Module): Returns: torch.Tensor: The output tensor. """ + + def _qk_projection_and_rope(q: torch.Tensor, k: torch.Tensor, + position_ids: torch.Tensor): + q_b_proj = self.q_b_proj(q) + indexer_q_and_k = self.indexer._qk_projection_and_rope( + q, k, position_ids) + return q_b_proj, indexer_q_and_k + assert self.mha is None and self.mqa is not None, "DSA is only supported in MQA mode" # split q, k, v into context and gen batches num_contexts = attn_metadata.num_contexts @@ -1363,7 +1371,6 @@ class MLA(nn.Module): self.indexer.head_dim ], -1) - # TODO: possibly overlap/fuse q_a_rmsnorm + kv_a_rmsnorm + indexer.k_layernorm? q, compressed_kv = maybe_execute_in_parallel( lambda: self.q_a_layernorm(q), lambda: self.kv_a_layernorm(compressed_kv), @@ -1371,15 +1378,28 @@ class MLA(nn.Module): self.ln_events[1], self.aux_stream, ) - qr = q - latent_cache = torch.concat([compressed_kv, k_pe], dim=-1) + latent_cache, indexer_k = maybe_execute_in_parallel( + lambda: torch.concat([compressed_kv, k_pe], dim=-1), + lambda: self.indexer.k_norm(indexer_k), + self.ln_events[0], + self.ln_events[1], + self.aux_stream, + ) + q_and_k, indexer_weights = maybe_execute_in_parallel( + lambda: _qk_projection_and_rope(q, indexer_k, position_ids), + lambda: self.indexer._weight_proj(hidden_states), + self.ln_events[0], + self.ln_events[1], + self.aux_stream, + ) + q, indexer_q_pe, indexer_q_nope, indexer_k_pe, indexer_k_nope = q_and_k + indxer_q_and_k = (indexer_q_pe, indexer_q_nope, indexer_k_pe, + indexer_k_nope) - # TODO: fuse wq_b + (indexer) wlq here - q = self.q_b_proj(q) # Indexer topk_indices = self.indexer( - qr, - hidden_states, + indxer_q_and_k, + indexer_weights, attn_metadata, position_ids, indexer_k=indexer_k, # indexer K proj