feat: Integration of Fused QKNorm+RoPE. (#4611)

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
This commit is contained in:
Bo Li 2025-05-28 11:20:45 +08:00 committed by GitHub
parent 6493401986
commit 9c4b8f66b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 72 additions and 32 deletions

View File

@ -53,11 +53,11 @@ class Gemma3Attention(Attention):
max_position_embeddings=config.max_position_embeddings,
bias=False,
pos_embd_params=pos_embd_params,
qk_norm_type=QkNormType.pre_rope,
layer_idx=layer_idx,
dtype=config.torch_dtype,
dense_bias=False,
config=model_config,
qk_norm_type=QkNormType.pre_rope,
q_scaling=q_scaling,
)
self.q_norm = RMSNorm(hidden_size=config.head_dim,

View File

@ -73,11 +73,11 @@ class Llama4Attention(Attention):
max_position_embeddings=config.max_position_embeddings,
bias=config.attention_bias,
pos_embd_params=pos_embd_params,
qk_norm_type=QkNormType.post_rope
if use_qk_norm else QkNormType.none,
layer_idx=layer_idx,
dtype=config.torch_dtype,
config=model_config,
qk_norm_type=QkNormType.post_rope
if use_qk_norm else QkNormType.none,
attention_chunk_size=attention_chunk_size,
)

View File

