chore: Refactor apply_rope. (#4918)

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Co-authored-by: hlu1 <14827759+hlu1@users.noreply.github.com>
This commit is contained in:
Bo Li 2025-06-09 16:51:59 +08:00 committed by GitHub
parent 6b17dff2f1
commit c104388d37
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 77 additions and 66 deletions

View File

@ -14,7 +14,7 @@ from ..attention_backend.interface import (PositionalEmbeddingParams,
PredefinedAttentionMask, RopeParams)
from ..distributed import AllReduceParams
from ..model_config import ModelConfig
from ..modules.attention import Attention, QkNormType
from ..modules.attention import Attention
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.linear import Linear, TensorParallelMode
@ -53,7 +53,6 @@ class Gemma3Attention(Attention):
max_position_embeddings=config.max_position_embeddings,
bias=False,
pos_embd_params=pos_embd_params,
qk_norm_type=QkNormType.pre_rope,
layer_idx=layer_idx,
dtype=config.torch_dtype,
dense_bias=False,
@ -113,6 +112,13 @@ class Gemma3Attention(Attention):
return q, k
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
v: Optional[torch.Tensor], position_ids: torch.Tensor):
# Gemma3 applies QK norm before RoPE.
q, k, v = self.split_qkv(q, k, v)
q, k = self.apply_qk_norm(q, k)
return super().apply_rope(q, k, v, position_ids)
class Gemma3MLP(nn.Module):

View File

@ -23,7 +23,7 @@ from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import (PositionalEmbeddingParams,
PredefinedAttentionMask, RopeParams)
from ..model_config import ModelConfig
from ..modules.attention import Attention, QkNormType
from ..modules.attention import Attention
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.fused_moe import (Llama4RenormalizeMoeRoutingMethod,
@ -60,6 +60,7 @@ class Llama4Attention(Attention):
rope=RopeParams.from_config(config),
is_neox=False,
) if self.use_rope else None
self.use_qk_norm = use_qk_norm
if model_config.attn_backend != "TRTLLM":
# TODO: support chunked attention for other backends.
@ -74,15 +75,15 @@ class Llama4Attention(Attention):
max_position_embeddings=config.max_position_embeddings,
bias=config.attention_bias,
pos_embd_params=pos_embd_params,
qk_norm_type=QkNormType.post_rope
if use_qk_norm else QkNormType.none,
rope_fusion=not self.
use_qk_norm, # Llama4 uses qk_norm after RoPE, so it is not possible to fuse RoPE into the attention OP with qk_norm.
layer_idx=layer_idx,
dtype=config.torch_dtype,
config=model_config,
attention_chunk_size=attention_chunk_size,
)
if self.use_rope and use_qk_norm:
if self.use_qk_norm:
self.head_dim = config.hidden_size // config.num_attention_heads
self.qk_norm = RMSNorm(hidden_size=self.head_dim,
eps=1e-6,
@ -115,6 +116,17 @@ class Llama4Attention(Attention):
return q, k
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
v: Optional[torch.Tensor], position_ids: torch.Tensor):
q, k, v = self.split_qkv(q, k, v)
if position_ids is not None:
q, k, v = super().apply_rope(q, k, v, position_ids)
# Llama4 applies QK norm after RoPE.
if self.use_qk_norm:
q, k = self.apply_qk_norm(q, k)
return q, k, v
def _attention_scaling(self, q, position_ids):
def _get_attn_scale(position_ids: torch.Tensor) -> torch.Tensor:

View File

