mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
chore: Refactor apply_rope. (#4918)
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> Co-authored-by: hlu1 <14827759+hlu1@users.noreply.github.com>
This commit is contained in:
parent
6b17dff2f1
commit
c104388d37
@ -14,7 +14,7 @@ from ..attention_backend.interface import (PositionalEmbeddingParams,
|
||||
PredefinedAttentionMask, RopeParams)
|
||||
from ..distributed import AllReduceParams
|
||||
from ..model_config import ModelConfig
|
||||
from ..modules.attention import Attention, QkNormType
|
||||
from ..modules.attention import Attention
|
||||
from ..modules.decoder_layer import DecoderLayer
|
||||
from ..modules.embedding import Embedding
|
||||
from ..modules.linear import Linear, TensorParallelMode
|
||||
@ -53,7 +53,6 @@ class Gemma3Attention(Attention):
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
bias=False,
|
||||
pos_embd_params=pos_embd_params,
|
||||
qk_norm_type=QkNormType.pre_rope,
|
||||
layer_idx=layer_idx,
|
||||
dtype=config.torch_dtype,
|
||||
dense_bias=False,
|
||||
@ -113,6 +112,13 @@ class Gemma3Attention(Attention):
|
||||
|
||||
return q, k
|
||||
|
||||
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
|
||||
v: Optional[torch.Tensor], position_ids: torch.Tensor):
|
||||
# Gemma3 applies QK norm before RoPE.
|
||||
q, k, v = self.split_qkv(q, k, v)
|
||||
q, k = self.apply_qk_norm(q, k)
|
||||
return super().apply_rope(q, k, v, position_ids)
|
||||
|
||||
|
||||
class Gemma3MLP(nn.Module):
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ from ..attention_backend import AttentionMetadata
|
||||
from ..attention_backend.interface import (PositionalEmbeddingParams,
|
||||
PredefinedAttentionMask, RopeParams)
|
||||
from ..model_config import ModelConfig
|
||||
from ..modules.attention import Attention, QkNormType
|
||||
from ..modules.attention import Attention
|
||||
from ..modules.decoder_layer import DecoderLayer
|
||||
from ..modules.embedding import Embedding
|
||||
from ..modules.fused_moe import (Llama4RenormalizeMoeRoutingMethod,
|
||||
@ -60,6 +60,7 @@ class Llama4Attention(Attention):
|
||||
rope=RopeParams.from_config(config),
|
||||
is_neox=False,
|
||||
) if self.use_rope else None
|
||||
self.use_qk_norm = use_qk_norm
|
||||
|
||||
if model_config.attn_backend != "TRTLLM":
|
||||
# TODO: support chunked attention for other backends.
|
||||
@ -74,15 +75,15 @@ class Llama4Attention(Attention):
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
bias=config.attention_bias,
|
||||
pos_embd_params=pos_embd_params,
|
||||
qk_norm_type=QkNormType.post_rope
|
||||
if use_qk_norm else QkNormType.none,
|
||||
rope_fusion=not self.
|
||||
use_qk_norm, # Llama4 uses qk_norm after RoPE, so it is not possible to fuse RoPE into the attention OP with qk_norm.
|
||||
layer_idx=layer_idx,
|
||||
dtype=config.torch_dtype,
|
||||
config=model_config,
|
||||
attention_chunk_size=attention_chunk_size,
|
||||
)
|
||||
|
||||
if self.use_rope and use_qk_norm:
|
||||
if self.use_qk_norm:
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
self.qk_norm = RMSNorm(hidden_size=self.head_dim,
|
||||
eps=1e-6,
|
||||
@ -115,6 +116,17 @@ class Llama4Attention(Attention):
|
||||
|
||||
return q, k
|
||||
|
||||
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
|
||||
v: Optional[torch.Tensor], position_ids: torch.Tensor):
|
||||
q, k, v = self.split_qkv(q, k, v)
|
||||
if position_ids is not None:
|
||||
q, k, v = super().apply_rope(q, k, v, position_ids)
|
||||
# Llama4 applies QK norm after RoPE.
|
||||
if self.use_qk_norm:
|
||||
q, k = self.apply_qk_norm(q, k)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def _attention_scaling(self, q, position_ids):
|
||||
|
||||
def _get_attn_scale(position_ids: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -9,7 +9,7 @@ from tensorrt_llm.functional import PositionEmbeddingType
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
|
||||
from ..model_config import ModelConfig
|
||||
from ..modules.attention import Attention, QkNormType
|
||||
from ..modules.attention import Attention
|
||||
from ..modules.decoder_layer import DecoderLayer
|
||||
from ..modules.embedding import Embedding
|
||||
from ..modules.gated_mlp import GatedMLP
|
||||
@ -50,19 +50,15 @@ class Qwen3Attention(Attention):
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
bias=config.attention_bias,
|
||||
pos_embd_params=pos_embd_params
|
||||
if not self.fuse_qk_norm_rope else None,
|
||||
qk_norm_type=QkNormType.pre_rope,
|
||||
pos_embd_params=pos_embd_params,
|
||||
rope_fusion=not self.
|
||||
fuse_qk_norm_rope, # If fuse_qk_norm_rope is true, do not apply fused RoPE in attention OP, and self.rotary_emb will be skipped in the overridden apply_rope.
|
||||
layer_idx=layer_idx,
|
||||
dtype=config.torch_dtype,
|
||||
dense_bias=config.attention_bias,
|
||||
config=model_config,
|
||||
)
|
||||
|
||||
# If fuse_qk_norm_rope is true, we pass pos_embd_params=None to super().__init__,
|
||||
# so we need to do assignment to record the actual pos_embd_params.
|
||||
self.pos_embd_params = pos_embd_params
|
||||
|
||||
self.q_norm = RMSNorm(hidden_size=self.head_dim,
|
||||
eps=1e-6,
|
||||
dtype=config.torch_dtype,
|
||||
@ -94,12 +90,6 @@ class Qwen3Attention(Attention):
|
||||
|
||||
return q, k
|
||||
|
||||
def apply_rope(self, qkv: torch.Tensor, position_ids: torch.Tensor):
|
||||
if not self.fuse_qk_norm_rope:
|
||||
return super().apply_rope(qkv, position_ids)
|
||||
else:
|
||||
return self.apply_qk_norm_rope(qkv, position_ids)
|
||||
|
||||
def apply_qk_norm_rope(self, qkv, position_ids):
|
||||
torch.ops.trtllm.fused_qk_norm_rope(
|
||||
qkv, self.num_heads, self.num_key_value_heads,
|
||||
@ -109,6 +99,18 @@ class Qwen3Attention(Attention):
|
||||
self.pos_embd_params.is_neox, position_ids.view(-1))
|
||||
return qkv, None, None
|
||||
|
||||
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
|
||||
v: Optional[torch.Tensor], position_ids: torch.Tensor):
|
||||
# Qwen3 applies QK norm before RoPE.
|
||||
if not self.fuse_qk_norm_rope:
|
||||
q, k, v = self.split_qkv(q, k, v)
|
||||
q, k = self.apply_qk_norm(q, k)
|
||||
return super().apply_rope(q, k, v, position_ids)
|
||||
|
||||
assert k is None and v is None, "The input should be a concatenated qkv tensor to apply_qk_norm_rope"
|
||||
qkv = q
|
||||
return self.apply_qk_norm_rope(qkv, position_ids)
|
||||
|
||||
|
||||
class Qwen3DecoderLayer(DecoderLayer):
|
||||
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
import math
|
||||
import weakref
|
||||
from enum import IntEnum
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ..attention_backend import (AttentionInputType, AttentionMetadata,
|
||||
@ -23,15 +23,6 @@ from .rms_norm import RMSNorm
|
||||
from .rotary_embedding import RotaryEmbedding
|
||||
|
||||
|
||||
class QkNormType(IntEnum):
|
||||
"""
|
||||
The type of QK normalization.
|
||||
"""
|
||||
none = 0 # No normalization applied to Q and K
|
||||
pre_rope = 1 # Apply normalization before Rope
|
||||
post_rope = 2 # Apply normalization after Rope
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@ -43,7 +34,7 @@ class Attention(nn.Module):
|
||||
max_position_embeddings: int,
|
||||
bias: bool,
|
||||
pos_embd_params: Optional[PositionalEmbeddingParams] = None,
|
||||
qk_norm_type: QkNormType = QkNormType.none,
|
||||
rope_fusion: Optional[bool] = None,
|
||||
layer_idx: Optional[int] = None,
|
||||
dtype: torch.dtype = None,
|
||||
dense_bias: Optional[bool] = None,
|
||||
@ -60,14 +51,14 @@ class Attention(nn.Module):
|
||||
num_key_value_heads (int): The number of key value heads.
|
||||
max_position_embeddings (int): The maximum position embeddings.
|
||||
bias (bool): Whether to use bias in the linear layers.
|
||||
pos_embd_params (PositionalEmbeddingParams): The positional embedding parameters.
|
||||
qk_norm_type (QkNormType): The type of QK normalization.
|
||||
layer_idx (int): The layer index.
|
||||
pos_embd_params (Optional[PositionalEmbeddingParams]): The positional embedding parameters.
|
||||
rope_fusion (Optional[bool]): Whether to fuse RoPE into the attention OP and skip applying unfused RoPE. If None, whether to fuse is decided by the capability of the attention backend.
|
||||
layer_idx (Optional[int]): The layer index.
|
||||
dtype (torch.dtype): The data type.
|
||||
dense_bias (bool): Whether to use bias in the output projection layer.
|
||||
config (ModelConfig): The model configuration.
|
||||
dense_bias (Optional[bool]): Whether to use bias in the output projection layer.
|
||||
config (Optional[ModelConfig]): The model configuration.
|
||||
q_scaling (float): The scaling factor for the qk_scale. The definition is $O = softmax(QK^T * qk_scale) * V, qk_scale = 1 / (sqrt(head_dim) * q_scaling)$. The default value is 1.0.
|
||||
attention_chunk_size (int): See [Chunked Attention] below.
|
||||
attention_chunk_size (Optional[int]): See [Chunked Attention] below.
|
||||
"""
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
@ -81,7 +72,6 @@ class Attention(nn.Module):
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.pos_embd_params = pos_embd_params
|
||||
self.qk_norm_type = qk_norm_type
|
||||
self.dense_bias = dense_bias
|
||||
self.q_scaling = q_scaling
|
||||
|
||||
@ -169,14 +159,21 @@ class Attention(nn.Module):
|
||||
self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
|
||||
[self.hidden_size])
|
||||
|
||||
# enable_rope_fusion: Whether to fuse RoPE into the attention OP.
|
||||
# Whether to fuse RoPE into the attention OP.
|
||||
# If true, RoPE will be applied in self.attn.forward.
|
||||
# If false, RoPE will be applied in self.apply_rope.
|
||||
self.enable_rope_fusion = attn_cls.support_fused_rope(
|
||||
) and self.qk_norm_type != QkNormType.post_rope
|
||||
self.rope_fusion = rope_fusion
|
||||
if self.rope_fusion and not attn_cls.support_fused_rope():
|
||||
logger.warning(
|
||||
"rope_fusion is true but the attention backend does not support it. Will disable rope_fusion."
|
||||
)
|
||||
self.rope_fusion = False
|
||||
# If rope_fusion is not specified, enable if the attention backend supports it.
|
||||
if self.rope_fusion is None:
|
||||
self.rope_fusion = attn_cls.support_fused_rope()
|
||||
|
||||
self.rotary_emb = None
|
||||
if not self.enable_rope_fusion and self.pos_embd_params is not None:
|
||||
if not self.rope_fusion and self.pos_embd_params is not None:
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
self.pos_embd_params.rope,
|
||||
head_dim=self.head_dim,
|
||||
@ -189,8 +186,7 @@ class Attention(nn.Module):
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.num_key_value_heads,
|
||||
pos_embd_params=self.pos_embd_params
|
||||
if self.enable_rope_fusion else None,
|
||||
pos_embd_params=self.pos_embd_params if self.rope_fusion else None,
|
||||
quant_config=self.quant_config,
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
q_scaling=self.q_scaling,
|
||||
@ -263,7 +259,9 @@ class Attention(nn.Module):
|
||||
if qkv_lora is not None:
|
||||
qkv = qkv + qkv_lora
|
||||
|
||||
q, k, v = self.apply_rope(qkv, position_ids)
|
||||
q, k, v = qkv, None, None
|
||||
|
||||
q, k, v = self.apply_rope(q, k, v, position_ids)
|
||||
|
||||
out_scale = None
|
||||
out_scale_sf = None
|
||||
@ -290,32 +288,25 @@ class Attention(nn.Module):
|
||||
layer_idx=self.layer_idx)
|
||||
return attn_output
|
||||
|
||||
def apply_qk_norm(self, q, k):
|
||||
raise NotImplementedError(
|
||||
f"QK norm is not implemented for {self.__class__.__name__}."
|
||||
"Please override the `apply_qk_norm` method in the subclass.")
|
||||
|
||||
def apply_rope(self, qkv: torch.Tensor, position_ids: torch.Tensor):
|
||||
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
|
||||
v: Optional[torch.Tensor], position_ids: torch.Tensor):
|
||||
"""
|
||||
Apply RoPE to the query and key, possibly including QK norm.
|
||||
Apply RoPE to the query and key.
|
||||
Depending on the implementation, q, k, v could be either fused (q, k, v = concat(q, k, v), None, None) or unfused (none of q, k, v is None).
|
||||
Before self.attn.forward, convert_qkv will be called to make sure that the format of (q, k, v) satisfies the requirement of self.attn.
|
||||
This method could be overridden in the subclass, in which extra functionalities such as q_norm/k_norm could be added.
|
||||
Args:
|
||||
qkv (torch.Tensor): The query, key, and value tensor.
|
||||
q (torch.Tensor): The query tensor.
|
||||
k (Optional[torch.Tensor]): The key tensor.
|
||||
v (Optional[torch.Tensor]): The value tensor.
|
||||
position_ids (torch.Tensor): The position IDs of each token for RoPE.
|
||||
Returns:
|
||||
tuple: A tuple of (q, k, v).
|
||||
This method could be overridden in the subclass, it is possible that k/v is None and q is the concatenated qkv tensor, up to the implementation.
|
||||
Before self.attn.forward, convert_qkv will be called to make sure that the format of (q, k, v) satisfies the requirement of self.attn.
|
||||
"""
|
||||
q, k, v = qkv, None, None
|
||||
if self.qk_norm_type == QkNormType.pre_rope:
|
||||
q, k, v = self.split_qkv(q, k, v)
|
||||
q, k = self.apply_qk_norm(q, k)
|
||||
if not self.enable_rope_fusion and position_ids is not None:
|
||||
q, k, v = self.split_qkv(q, k, v)
|
||||
q, k, v = self.split_qkv(q, k, v)
|
||||
# If RoPE is fused into the attention OP, do not apply RoPE here.
|
||||
if not self.rope_fusion and position_ids is not None:
|
||||
q, k = self.rotary_emb(position_ids, [q, k])
|
||||
if self.qk_norm_type == QkNormType.post_rope:
|
||||
q, k = self.apply_qk_norm(q, k)
|
||||
|
||||
return q, k, v
|
||||
|
||||
|
||||
@ -600,14 +591,14 @@ class MLA(nn.Module):
|
||||
self.aux_stream = aux_stream
|
||||
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
|
||||
|
||||
self.enable_rope_fusion = self.mha.support_fused_rope()
|
||||
self.rope_fusion = self.mha.support_fused_rope()
|
||||
self.support_fused_qkv = self.mha.support_fused_qkv()
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
pos_embd_params.rope,
|
||||
head_dim=self.qk_rope_head_dim,
|
||||
is_neox=pos_embd_params.is_neox,
|
||||
)
|
||||
self.apply_rotary_emb = not self.enable_rope_fusion
|
||||
self.apply_rotary_emb = not self.rope_fusion
|
||||
|
||||
if not config.skip_create_weights_in_init:
|
||||
self.create_weights()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user