[None][fix] Avoid unnecessary concat in attn_output_gate case. (#8094)

Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
Yuxian Qiu 2025-10-14 03:59:40 +08:00 committed by GitHub
parent 6c4cc4c8b2
commit bd740c9ba6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 4 deletions

View File

@ -538,10 +538,9 @@ class Attention(nn.Module):
t.reshape(*orig_shape, -1) for t in torch.chunk(
q_gate.view(*orig_shape, self.num_heads, -1), 2, dim=-1)
]
### TODO: avoid the redundant split and concat
qkv = torch.concat([q, k, v], dim=-1)
else:
q, k, v = qkv, None, None
q, k, v = qkv, None, None
q, k, v = self.apply_rope(q, k, v, position_ids)
q, k, v = self.convert_qkv(q, k, v)

View File

@ -249,6 +249,7 @@ class QKNormRoPEAttention(Attention):
else:
return q, k, v
assert k is None and v is None, "The input should be a concatenated qkv tensor to apply_qk_norm_rope"
qkv = q
if k is not None and v is not None:
qkv = torch.concat([q, k, v], dim=-1)
return self.apply_qk_norm_rope(qkv, position_ids)