@ -9,7 +9,7 @@ from tensorrt_llm.functional import PositionEmbeddingType
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
from ..model_config import ModelConfig
from ..modules.attention import Attention, QkNormType
from ..modules.attention import Attention
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.gated_mlp import GatedMLP
@ -50,19 +50,15 @@ class Qwen3Attention(Attention):
num_key_value_heads=config.num_key_value_heads,
max_position_embeddings=config.max_position_embeddings,
bias=config.attention_bias,
pos_embd_params=pos_embd_params
if not self.fuse_qk_norm_rope else None,
qk_norm_type=QkNormType.pre_rope,
pos_embd_params=pos_embd_params,
rope_fusion=not self.
fuse_qk_norm_rope, # If fuse_qk_norm_rope is true, do not apply fused RoPE in attention OP, and self.rotary_emb will be skipped in the overridden apply_rope.
layer_idx=layer_idx,
dtype=config.torch_dtype,
dense_bias=config.attention_bias,
config=model_config,
)
# If fuse_qk_norm_rope is true, we pass pos_embd_params=None to super().__init__,
# so we need to do assignment to record the actual pos_embd_params.
self.pos_embd_params = pos_embd_params
self.q_norm = RMSNorm(hidden_size=self.head_dim,
eps=1e-6,
dtype=config.torch_dtype,
@ -94,12 +90,6 @@ class Qwen3Attention(Attention):
return q, k
def apply_rope(self, qkv: torch.Tensor, position_ids: torch.Tensor):
if not self.fuse_qk_norm_rope:
return super().apply_rope(qkv, position_ids)
else:
return self.apply_qk_norm_rope(qkv, position_ids)
def apply_qk_norm_rope(self, qkv, position_ids):
torch.ops.trtllm.fused_qk_norm_rope(
qkv, self.num_heads, self.num_key_value_heads,
@ -109,6 +99,18 @@ class Qwen3Attention(Attention):
self.pos_embd_params.is_neox, position_ids.view(-1))
return qkv, None, None
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
v: Optional[torch.Tensor], position_ids: torch.Tensor):
# Qwen3 applies QK norm before RoPE.
if not self.fuse_qk_norm_rope:
q, k, v = self.split_qkv(q, k, v)
q, k = self.apply_qk_norm(q, k)
return super().apply_rope(q, k, v, position_ids)
assert k is None and v is None, "The input should be a concatenated qkv tensor to apply_qk_norm_rope"
qkv = q
return self.apply_qk_norm_rope(qkv, position_ids)
class Qwen3DecoderLayer(DecoderLayer):

View File

