mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-31 16:21:07 +08:00
fix.
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
This commit is contained in:
parent
152d46b566
commit
63888cf3a8
@ -1535,7 +1535,9 @@ class Indexer(nn.Module):
|
||||
dim=-1)
|
||||
q_pe, k_pe = self.rotary_emb(position_ids, [q_pe, k_pe.unsqueeze(1)])
|
||||
k_pe = k_pe[:, 0, :]
|
||||
return q_pe, q_nope, k_pe, k_nope
|
||||
q = self._prep_q_or_k(q_pe, q_nope)
|
||||
k = self._prep_q_or_k(k_pe, k_nope)
|
||||
return q, k
|
||||
|
||||
def _weight_proj(self, hidden_states: torch.Tensor):
|
||||
return self.weights_proj(_to_float(hidden_states))
|
||||
@ -1556,14 +1558,7 @@ class Indexer(nn.Module):
|
||||
quant_block_size = metadata.kv_cache_manager.quant_block_size
|
||||
assert quant_block_size == 128, "Only support quant_block_size = 128 for now"
|
||||
|
||||
q_pe, q_nope, k_pe, k_nope = q_and_k
|
||||
q, k = maybe_execute_in_parallel(
|
||||
lambda: self._prep_q_or_k(q_pe, q_nope),
|
||||
lambda: self._prep_q_or_k(k_pe, k_nope),
|
||||
self.ln_events[0],
|
||||
self.ln_events[1],
|
||||
self.aux_stream,
|
||||
)
|
||||
q, k = q_and_k
|
||||
q_fp8, q_scale = q
|
||||
k_fp8, k_scale = k
|
||||
q_fp8 = q_fp8.view(-1, self.n_heads, self.head_dim)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user