diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index 361473b4d1..f88a16adf1 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -1539,9 +1539,6 @@ class Indexer(nn.Module): k = self._prep_q_or_k(k_pe, k_nope) return q, k - def _weight_proj(self, hidden_states: torch.Tensor): - return self.weights_proj(_to_float(hidden_states)) - def _prep_q_or_k(self, qk_pe: torch.Tensor, qk_nope: torch.Tensor): """Concatenate, rotate, and FP8 quantize for Q or K""" q_or_k = torch.cat([qk_pe, qk_nope], dim=-1) @@ -1552,12 +1549,19 @@ class Indexer(nn.Module): return q_or_k @torch.inference_mode() - def forward(self, q_and_k: torch.Tensor, weights: torch.Tensor, + def forward(self, qr: torch.Tensor, hidden_states: torch.Tensor, metadata: DSAtrtllmAttentionMetadata, position_ids: torch.Tensor, indexer_k: torch.Tensor): quant_block_size = metadata.kv_cache_manager.quant_block_size assert quant_block_size == 128, "Only support quant_block_size = 128 for now" + q_and_k, weights = maybe_execute_in_parallel( + lambda: self._qk_projection_and_rope(qr, indexer_k, position_ids), + lambda: self.weights_proj(_to_float(hidden_states)), + self.ln_events[0], + self.ln_events[1], + self.aux_stream, + ) q, k = q_and_k q_fp8, q_scale = q k_fp8, k_scale = k diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index ed3113d5e1..2d68926ed4 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -1346,14 +1346,6 @@ class MLA(nn.Module): Returns: torch.Tensor: The output tensor. """ - - def _qk_projection_and_rope(q: torch.Tensor, k: torch.Tensor, - position_ids: torch.Tensor): - q_b_proj = self.q_b_proj(q) - indexer_q_and_k = self.indexer._qk_projection_and_rope( - q, k, position_ids) - return q_b_proj, indexer_q_and_k - assert self.mha is None and self.mqa is not None, "DSA is only supported in MQA mode" # split q, k, v into context and gen batches num_contexts = attn_metadata.num_contexts @@ -1378,6 +1370,7 @@ class MLA(nn.Module): self.ln_events[1], self.aux_stream, ) + qr = q latent_cache, indexer_k = maybe_execute_in_parallel( lambda: torch.concat([compressed_kv, k_pe], dim=-1), lambda: self.indexer.k_norm(indexer_k), @@ -1385,19 +1378,12 @@ class MLA(nn.Module): self.ln_events[1], self.aux_stream, ) - q_and_k, indexer_weights = maybe_execute_in_parallel( - lambda: _qk_projection_and_rope(q, indexer_k, position_ids), - lambda: self.indexer._weight_proj(hidden_states), - self.ln_events[0], - self.ln_events[1], - self.aux_stream, - ) - q, indxer_q_and_k = q_and_k + q = self.q_b_proj(q) # Indexer topk_indices = self.indexer( - indxer_q_and_k, - indexer_weights, + qr, + hidden_states, attn_metadata, position_ids, indexer_k=indexer_k, # indexer K proj