Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
This commit is contained in:
Fanrong Li 2026-01-05 04:26:44 -08:00
parent 3519a1a685
commit 152d46b566
2 changed files with 7 additions and 11 deletions

View File

@ -1329,7 +1329,6 @@ class Indexer(nn.Module):
def sparse_attn_indexer(
self,
metadata: DSAtrtllmAttentionMetadata,
hidden_states: torch.Tensor,
q_fp8: torch.Tensor,
k_fp8: torch.Tensor,
k_scale: torch.Tensor,
@ -1345,12 +1344,11 @@ class Indexer(nn.Module):
has_prefill = num_contexts > 0
num_gen_tokens = num_tokens - num_ctx_tokens
topk_indices_buffer = torch.empty(
(hidden_states.shape[0], self.index_topk),
dtype=torch.int32,
device=hidden_states.device)
topk_indices_buffer = torch.empty((num_tokens, self.index_topk),
dtype=torch.int32,
device=q_fp8.device)
if not use_custom_topk:
topk_indices_buffer[:hidden_states.shape[0]] = -1
topk_indices_buffer[:num_tokens] = -1
if has_prefill and not metadata.skip_indexer_for_ctx_reqs:
# Use chunked prefill to reduce memory footprint
@ -1581,8 +1579,8 @@ class Indexer(nn.Module):
)
# Return topk indices buffer for sparse attention [num_tokens, index_topk]
return self.sparse_attn_indexer(metadata, hidden_states, q_fp8, k_fp8,
k_scale, weights)
return self.sparse_attn_indexer(metadata, q_fp8, k_fp8, k_scale,
weights)
class DSATrtllmAttention(TrtllmAttention):

View File

@ -1392,9 +1392,7 @@ class MLA(nn.Module):
self.ln_events[1],
self.aux_stream,
)
q, indexer_q_pe, indexer_q_nope, indexer_k_pe, indexer_k_nope = q_and_k
indxer_q_and_k = (indexer_q_pe, indexer_q_nope, indexer_k_pe,
indexer_k_nope)
q, indxer_q_and_k = q_and_k
# Indexer
topk_indices = self.indexer(