@ -26,8 +26,10 @@ class Qwen3Attention(Attention):
self,
model_config: ModelConfig[Qwen3Config],
layer_idx: Optional[int] = None,
fuse_qk_norm_rope: bool = True,
):
config = model_config.pretrained_config
if getattr(config, "rope_scaling", None) is not None:
pos_embd_params = PositionalEmbeddingParams(
type=PositionEmbeddingType.from_string(
@ -40,20 +42,27 @@ class Qwen3Attention(Attention):
rope=RopeParams.from_config(config),
)
self.fuse_qk_norm_rope = fuse_qk_norm_rope
super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
max_position_embeddings=config.max_position_embeddings,
bias=config.attention_bias,
pos_embd_params=pos_embd_params,
pos_embd_params=pos_embd_params
if not self.fuse_qk_norm_rope else None,
qk_norm_type=QkNormType.pre_rope,
layer_idx=layer_idx,
dtype=config.torch_dtype,
dense_bias=config.attention_bias,
config=model_config,
qk_norm_type=QkNormType.pre_rope,
)
# If fuse_qk_norm_rope is true, we pass pos_embd_params=None to super().__init__,
# so we need to do assignment to record the actual pos_embd_params.
self.pos_embd_params = pos_embd_params
self.q_norm = RMSNorm(hidden_size=self.head_dim,
eps=1e-6,
dtype=config.torch_dtype,
@ -85,6 +94,21 @@ class Qwen3Attention(Attention):
return q, k
def apply_rope(self, qkv: torch.Tensor, position_ids: torch.Tensor):
if not self.fuse_qk_norm_rope:
return super().apply_rope(qkv, position_ids)
else:
return self.apply_qk_norm_rope(qkv, position_ids)
def apply_qk_norm_rope(self, qkv, position_ids):
torch.ops.trtllm.fused_qk_norm_rope(
qkv, self.num_heads, self.num_key_value_heads,
self.num_key_value_heads, self.head_dim,
self.q_norm.variance_epsilon, self.q_norm.weight,
self.k_norm.weight, self.pos_embd_params.rope.theta,
self.pos_embd_params.is_neox, position_ids.view(-1))
return qkv, None, None
class Qwen3DecoderLayer(DecoderLayer):

View File

@ -43,11 +43,11 @@ class Attention(nn.Module):
max_position_embeddings: int,
bias: bool,
pos_embd_params: Optional[PositionalEmbeddingParams] = None,
qk_norm_type: QkNormType = QkNormType.none,
layer_idx: Optional[int] = None,
dtype: torch.dtype = None,
dense_bias: Optional[bool] = None,
config: Optional[ModelConfig] = None,
qk_norm_type: QkNormType = QkNormType.none,
q_scaling: float = 1.0,
attention_chunk_size: Optional[int] = None,
):
@ -61,11 +61,11 @@ class Attention(nn.Module):
max_position_embeddings (int): The maximum position embeddings.
bias (bool): Whether to use bias in the linear layers.
pos_embd_params (PositionalEmbeddingParams): The positional embedding parameters.
qk_norm_type (QkNormType): The type of QK normalization.
layer_idx (int): The layer index.
dtype (torch.dtype): The data type.
dense_bias (bool): Whether to use bias in the output projection layer.
config (ModelConfig): The model configuration.
qk_norm_type (QkNormType): The type of QK normalization.
q_scaling (float): The scaling factor for the qk_scale. The definition is $O = softmax(QK^T * qk_scale) * V, qk_scale = 1 / (sqrt(head_dim) * q_scaling)$. The default value is 1.0.
attention_chunk_size (int): See [Chunked Attention] below.
"""
@ -154,7 +154,7 @@ class Attention(nn.Module):
self.quant_config = config.get_quant_config()
self.attn_backend = config.attn_backend
self.pos_embd_params = pos_embd_params
attn_cls = get_attention_backend(self.attn_backend)
# These two modules are mutually exclusive - either splitted_qkv_lora or fused_qkv_lora will be used,
# but never both at the same time. splitted_qkv_lora handles Q,K,V separately while fused_qkv_lora
@ -169,9 +169,20 @@ class Attention(nn.Module):
self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
[self.hidden_size])
attn_cls = get_attention_backend(self.attn_backend)
# enable_rope_fusion: Whether to fuse RoPE into the attention OP.
# If true, RoPE will be applied in self.attn.forward.
# If false, RoPE will be applied in self.apply_rope.
self.enable_rope_fusion = attn_cls.support_fused_rope(
) and qk_norm_type != QkNormType.post_rope
) and self.qk_norm_type != QkNormType.post_rope
self.rotary_emb = None
if not self.enable_rope_fusion and self.pos_embd_params is not None:
self.rotary_emb = RotaryEmbedding(
self.pos_embd_params.rope,
head_dim=self.head_dim,
is_neox=self.pos_embd_params.is_neox,
)
self.attn = create_attention(
self.attn_backend,
self.layer_idx,
@ -188,16 +199,6 @@ class Attention(nn.Module):
self.support_fused_qkv = self.attn.support_fused_qkv()
self.rotary_emb = None
self.apply_rotary_emb = (not self.enable_rope_fusion
and pos_embd_params is not None)
if self.apply_rotary_emb:
self.rotary_emb = RotaryEmbedding(
pos_embd_params.rope,
head_dim=self.head_dim,
is_neox=pos_embd_params.is_neox,
)
if not config.skip_create_weights_in_init:
self.create_weights()
@ -261,17 +262,9 @@ class Attention(nn.Module):
if qkv_lora is not None:
qkv = qkv + qkv_lora
q, k, v = qkv, None, None
if self.qk_norm_type == QkNormType.pre_rope:
q, k, v = self.split_qkv(q, k, v)
q, k = self.apply_qk_norm(q, k)
if self.apply_rotary_emb and position_ids is not None:
q, k, v = self.split_qkv(q, k, v)
q, k = self.rotary_emb(position_ids, [q, k])
if self.qk_norm_type == QkNormType.post_rope:
q, k = self.apply_qk_norm(q, k)
out_scale = None
q, k, v = self.apply_rope(qkv, position_ids)
out_scale = None
if self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 or self.o_proj.has_fp8_block_scales:
out_scale = self.o_proj.inv_input_scale
@ -297,6 +290,29 @@ class Attention(nn.Module):
f"QK norm is not implemented for {self.__class__.__name__}."
"Please override the `apply_qk_norm` method in the subclass.")
def apply_rope(self, qkv: torch.Tensor, position_ids: torch.Tensor):
"""
Apply RoPE to the query and key, possibly including QK norm.
Args:
qkv (torch.Tensor): The query, key, and value tensor.
position_ids (torch.Tensor): The position IDs of each token for RoPE.
Returns:
tuple: A tuple of (q, k, v).
This method could be overridden in the subclass, it is possible that k/v is None and q is the concatenated qkv tensor, up to the implementation.
Before self.attn.forward, convert_qkv will be called to make sure that the format of (q, k, v) satisfies the requirement of self.attn.
"""
q, k, v = qkv, None, None
if self.qk_norm_type == QkNormType.pre_rope:
q, k, v = self.split_qkv(q, k, v)
q, k = self.apply_qk_norm(q, k)
if not self.enable_rope_fusion and position_ids is not None:
q, k, v = self.split_qkv(q, k, v)
q, k = self.rotary_emb(position_ids, [q, k])
if self.qk_norm_type == QkNormType.post_rope:
q, k = self.apply_qk_norm(q, k)
return q, k, v
def extract_extra_attrs(layer_idx: str):
extra_attrs = get_model_extra_attrs()

View File

@ -52,10 +52,10 @@ class DecodingCUDAGraphRunner:
# Using ones instead of zeros prevents NaNs in e.g. Deepseek
self.input_ids = torch.ones((batch_size * token_per_request, ),
device=device,
dtype=torch.int64)
dtype=torch.int32)
self.position_ids = torch.zeros((1, batch_size * token_per_request),
device=device,
dtype=torch.int64)
dtype=torch.int32)
self.extra_model_inputs = {}
self.attn_metadata = attn_metadata