From bd740c9ba6c96adbc1d5255e737b4635f66bf512 Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Tue, 14 Oct 2025 03:59:40 +0800 Subject: [PATCH] [None][fix] Avoid unnecessary concat in attn_output_gate case. (#8094) Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> --- tensorrt_llm/_torch/modules/attention.py | 5 ++--- tensorrt_llm/_torch/modules/qk_norm_attention.py | 3 ++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 8e2f042323..eb0071e2b5 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -538,10 +538,9 @@ class Attention(nn.Module): t.reshape(*orig_shape, -1) for t in torch.chunk( q_gate.view(*orig_shape, self.num_heads, -1), 2, dim=-1) ] - ### TODO: avoid the redundant split and concat - qkv = torch.concat([q, k, v], dim=-1) + else: + q, k, v = qkv, None, None - q, k, v = qkv, None, None q, k, v = self.apply_rope(q, k, v, position_ids) q, k, v = self.convert_qkv(q, k, v) diff --git a/tensorrt_llm/_torch/modules/qk_norm_attention.py b/tensorrt_llm/_torch/modules/qk_norm_attention.py index b116394989..e69fb33d1d 100644 --- a/tensorrt_llm/_torch/modules/qk_norm_attention.py +++ b/tensorrt_llm/_torch/modules/qk_norm_attention.py @@ -249,6 +249,7 @@ class QKNormRoPEAttention(Attention): else: return q, k, v - assert k is None and v is None, "The input should be a concatenated qkv tensor to apply_qk_norm_rope" qkv = q + if k is not None and v is not None: + qkv = torch.concat([q, k, v], dim=-1) return self.apply_qk_norm_rope(qkv, position_ids)