mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-10309] [feat] Optimize qk rope/nope concat for DSA (#10571)
Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
This commit is contained in:
parent
ced88424ef
commit
1c69aad850
@ -17,7 +17,7 @@ from tensorrt_llm._torch.modules.multi_stream_utils import \
|
||||
maybe_execute_in_parallel
|
||||
from tensorrt_llm._torch.modules.rotary_embedding import RotaryEmbedding
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm._torch.utils import maybe_compile
|
||||
from tensorrt_llm._torch.utils import maybe_compile, maybe_compiled_cat
|
||||
from tensorrt_llm._utils import get_size_in_bytes, get_sm_version
|
||||
from tensorrt_llm.bindings import DataType
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
@ -1541,7 +1541,7 @@ class Indexer(nn.Module):
|
||||
|
||||
def _prep_q_or_k(self, qk_pe: torch.Tensor, qk_nope: torch.Tensor):
|
||||
"""Concatenate, rotate, and FP8 quantize for Q or K"""
|
||||
q_or_k = torch.cat([qk_pe, qk_nope], dim=-1)
|
||||
q_or_k = maybe_compiled_cat([qk_pe, qk_nope], dim=-1)
|
||||
q_or_k = rotate_activation(q_or_k)
|
||||
q_or_k = q_or_k.view(-1, self.head_dim)
|
||||
q_or_k = fp8_utils.fp8_quantize_1x128_sf_transpose(
|
||||
|
||||
@ -25,7 +25,8 @@ from ..distributed import AllReduceParams, HelixAllToAllNative, alltoall_helix
|
||||
from ..model_config import ModelConfig
|
||||
from ..peft.lora.layer import LoraLayer, LoraModuleType
|
||||
from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs,
|
||||
is_torch_compiling, maybe_compile)
|
||||
is_torch_compiling, maybe_compiled_cat,
|
||||
maybe_compiled_copy_)
|
||||
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
|
||||
from .multi_stream_utils import maybe_execute_in_parallel
|
||||
from .rms_norm import RMSNorm
|
||||
@ -78,16 +79,6 @@ def extract_extra_attrs(layer_idx: str, attn_type: str):
|
||||
return metadata, attn_layer
|
||||
|
||||
|
||||
@maybe_compile
|
||||
def maybe_compiled_copy_(dst, src):
|
||||
dst.copy_(src)
|
||||
|
||||
|
||||
@maybe_compile
|
||||
def maybe_compiled_cat(tensors, dim):
|
||||
return torch.cat(tensors, dim)
|
||||
|
||||
|
||||
def create_attn_outputs_impl(q: torch.Tensor, attention_mask: str,
|
||||
layer_idx: str) -> List[torch.Tensor]:
|
||||
metadata, attn_layer = extract_extra_attrs(layer_idx, "attn")
|
||||
|
||||
@ -404,3 +404,13 @@ def split(x: torch.Tensor,
|
||||
|
||||
def relu2(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.square(F.relu(x))
|
||||
|
||||
|
||||
@maybe_compile
|
||||
def maybe_compiled_copy_(dst, src):
|
||||
dst.copy_(src)
|
||||
|
||||
|
||||
@maybe_compile
|
||||
def maybe_compiled_cat(tensors, dim):
|
||||
return torch.cat(tensors, dim)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user