mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-31 08:11:27 +08:00
optimize the multi-stream for DSA.
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
This commit is contained in:
parent
6507087c3f
commit
3519a1a685
@ -1526,11 +1526,10 @@ class Indexer(nn.Module):
|
||||
weights = _scale(weights, q_scale, self.weight_scale_factor)
|
||||
return weights
|
||||
|
||||
def _qk_projection_and_rope(self, qr: torch.Tensor, indexer_k: torch.Tensor,
|
||||
def _qk_projection_and_rope(self, qr: torch.Tensor, k: torch.Tensor,
|
||||
position_ids: torch.Tensor):
|
||||
"""Project Q/K and apply RoPE"""
|
||||
q = self.wq_b(qr)
|
||||
k = self.k_norm(indexer_k)
|
||||
q = q.view(-1, self.n_heads, self.head_dim)
|
||||
q_pe, q_nope = q.split([self.rope_dim, self.head_dim - self.rope_dim],
|
||||
dim=-1)
|
||||
@ -1540,6 +1539,9 @@ class Indexer(nn.Module):
|
||||
k_pe = k_pe[:, 0, :]
|
||||
return q_pe, q_nope, k_pe, k_nope
|
||||
|
||||
def _weight_proj(self, hidden_states: torch.Tensor):
|
||||
return self.weights_proj(_to_float(hidden_states))
|
||||
|
||||
def _prep_q_or_k(self, qk_pe: torch.Tensor, qk_nope: torch.Tensor):
|
||||
"""Concatenate, rotate, and FP8 quantize for Q or K"""
|
||||
q_or_k = torch.cat([qk_pe, qk_nope], dim=-1)
|
||||
@ -1550,19 +1552,12 @@ class Indexer(nn.Module):
|
||||
return q_or_k
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, qr: torch.Tensor, hidden_states: torch.Tensor,
|
||||
def forward(self, q_and_k: torch.Tensor, weights: torch.Tensor,
|
||||
metadata: DSAtrtllmAttentionMetadata,
|
||||
position_ids: torch.Tensor, indexer_k: torch.Tensor):
|
||||
quant_block_size = metadata.kv_cache_manager.quant_block_size
|
||||
assert quant_block_size == 128, "Only support quant_block_size = 128 for now"
|
||||
|
||||
q_and_k, weights = maybe_execute_in_parallel(
|
||||
lambda: self._qk_projection_and_rope(qr, indexer_k, position_ids),
|
||||
lambda: self.weights_proj(_to_float(hidden_states)),
|
||||
self.ln_events[0],
|
||||
self.ln_events[1],
|
||||
self.aux_stream,
|
||||
)
|
||||
q_pe, q_nope, k_pe, k_nope = q_and_k
|
||||
q, k = maybe_execute_in_parallel(
|
||||
lambda: self._prep_q_or_k(q_pe, q_nope),
|
||||
|
||||
@ -1346,6 +1346,14 @@ class MLA(nn.Module):
|
||||
Returns:
|
||||
torch.Tensor: The output tensor.
|
||||
"""
|
||||
|
||||
def _qk_projection_and_rope(q: torch.Tensor, k: torch.Tensor,
|
||||
position_ids: torch.Tensor):
|
||||
q_b_proj = self.q_b_proj(q)
|
||||
indexer_q_and_k = self.indexer._qk_projection_and_rope(
|
||||
q, k, position_ids)
|
||||
return q_b_proj, indexer_q_and_k
|
||||
|
||||
assert self.mha is None and self.mqa is not None, "DSA is only supported in MQA mode"
|
||||
# split q, k, v into context and gen batches
|
||||
num_contexts = attn_metadata.num_contexts
|
||||
@ -1363,7 +1371,6 @@ class MLA(nn.Module):
|
||||
self.indexer.head_dim
|
||||
], -1)
|
||||
|
||||
# TODO: possibly overlap/fuse q_a_rmsnorm + kv_a_rmsnorm + indexer.k_layernorm?
|
||||
q, compressed_kv = maybe_execute_in_parallel(
|
||||
lambda: self.q_a_layernorm(q),
|
||||
lambda: self.kv_a_layernorm(compressed_kv),
|
||||
@ -1371,15 +1378,28 @@ class MLA(nn.Module):
|
||||
self.ln_events[1],
|
||||
self.aux_stream,
|
||||
)
|
||||
qr = q
|
||||
latent_cache = torch.concat([compressed_kv, k_pe], dim=-1)
|
||||
latent_cache, indexer_k = maybe_execute_in_parallel(
|
||||
lambda: torch.concat([compressed_kv, k_pe], dim=-1),
|
||||
lambda: self.indexer.k_norm(indexer_k),
|
||||
self.ln_events[0],
|
||||
self.ln_events[1],
|
||||
self.aux_stream,
|
||||
)
|
||||
q_and_k, indexer_weights = maybe_execute_in_parallel(
|
||||
lambda: _qk_projection_and_rope(q, indexer_k, position_ids),
|
||||
lambda: self.indexer._weight_proj(hidden_states),
|
||||
self.ln_events[0],
|
||||
self.ln_events[1],
|
||||
self.aux_stream,
|
||||
)
|
||||
q, indexer_q_pe, indexer_q_nope, indexer_k_pe, indexer_k_nope = q_and_k
|
||||
indxer_q_and_k = (indexer_q_pe, indexer_q_nope, indexer_k_pe,
|
||||
indexer_k_nope)
|
||||
|
||||
# TODO: fuse wq_b + (indexer) wlq here
|
||||
q = self.q_b_proj(q)
|
||||
# Indexer
|
||||
topk_indices = self.indexer(
|
||||
qr,
|
||||
hidden_states,
|
||||
indxer_q_and_k,
|
||||
indexer_weights,
|
||||
attn_metadata,
|
||||
position_ids,
|
||||
indexer_k=indexer_k, # indexer K proj
|
||||
|
||||
Loading…
Reference in New Issue
Block a user