mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-01 16:51:11 +08:00
fix.
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
This commit is contained in:
parent
3519a1a685
commit
152d46b566
@ -1329,7 +1329,6 @@ class Indexer(nn.Module):
|
||||
def sparse_attn_indexer(
|
||||
self,
|
||||
metadata: DSAtrtllmAttentionMetadata,
|
||||
hidden_states: torch.Tensor,
|
||||
q_fp8: torch.Tensor,
|
||||
k_fp8: torch.Tensor,
|
||||
k_scale: torch.Tensor,
|
||||
@ -1345,12 +1344,11 @@ class Indexer(nn.Module):
|
||||
has_prefill = num_contexts > 0
|
||||
num_gen_tokens = num_tokens - num_ctx_tokens
|
||||
|
||||
topk_indices_buffer = torch.empty(
|
||||
(hidden_states.shape[0], self.index_topk),
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
topk_indices_buffer = torch.empty((num_tokens, self.index_topk),
|
||||
dtype=torch.int32,
|
||||
device=q_fp8.device)
|
||||
if not use_custom_topk:
|
||||
topk_indices_buffer[:hidden_states.shape[0]] = -1
|
||||
topk_indices_buffer[:num_tokens] = -1
|
||||
|
||||
if has_prefill and not metadata.skip_indexer_for_ctx_reqs:
|
||||
# Use chunked prefill to reduce memory footprint
|
||||
@ -1581,8 +1579,8 @@ class Indexer(nn.Module):
|
||||
)
|
||||
|
||||
# Return topk indices buffer for sparse attention [num_tokens, index_topk]
|
||||
return self.sparse_attn_indexer(metadata, hidden_states, q_fp8, k_fp8,
|
||||
k_scale, weights)
|
||||
return self.sparse_attn_indexer(metadata, q_fp8, k_fp8, k_scale,
|
||||
weights)
|
||||
|
||||
|
||||
class DSATrtllmAttention(TrtllmAttention):
|
||||
|
||||
@ -1392,9 +1392,7 @@ class MLA(nn.Module):
|
||||
self.ln_events[1],
|
||||
self.aux_stream,
|
||||
)
|
||||
q, indexer_q_pe, indexer_q_nope, indexer_k_pe, indexer_k_nope = q_and_k
|
||||
indxer_q_and_k = (indexer_q_pe, indexer_q_nope, indexer_k_pe,
|
||||
indexer_k_nope)
|
||||
q, indxer_q_and_k = q_and_k
|
||||
|
||||
# Indexer
|
||||
topk_indices = self.indexer(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user