This commit is contained in:
Fanrong Li 2026-01-13 19:17:08 +08:00 committed by GitHub
commit 5ec0ffa4bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 21 additions and 37 deletions

View File

@ -1328,7 +1328,6 @@ class Indexer(nn.Module):
def sparse_attn_indexer(
self,
metadata: DSAtrtllmAttentionMetadata,
hidden_states: torch.Tensor,
q_fp8: torch.Tensor,
k_fp8: torch.Tensor,
k_scale: torch.Tensor,
@ -1344,12 +1343,11 @@ class Indexer(nn.Module):
has_prefill = num_contexts > 0
num_gen_tokens = num_tokens - num_ctx_tokens
topk_indices_buffer = torch.empty(
(hidden_states.shape[0], self.index_topk),
dtype=torch.int32,
device=hidden_states.device)
topk_indices_buffer = torch.empty((num_tokens, self.index_topk),
dtype=torch.int32,
device=q_fp8.device)
if not use_custom_topk:
topk_indices_buffer[:hidden_states.shape[0]] = -1
topk_indices_buffer[:num_tokens] = -1
if has_prefill and not metadata.skip_indexer_for_ctx_reqs:
# Use chunked prefill to reduce memory footprint
@ -1525,11 +1523,10 @@ class Indexer(nn.Module):
weights = _scale(weights, q_scale, self.weight_scale_factor)
return weights
def _qk_projection_and_rope(self, qr: torch.Tensor, indexer_k: torch.Tensor,
def _qk_projection_and_rope(self, qr: torch.Tensor, k: torch.Tensor,
position_ids: torch.Tensor):
"""Project Q/K and apply RoPE"""
q = self.wq_b(qr)
k = self.k_norm(indexer_k)
q = q.view(-1, self.n_heads, self.head_dim)
q_pe, q_nope = q.split([self.rope_dim, self.head_dim - self.rope_dim],
dim=-1)
@ -1537,7 +1534,9 @@ class Indexer(nn.Module):
dim=-1)
q_pe, k_pe = self.rotary_emb(position_ids, [q_pe, k_pe.unsqueeze(1)])
k_pe = k_pe[:, 0, :]
return q_pe, q_nope, k_pe, k_nope
q = self._prep_q_or_k(q_pe, q_nope)
k = self._prep_q_or_k(k_pe, k_nope)
return q, k
def _prep_q_or_k(self, qk_pe: torch.Tensor, qk_nope: torch.Tensor):
"""Concatenate, rotate, and FP8 quantize for Q or K"""
@ -1562,14 +1561,7 @@ class Indexer(nn.Module):
self.ln_events[1],
self.aux_stream,
)
q_pe, q_nope, k_pe, k_nope = q_and_k
q, k = maybe_execute_in_parallel(
lambda: self._prep_q_or_k(q_pe, q_nope),
lambda: self._prep_q_or_k(k_pe, k_nope),
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
q, k = q_and_k
q_fp8, q_scale = q
k_fp8, k_scale = k
q_fp8 = q_fp8.view(-1, self.n_heads, self.head_dim)
@ -1585,8 +1577,8 @@ class Indexer(nn.Module):
)
# Return topk indices buffer for sparse attention [num_tokens, index_topk]
return self.sparse_attn_indexer(metadata, hidden_states, q_fp8, k_fp8,
k_scale, weights)
return self.sparse_attn_indexer(metadata, q_fp8, k_fp8, k_scale,
weights)
class DSATrtllmAttention(TrtllmAttention):

View File

