mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][fix] Avoid unnecessary concat in attn_output_gate case. (#8094)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
parent
6c4cc4c8b2
commit
bd740c9ba6
@ -538,10 +538,9 @@ class Attention(nn.Module):
|
||||
t.reshape(*orig_shape, -1) for t in torch.chunk(
|
||||
q_gate.view(*orig_shape, self.num_heads, -1), 2, dim=-1)
|
||||
]
|
||||
### TODO: avoid the redundant split and concat
|
||||
qkv = torch.concat([q, k, v], dim=-1)
|
||||
else:
|
||||
q, k, v = qkv, None, None
|
||||
|
||||
q, k, v = qkv, None, None
|
||||
q, k, v = self.apply_rope(q, k, v, position_ids)
|
||||
q, k, v = self.convert_qkv(q, k, v)
|
||||
|
||||
|
||||
@ -249,6 +249,7 @@ class QKNormRoPEAttention(Attention):
|
||||
else:
|
||||
return q, k, v
|
||||
|
||||
assert k is None and v is None, "The input should be a concatenated qkv tensor to apply_qk_norm_rope"
|
||||
qkv = q
|
||||
if k is not None and v is not None:
|
||||
qkv = torch.concat([q, k, v], dim=-1)
|
||||
return self.apply_qk_norm_rope(qkv, position_ids)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user