Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
This commit is contained in:
Fanrong Li 2026-01-05 08:26:40 -08:00
parent 63888cf3a8
commit 2a5d5e9616
2 changed files with 12 additions and 22 deletions

View File

@ -1539,9 +1539,6 @@ class Indexer(nn.Module):
k = self._prep_q_or_k(k_pe, k_nope)
return q, k
def _weight_proj(self, hidden_states: torch.Tensor):
return self.weights_proj(_to_float(hidden_states))
def _prep_q_or_k(self, qk_pe: torch.Tensor, qk_nope: torch.Tensor):
"""Concatenate, rotate, and FP8 quantize for Q or K"""
q_or_k = torch.cat([qk_pe, qk_nope], dim=-1)
@ -1552,12 +1549,19 @@ class Indexer(nn.Module):
return q_or_k
@torch.inference_mode()
def forward(self, q_and_k: torch.Tensor, weights: torch.Tensor,
def forward(self, qr: torch.Tensor, hidden_states: torch.Tensor,
metadata: DSAtrtllmAttentionMetadata,
position_ids: torch.Tensor, indexer_k: torch.Tensor):
quant_block_size = metadata.kv_cache_manager.quant_block_size
assert quant_block_size == 128, "Only support quant_block_size = 128 for now"
q_and_k, weights = maybe_execute_in_parallel(
lambda: self._qk_projection_and_rope(qr, indexer_k, position_ids),
lambda: self.weights_proj(_to_float(hidden_states)),
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
q, k = q_and_k
q_fp8, q_scale = q
k_fp8, k_scale = k

View File

@ -1346,14 +1346,6 @@ class MLA(nn.Module):
Returns:
torch.Tensor: The output tensor.
"""
def _qk_projection_and_rope(q: torch.Tensor, k: torch.Tensor,
position_ids: torch.Tensor):
q_b_proj = self.q_b_proj(q)
indexer_q_and_k = self.indexer._qk_projection_and_rope(
q, k, position_ids)
return q_b_proj, indexer_q_and_k
assert self.mha is None and self.mqa is not None, "DSA is only supported in MQA mode"
# split q, k, v into context and gen batches
num_contexts = attn_metadata.num_contexts
@ -1378,6 +1370,7 @@ class MLA(nn.Module):
self.ln_events[1],
self.aux_stream,
)
qr = q
latent_cache, indexer_k = maybe_execute_in_parallel(
lambda: torch.concat([compressed_kv, k_pe], dim=-1),
lambda: self.indexer.k_norm(indexer_k),
@ -1385,19 +1378,12 @@ class MLA(nn.Module):
self.ln_events[1],
self.aux_stream,
)
q_and_k, indexer_weights = maybe_execute_in_parallel(
lambda: _qk_projection_and_rope(q, indexer_k, position_ids),
lambda: self.indexer._weight_proj(hidden_states),
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
q, indxer_q_and_k = q_and_k
q = self.q_b_proj(q)
# Indexer
topk_indices = self.indexer(
indxer_q_and_k,
indexer_weights,
qr,
hidden_states,
attn_metadata,
position_ids,
indexer_k=indexer_k, # indexer K proj