@ -1354,7 +1354,6 @@ class MLA(nn.Module):
self.indexer.head_dim
], -1)
# TODO: possibly overlap/fuse q_a_rmsnorm + kv_a_rmsnorm + indexer.k_layernorm?
q, compressed_kv = maybe_execute_in_parallel(
lambda: self.q_a_layernorm(q),
lambda: self.kv_a_layernorm(compressed_kv),
@ -1363,9 +1362,14 @@ class MLA(nn.Module):
self.aux_stream,
)
qr = q
latent_cache = torch.concat([compressed_kv, k_pe], dim=-1)
latent_cache, indexer_k = maybe_execute_in_parallel(
lambda: torch.concat([compressed_kv, k_pe], dim=-1),
lambda: self.indexer.k_norm(indexer_k),
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
# TODO: fuse wq_b + (indexer) wlq here
q = self.q_b_proj(q)
# Indexer
topk_indices = self.indexer(

View File

@ -1559,8 +1559,7 @@ def test_indexer_chunked_prefill(chunk_size, seq_lens_list, chunking_type):
f"K[{chunk.k_token_start}:{chunk.k_token_end}] ({num_k} tokens)")
indexer._update_k_cache(k_fp8, k_scale, metadata_chunked)
topk_indices_chunked = indexer.sparse_attn_indexer(metadata_chunked,
hidden_states, q_fp8,
topk_indices_chunked = indexer.sparse_attn_indexer(metadata_chunked, q_fp8,
k_fp8, k_scale, weights)
print(f"✓ Chunked execution completed, shape: {topk_indices_chunked.shape}")
@ -1592,8 +1591,8 @@ def test_indexer_chunked_prefill(chunk_size, seq_lens_list, chunking_type):
indexer._update_k_cache(k_fp8, k_scale, metadata_baseline)
topk_indices_baseline = indexer.sparse_attn_indexer(metadata_baseline,
hidden_states, q_fp8,
k_fp8, k_scale, weights)
q_fp8, k_fp8, k_scale,
weights)
print(
f"✓ Non-chunked execution completed, shape: {topk_indices_baseline.shape}"
@ -1866,7 +1865,6 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
try:
topk_indices_custom = indexer.sparse_attn_indexer(metadata_custom,
hidden_states,
q_fp8,
k_fp8,
k_scale,
@ -1892,7 +1890,6 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
Indexer.prepare(metadata_fallback)
indexer._update_k_cache(k_fp8, k_scale, metadata_fallback)
topk_indices_fallback = indexer.sparse_attn_indexer(metadata_fallback,
hidden_states,
q_fp8,
k_fp8,
k_scale,
@ -1921,7 +1918,6 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
try:
topk_indices_skip = indexer.sparse_attn_indexer(
metadata_skip,
hidden_states,
q_fp8,
k_fp8,
k_scale,
@ -2035,7 +2031,6 @@ def test_indexer_prefill_chunked_custom_vs_fallback(batch_size, index_topk,
try:
topk_indices_custom = indexer.sparse_attn_indexer(metadata_custom,
hidden_states,
q_fp8,
k_fp8,
k_scale,
@ -2055,7 +2050,6 @@ def test_indexer_prefill_chunked_custom_vs_fallback(batch_size, index_topk,
Indexer.prepare(metadata_fallback)
indexer._update_k_cache(k_fp8, k_scale, metadata_fallback)
topk_indices_fallback = indexer.sparse_attn_indexer(metadata_fallback,
hidden_states,
q_fp8,
k_fp8,
k_scale,
@ -2143,7 +2137,6 @@ def test_indexer_prefill_single_pass_custom_vs_fallback(batch_size, index_topk,
try:
topk_indices_custom = indexer.sparse_attn_indexer(metadata_custom,
hidden_states,
q_fp8,
k_fp8,
k_scale,
@ -2166,7 +2159,6 @@ def test_indexer_prefill_single_pass_custom_vs_fallback(batch_size, index_topk,
metadata_fallback.indexer_prefill_chunks = None
topk_indices_fallback = indexer.sparse_attn_indexer(metadata_fallback,
hidden_states,
q_fp8,
k_fp8,
k_scale,
@ -2191,7 +2183,6 @@ def test_indexer_prefill_single_pass_custom_vs_fallback(batch_size, index_topk,
try:
topk_indices_skip = indexer.sparse_attn_indexer(metadata_skip,
hidden_states,
q_fp8,
k_fp8,
k_scale,
@ -2302,7 +2293,6 @@ def test_indexer_topk_multi_request_with_different_cache(enable_indexer_skip):
# Test custom kernel
topk_custom = indexer.sparse_attn_indexer(metadata,
hidden_states,
q_fp8,
k_fp8,
k_scale,
@ -2311,7 +2301,6 @@ def test_indexer_topk_multi_request_with_different_cache(enable_indexer_skip):
# Test fallback
topk_fallback = indexer.sparse_attn_indexer(metadata,
hidden_states,
q_fp8,
k_fp8,
k_scale,
@ -2337,7 +2326,6 @@ def test_indexer_topk_multi_request_with_different_cache(enable_indexer_skip):
Indexer.prepare(metadata_skip)
indexer._update_k_cache(k_fp8, k_scale, metadata_skip)
topk_indices_skip = indexer.sparse_attn_indexer(metadata_skip,
hidden_states,
q_fp8,
k_fp8,
k_scale,