diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index 71460f8b74..0bcad59624 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -1329,7 +1329,6 @@ class Indexer(nn.Module): def sparse_attn_indexer( self, metadata: DSAtrtllmAttentionMetadata, - hidden_states: torch.Tensor, q_fp8: torch.Tensor, k_fp8: torch.Tensor, k_scale: torch.Tensor, @@ -1345,12 +1344,11 @@ class Indexer(nn.Module): has_prefill = num_contexts > 0 num_gen_tokens = num_tokens - num_ctx_tokens - topk_indices_buffer = torch.empty( - (hidden_states.shape[0], self.index_topk), - dtype=torch.int32, - device=hidden_states.device) + topk_indices_buffer = torch.empty((num_tokens, self.index_topk), + dtype=torch.int32, + device=q_fp8.device) if not use_custom_topk: - topk_indices_buffer[:hidden_states.shape[0]] = -1 + topk_indices_buffer[:num_tokens] = -1 if has_prefill and not metadata.skip_indexer_for_ctx_reqs: # Use chunked prefill to reduce memory footprint @@ -1581,8 +1579,8 @@ class Indexer(nn.Module): ) # Return topk indices buffer for sparse attention [num_tokens, index_topk] - return self.sparse_attn_indexer(metadata, hidden_states, q_fp8, k_fp8, - k_scale, weights) + return self.sparse_attn_indexer(metadata, q_fp8, k_fp8, k_scale, + weights) class DSATrtllmAttention(TrtllmAttention): diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index c68cc28af5..ed3113d5e1 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -1392,9 +1392,7 @@ class MLA(nn.Module): self.ln_events[1], self.aux_stream, ) - q, indexer_q_pe, indexer_q_nope, indexer_k_pe, indexer_k_nope = q_and_k - indxer_q_and_k = (indexer_q_pe, indexer_q_nope, indexer_k_pe, - indexer_k_nope) + q, indxer_q_and_k = q_and_k # Indexer topk_indices = self.indexer(