mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-10309] [feat] Optimize qk rope/nope concat for DSA (#10571)
Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
This commit is contained in:
parent
ced88424ef
commit
1c69aad850
@ -17,7 +17,7 @@ from tensorrt_llm._torch.modules.multi_stream_utils import \
|
|||||||
maybe_execute_in_parallel
|
maybe_execute_in_parallel
|
||||||
from tensorrt_llm._torch.modules.rotary_embedding import RotaryEmbedding
|
from tensorrt_llm._torch.modules.rotary_embedding import RotaryEmbedding
|
||||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||||
from tensorrt_llm._torch.utils import maybe_compile
|
from tensorrt_llm._torch.utils import maybe_compile, maybe_compiled_cat
|
||||||
from tensorrt_llm._utils import get_size_in_bytes, get_sm_version
|
from tensorrt_llm._utils import get_size_in_bytes, get_sm_version
|
||||||
from tensorrt_llm.bindings import DataType
|
from tensorrt_llm.bindings import DataType
|
||||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||||
@ -1541,7 +1541,7 @@ class Indexer(nn.Module):
|
|||||||
|
|
||||||
def _prep_q_or_k(self, qk_pe: torch.Tensor, qk_nope: torch.Tensor):
|
def _prep_q_or_k(self, qk_pe: torch.Tensor, qk_nope: torch.Tensor):
|
||||||
"""Concatenate, rotate, and FP8 quantize for Q or K"""
|
"""Concatenate, rotate, and FP8 quantize for Q or K"""
|
||||||
q_or_k = torch.cat([qk_pe, qk_nope], dim=-1)
|
q_or_k = maybe_compiled_cat([qk_pe, qk_nope], dim=-1)
|
||||||
q_or_k = rotate_activation(q_or_k)
|
q_or_k = rotate_activation(q_or_k)
|
||||||
q_or_k = q_or_k.view(-1, self.head_dim)
|
q_or_k = q_or_k.view(-1, self.head_dim)
|
||||||
q_or_k = fp8_utils.fp8_quantize_1x128_sf_transpose(
|
q_or_k = fp8_utils.fp8_quantize_1x128_sf_transpose(
|
||||||
|
|||||||
@ -25,7 +25,8 @@ from ..distributed import AllReduceParams, HelixAllToAllNative, alltoall_helix
|
|||||||
from ..model_config import ModelConfig
|
from ..model_config import ModelConfig
|
||||||
from ..peft.lora.layer import LoraLayer, LoraModuleType
|
from ..peft.lora.layer import LoraLayer, LoraModuleType
|
||||||
from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs,
|
from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs,
|
||||||
is_torch_compiling, maybe_compile)
|
is_torch_compiling, maybe_compiled_cat,
|
||||||
|
maybe_compiled_copy_)
|
||||||
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
|
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
|
||||||
from .multi_stream_utils import maybe_execute_in_parallel
|
from .multi_stream_utils import maybe_execute_in_parallel
|
||||||
from .rms_norm import RMSNorm
|
from .rms_norm import RMSNorm
|
||||||
@ -78,16 +79,6 @@ def extract_extra_attrs(layer_idx: str, attn_type: str):
|
|||||||
return metadata, attn_layer
|
return metadata, attn_layer
|
||||||
|
|
||||||
|
|
||||||
@maybe_compile
|
|
||||||
def maybe_compiled_copy_(dst, src):
|
|
||||||
dst.copy_(src)
|
|
||||||
|
|
||||||
|
|
||||||
@maybe_compile
|
|
||||||
def maybe_compiled_cat(tensors, dim):
|
|
||||||
return torch.cat(tensors, dim)
|
|
||||||
|
|
||||||
|
|
||||||
def create_attn_outputs_impl(q: torch.Tensor, attention_mask: str,
|
def create_attn_outputs_impl(q: torch.Tensor, attention_mask: str,
|
||||||
layer_idx: str) -> List[torch.Tensor]:
|
layer_idx: str) -> List[torch.Tensor]:
|
||||||
metadata, attn_layer = extract_extra_attrs(layer_idx, "attn")
|
metadata, attn_layer = extract_extra_attrs(layer_idx, "attn")
|
||||||
|
|||||||
@ -404,3 +404,13 @@ def split(x: torch.Tensor,
|
|||||||
|
|
||||||
def relu2(x: torch.Tensor) -> torch.Tensor:
|
def relu2(x: torch.Tensor) -> torch.Tensor:
|
||||||
return torch.square(F.relu(x))
|
return torch.square(F.relu(x))
|
||||||
|
|
||||||
|
|
||||||
|
@maybe_compile
|
||||||
|
def maybe_compiled_copy_(dst, src):
|
||||||
|
dst.copy_(src)
|
||||||
|
|
||||||
|
|
||||||
|
@maybe_compile
|
||||||
|
def maybe_compiled_cat(tensors, dim):
|
||||||
|
return torch.cat(tensors, dim)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user