mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge 384e9e5b01 into 38296a472b
This commit is contained in:
commit
5ec0ffa4bd
@ -1328,7 +1328,6 @@ class Indexer(nn.Module):
|
||||
def sparse_attn_indexer(
|
||||
self,
|
||||
metadata: DSAtrtllmAttentionMetadata,
|
||||
hidden_states: torch.Tensor,
|
||||
q_fp8: torch.Tensor,
|
||||
k_fp8: torch.Tensor,
|
||||
k_scale: torch.Tensor,
|
||||
@ -1344,12 +1343,11 @@ class Indexer(nn.Module):
|
||||
has_prefill = num_contexts > 0
|
||||
num_gen_tokens = num_tokens - num_ctx_tokens
|
||||
|
||||
topk_indices_buffer = torch.empty(
|
||||
(hidden_states.shape[0], self.index_topk),
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
topk_indices_buffer = torch.empty((num_tokens, self.index_topk),
|
||||
dtype=torch.int32,
|
||||
device=q_fp8.device)
|
||||
if not use_custom_topk:
|
||||
topk_indices_buffer[:hidden_states.shape[0]] = -1
|
||||
topk_indices_buffer[:num_tokens] = -1
|
||||
|
||||
if has_prefill and not metadata.skip_indexer_for_ctx_reqs:
|
||||
# Use chunked prefill to reduce memory footprint
|
||||
@ -1525,11 +1523,10 @@ class Indexer(nn.Module):
|
||||
weights = _scale(weights, q_scale, self.weight_scale_factor)
|
||||
return weights
|
||||
|
||||
def _qk_projection_and_rope(self, qr: torch.Tensor, indexer_k: torch.Tensor,
|
||||
def _qk_projection_and_rope(self, qr: torch.Tensor, k: torch.Tensor,
|
||||
position_ids: torch.Tensor):
|
||||
"""Project Q/K and apply RoPE"""
|
||||
q = self.wq_b(qr)
|
||||
k = self.k_norm(indexer_k)
|
||||
q = q.view(-1, self.n_heads, self.head_dim)
|
||||
q_pe, q_nope = q.split([self.rope_dim, self.head_dim - self.rope_dim],
|
||||
dim=-1)
|
||||
@ -1537,7 +1534,9 @@ class Indexer(nn.Module):
|
||||
dim=-1)
|
||||
q_pe, k_pe = self.rotary_emb(position_ids, [q_pe, k_pe.unsqueeze(1)])
|
||||
k_pe = k_pe[:, 0, :]
|
||||
return q_pe, q_nope, k_pe, k_nope
|
||||
q = self._prep_q_or_k(q_pe, q_nope)
|
||||
k = self._prep_q_or_k(k_pe, k_nope)
|
||||
return q, k
|
||||
|
||||
def _prep_q_or_k(self, qk_pe: torch.Tensor, qk_nope: torch.Tensor):
|
||||
"""Concatenate, rotate, and FP8 quantize for Q or K"""
|
||||
@ -1562,14 +1561,7 @@ class Indexer(nn.Module):
|
||||
self.ln_events[1],
|
||||
self.aux_stream,
|
||||
)
|
||||
q_pe, q_nope, k_pe, k_nope = q_and_k
|
||||
q, k = maybe_execute_in_parallel(
|
||||
lambda: self._prep_q_or_k(q_pe, q_nope),
|
||||
lambda: self._prep_q_or_k(k_pe, k_nope),
|
||||
self.ln_events[0],
|
||||
self.ln_events[1],
|
||||
self.aux_stream,
|
||||
)
|
||||
q, k = q_and_k
|
||||
q_fp8, q_scale = q
|
||||
k_fp8, k_scale = k
|
||||
q_fp8 = q_fp8.view(-1, self.n_heads, self.head_dim)
|
||||
@ -1585,8 +1577,8 @@ class Indexer(nn.Module):
|
||||
)
|
||||
|
||||
# Return topk indices buffer for sparse attention [num_tokens, index_topk]
|
||||
return self.sparse_attn_indexer(metadata, hidden_states, q_fp8, k_fp8,
|
||||
k_scale, weights)
|
||||
return self.sparse_attn_indexer(metadata, q_fp8, k_fp8, k_scale,
|
||||
weights)
|
||||
|
||||
|
||||
class DSATrtllmAttention(TrtllmAttention):
|
||||
|
||||
@ -1354,7 +1354,6 @@ class MLA(nn.Module):
|
||||
self.indexer.head_dim
|
||||
], -1)
|
||||
|
||||
# TODO: possibly overlap/fuse q_a_rmsnorm + kv_a_rmsnorm + indexer.k_layernorm?
|
||||
q, compressed_kv = maybe_execute_in_parallel(
|
||||
lambda: self.q_a_layernorm(q),
|
||||
lambda: self.kv_a_layernorm(compressed_kv),
|
||||
@ -1363,9 +1362,14 @@ class MLA(nn.Module):
|
||||
self.aux_stream,
|
||||
)
|
||||
qr = q
|
||||
latent_cache = torch.concat([compressed_kv, k_pe], dim=-1)
|
||||
latent_cache, indexer_k = maybe_execute_in_parallel(
|
||||
lambda: torch.concat([compressed_kv, k_pe], dim=-1),
|
||||
lambda: self.indexer.k_norm(indexer_k),
|
||||
self.ln_events[0],
|
||||
self.ln_events[1],
|
||||
self.aux_stream,
|
||||
)
|
||||
|
||||
# TODO: fuse wq_b + (indexer) wlq here
|
||||
q = self.q_b_proj(q)
|
||||
# Indexer
|
||||
topk_indices = self.indexer(
|
||||
|
||||
@ -1559,8 +1559,7 @@ def test_indexer_chunked_prefill(chunk_size, seq_lens_list, chunking_type):
|
||||
f"K[{chunk.k_token_start}:{chunk.k_token_end}] ({num_k} tokens)")
|
||||
|
||||
indexer._update_k_cache(k_fp8, k_scale, metadata_chunked)
|
||||
topk_indices_chunked = indexer.sparse_attn_indexer(metadata_chunked,
|
||||
hidden_states, q_fp8,
|
||||
topk_indices_chunked = indexer.sparse_attn_indexer(metadata_chunked, q_fp8,
|
||||
k_fp8, k_scale, weights)
|
||||
|
||||
print(f"✓ Chunked execution completed, shape: {topk_indices_chunked.shape}")
|
||||
@ -1592,8 +1591,8 @@ def test_indexer_chunked_prefill(chunk_size, seq_lens_list, chunking_type):
|
||||
|
||||
indexer._update_k_cache(k_fp8, k_scale, metadata_baseline)
|
||||
topk_indices_baseline = indexer.sparse_attn_indexer(metadata_baseline,
|
||||
hidden_states, q_fp8,
|
||||
k_fp8, k_scale, weights)
|
||||
q_fp8, k_fp8, k_scale,
|
||||
weights)
|
||||
|
||||
print(
|
||||
f"✓ Non-chunked execution completed, shape: {topk_indices_baseline.shape}"
|
||||
@ -1866,7 +1865,6 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
|
||||
|
||||
try:
|
||||
topk_indices_custom = indexer.sparse_attn_indexer(metadata_custom,
|
||||
hidden_states,
|
||||
q_fp8,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
@ -1892,7 +1890,6 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
|
||||
Indexer.prepare(metadata_fallback)
|
||||
indexer._update_k_cache(k_fp8, k_scale, metadata_fallback)
|
||||
topk_indices_fallback = indexer.sparse_attn_indexer(metadata_fallback,
|
||||
hidden_states,
|
||||
q_fp8,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
@ -1921,7 +1918,6 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
|
||||
try:
|
||||
topk_indices_skip = indexer.sparse_attn_indexer(
|
||||
metadata_skip,
|
||||
hidden_states,
|
||||
q_fp8,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
@ -2035,7 +2031,6 @@ def test_indexer_prefill_chunked_custom_vs_fallback(batch_size, index_topk,
|
||||
|
||||
try:
|
||||
topk_indices_custom = indexer.sparse_attn_indexer(metadata_custom,
|
||||
hidden_states,
|
||||
q_fp8,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
@ -2055,7 +2050,6 @@ def test_indexer_prefill_chunked_custom_vs_fallback(batch_size, index_topk,
|
||||
Indexer.prepare(metadata_fallback)
|
||||
indexer._update_k_cache(k_fp8, k_scale, metadata_fallback)
|
||||
topk_indices_fallback = indexer.sparse_attn_indexer(metadata_fallback,
|
||||
hidden_states,
|
||||
q_fp8,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
@ -2143,7 +2137,6 @@ def test_indexer_prefill_single_pass_custom_vs_fallback(batch_size, index_topk,
|
||||
|
||||
try:
|
||||
topk_indices_custom = indexer.sparse_attn_indexer(metadata_custom,
|
||||
hidden_states,
|
||||
q_fp8,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
@ -2166,7 +2159,6 @@ def test_indexer_prefill_single_pass_custom_vs_fallback(batch_size, index_topk,
|
||||
metadata_fallback.indexer_prefill_chunks = None
|
||||
|
||||
topk_indices_fallback = indexer.sparse_attn_indexer(metadata_fallback,
|
||||
hidden_states,
|
||||
q_fp8,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
@ -2191,7 +2183,6 @@ def test_indexer_prefill_single_pass_custom_vs_fallback(batch_size, index_topk,
|
||||
|
||||
try:
|
||||
topk_indices_skip = indexer.sparse_attn_indexer(metadata_skip,
|
||||
hidden_states,
|
||||
q_fp8,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
@ -2302,7 +2293,6 @@ def test_indexer_topk_multi_request_with_different_cache(enable_indexer_skip):
|
||||
|
||||
# Test custom kernel
|
||||
topk_custom = indexer.sparse_attn_indexer(metadata,
|
||||
hidden_states,
|
||||
q_fp8,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
@ -2311,7 +2301,6 @@ def test_indexer_topk_multi_request_with_different_cache(enable_indexer_skip):
|
||||
|
||||
# Test fallback
|
||||
topk_fallback = indexer.sparse_attn_indexer(metadata,
|
||||
hidden_states,
|
||||
q_fp8,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
@ -2337,7 +2326,6 @@ def test_indexer_topk_multi_request_with_different_cache(enable_indexer_skip):
|
||||
Indexer.prepare(metadata_skip)
|
||||
indexer._update_k_cache(k_fp8, k_scale, metadata_skip)
|
||||
topk_indices_skip = indexer.sparse_attn_indexer(metadata_skip,
|
||||
hidden_states,
|
||||
q_fp8,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user