diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index 0bcad59624..361473b4d1 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -1535,7 +1535,9 @@ class Indexer(nn.Module): dim=-1) q_pe, k_pe = self.rotary_emb(position_ids, [q_pe, k_pe.unsqueeze(1)]) k_pe = k_pe[:, 0, :] - return q_pe, q_nope, k_pe, k_nope + q = self._prep_q_or_k(q_pe, q_nope) + k = self._prep_q_or_k(k_pe, k_nope) + return q, k def _weight_proj(self, hidden_states: torch.Tensor): return self.weights_proj(_to_float(hidden_states)) @@ -1556,14 +1558,7 @@ class Indexer(nn.Module): quant_block_size = metadata.kv_cache_manager.quant_block_size assert quant_block_size == 128, "Only support quant_block_size = 128 for now" - q_pe, q_nope, k_pe, k_nope = q_and_k - q, k = maybe_execute_in_parallel( - lambda: self._prep_q_or_k(q_pe, q_nope), - lambda: self._prep_q_or_k(k_pe, k_nope), - self.ln_events[0], - self.ln_events[1], - self.aux_stream, - ) + q, k = q_and_k q_fp8, q_scale = q k_fp8, k_scale = k q_fp8 = q_fp8.view(-1, self.n_heads, self.head_dim)