[TRTLLM-10309] [feat] Optimize qk rope/nope concat for DSA (#10571)

Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
This commit is contained in:
Kaiyu Xie 2026-01-09 22:50:57 +08:00 committed by GitHub
parent ced88424ef
commit 1c69aad850
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 14 additions and 13 deletions

View File

@ -17,7 +17,7 @@ from tensorrt_llm._torch.modules.multi_stream_utils import \
maybe_execute_in_parallel
from tensorrt_llm._torch.modules.rotary_embedding import RotaryEmbedding
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm._torch.utils import maybe_compile
from tensorrt_llm._torch.utils import maybe_compile, maybe_compiled_cat
from tensorrt_llm._utils import get_size_in_bytes, get_sm_version
from tensorrt_llm.bindings import DataType
from tensorrt_llm.bindings.executor import KvCacheConfig
@ -1541,7 +1541,7 @@ class Indexer(nn.Module):
def _prep_q_or_k(self, qk_pe: torch.Tensor, qk_nope: torch.Tensor):
"""Concatenate, rotate, and FP8 quantize for Q or K"""
q_or_k = torch.cat([qk_pe, qk_nope], dim=-1)
q_or_k = maybe_compiled_cat([qk_pe, qk_nope], dim=-1)
q_or_k = rotate_activation(q_or_k)
q_or_k = q_or_k.view(-1, self.head_dim)
q_or_k = fp8_utils.fp8_quantize_1x128_sf_transpose(

View File

@ -25,7 +25,8 @@ from ..distributed import AllReduceParams, HelixAllToAllNative, alltoall_helix
from ..model_config import ModelConfig
from ..peft.lora.layer import LoraLayer, LoraModuleType
from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs,
is_torch_compiling, maybe_compile)
is_torch_compiling, maybe_compiled_cat,
maybe_compiled_copy_)
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
from .multi_stream_utils import maybe_execute_in_parallel
from .rms_norm import RMSNorm
@ -78,16 +79,6 @@ def extract_extra_attrs(layer_idx: str, attn_type: str):
return metadata, attn_layer
@maybe_compile
def maybe_compiled_copy_(dst, src):
dst.copy_(src)
@maybe_compile
def maybe_compiled_cat(tensors, dim):
return torch.cat(tensors, dim)
def create_attn_outputs_impl(q: torch.Tensor, attention_mask: str,
layer_idx: str) -> List[torch.Tensor]:
metadata, attn_layer = extract_extra_attrs(layer_idx, "attn")

View File

@ -404,3 +404,13 @@ def split(x: torch.Tensor,
def relu2(x: torch.Tensor) -> torch.Tensor:
return torch.square(F.relu(x))
@maybe_compile
def maybe_compiled_copy_(dst, src):
dst.copy_(src)
@maybe_compile
def maybe_compiled_cat(tensors, dim):
return torch.cat(tensors, dim)