@ -1,11 +1,11 @@
import math
import weakref
from enum import IntEnum
from typing import Optional, Union, cast
import torch
from torch import nn
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from ..attention_backend import (AttentionInputType, AttentionMetadata,
@ -23,15 +23,6 @@ from .rms_norm import RMSNorm
from .rotary_embedding import RotaryEmbedding
class QkNormType(IntEnum):
"""
The type of QK normalization.
"""
none = 0 # No normalization applied to Q and K
pre_rope = 1 # Apply normalization before Rope
post_rope = 2 # Apply normalization after Rope
class Attention(nn.Module):
def __init__(
@ -43,7 +34,7 @@ class Attention(nn.Module):
max_position_embeddings: int,
bias: bool,
pos_embd_params: Optional[PositionalEmbeddingParams] = None,
qk_norm_type: QkNormType = QkNormType.none,
rope_fusion: Optional[bool] = None,
layer_idx: Optional[int] = None,
dtype: torch.dtype = None,
dense_bias: Optional[bool] = None,
@ -60,14 +51,14 @@ class Attention(nn.Module):
num_key_value_heads (int): The number of key value heads.
max_position_embeddings (int): The maximum position embeddings.
bias (bool): Whether to use bias in the linear layers.
pos_embd_params (PositionalEmbeddingParams): The positional embedding parameters.
qk_norm_type (QkNormType): The type of QK normalization.
layer_idx (int): The layer index.
pos_embd_params (Optional[PositionalEmbeddingParams]): The positional embedding parameters.
rope_fusion (Optional[bool]): Whether to fuse RoPE into the attention OP and skip applying unfused RoPE. If None, whether to fuse is decided by the capability of the attention backend.
layer_idx (Optional[int]): The layer index.
dtype (torch.dtype): The data type.
dense_bias (bool): Whether to use bias in the output projection layer.
config (ModelConfig): The model configuration.
dense_bias (Optional[bool]): Whether to use bias in the output projection layer.
config (Optional[ModelConfig]): The model configuration.
q_scaling (float): The scaling factor for the qk_scale. The definition is $O = softmax(QK^T * qk_scale) * V, qk_scale = 1 / (sqrt(head_dim) * q_scaling)$. The default value is 1.0.
attention_chunk_size (int): See [Chunked Attention] below.
attention_chunk_size (Optional[int]): See [Chunked Attention] below.
"""
super().__init__()
self.layer_idx = layer_idx
@ -81,7 +72,6 @@ class Attention(nn.Module):
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = max_position_embeddings
self.pos_embd_params = pos_embd_params
self.qk_norm_type = qk_norm_type
self.dense_bias = dense_bias
self.q_scaling = q_scaling
@ -169,14 +159,21 @@ class Attention(nn.Module):
self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
[self.hidden_size])
# enable_rope_fusion: Whether to fuse RoPE into the attention OP.
# Whether to fuse RoPE into the attention OP.
# If true, RoPE will be applied in self.attn.forward.
# If false, RoPE will be applied in self.apply_rope.
self.enable_rope_fusion = attn_cls.support_fused_rope(
) and self.qk_norm_type != QkNormType.post_rope
self.rope_fusion = rope_fusion
if self.rope_fusion and not attn_cls.support_fused_rope():
logger.warning(
"rope_fusion is true but the attention backend does not support it. Will disable rope_fusion."
)
self.rope_fusion = False
# If rope_fusion is not specified, enable if the attention backend supports it.
if self.rope_fusion is None:
self.rope_fusion = attn_cls.support_fused_rope()
self.rotary_emb = None
if not self.enable_rope_fusion and self.pos_embd_params is not None:
if not self.rope_fusion and self.pos_embd_params is not None:
self.rotary_emb = RotaryEmbedding(
self.pos_embd_params.rope,
head_dim=self.head_dim,
@ -189,8 +186,7 @@ class Attention(nn.Module):
self.num_heads,
self.head_dim,
self.num_key_value_heads,
pos_embd_params=self.pos_embd_params
if self.enable_rope_fusion else None,
pos_embd_params=self.pos_embd_params if self.rope_fusion else None,
quant_config=self.quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init,
q_scaling=self.q_scaling,
@ -263,7 +259,9 @@ class Attention(nn.Module):
if qkv_lora is not None:
qkv = qkv + qkv_lora
q, k, v = self.apply_rope(qkv, position_ids)
q, k, v = qkv, None, None
q, k, v = self.apply_rope(q, k, v, position_ids)
out_scale = None
out_scale_sf = None
@ -290,32 +288,25 @@ class Attention(nn.Module):
layer_idx=self.layer_idx)
return attn_output
def apply_qk_norm(self, q, k):
raise NotImplementedError(
f"QK norm is not implemented for {self.__class__.__name__}."
"Please override the `apply_qk_norm` method in the subclass.")
def apply_rope(self, qkv: torch.Tensor, position_ids: torch.Tensor):
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
v: Optional[torch.Tensor], position_ids: torch.Tensor):
"""
Apply RoPE to the query and key, possibly including QK norm.
Apply RoPE to the query and key.
Depending on the implementation, q, k, v could be either fused (q, k, v = concat(q, k, v), None, None) or unfused (none of q, k, v is None).
Before self.attn.forward, convert_qkv will be called to make sure that the format of (q, k, v) satisfies the requirement of self.attn.
This method could be overridden in the subclass, in which extra functionalities such as q_norm/k_norm could be added.
Args:
qkv (torch.Tensor): The query, key, and value tensor.
q (torch.Tensor): The query tensor.
k (Optional[torch.Tensor]): The key tensor.
v (Optional[torch.Tensor]): The value tensor.
position_ids (torch.Tensor): The position IDs of each token for RoPE.
Returns:
tuple: A tuple of (q, k, v).
This method could be overridden in the subclass, it is possible that k/v is None and q is the concatenated qkv tensor, up to the implementation.
Before self.attn.forward, convert_qkv will be called to make sure that the format of (q, k, v) satisfies the requirement of self.attn.
"""
q, k, v = qkv, None, None
if self.qk_norm_type == QkNormType.pre_rope:
q, k, v = self.split_qkv(q, k, v)
q, k = self.apply_qk_norm(q, k)
if not self.enable_rope_fusion and position_ids is not None:
q, k, v = self.split_qkv(q, k, v)
q, k, v = self.split_qkv(q, k, v)
# If RoPE is fused into the attention OP, do not apply RoPE here.
if not self.rope_fusion and position_ids is not None:
q, k = self.rotary_emb(position_ids, [q, k])
if self.qk_norm_type == QkNormType.post_rope:
q, k = self.apply_qk_norm(q, k)
return q, k, v
@ -600,14 +591,14 @@ class MLA(nn.Module):
self.aux_stream = aux_stream
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
self.enable_rope_fusion = self.mha.support_fused_rope()
self.rope_fusion = self.mha.support_fused_rope()
self.support_fused_qkv = self.mha.support_fused_qkv()
self.rotary_emb = RotaryEmbedding(
pos_embd_params.rope,
head_dim=self.qk_rope_head_dim,
is_neox=pos_embd_params.is_neox,
)
self.apply_rotary_emb = not self.enable_rope_fusion
self.apply_rotary_emb = not self.rope_fusion
if not config.skip_create_weights_in_init:
self.create_weights()