refactor: Allow models to override apply_qk_norm. (#4078)

Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
yuxianq 2025-05-12 19:38:24 +08:00 committed by GitHub
parent c9e2a963e0
commit b35f9a67f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 111 additions and 180 deletions

View File

@ -22,7 +22,7 @@ from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import (PositionalEmbeddingParams,
PredefinedAttentionMask, RopeParams)
from ..model_config import ModelConfig
from ..modules.attention import Attention
from ..modules.attention import Attention, QkNormType
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.fused_moe import (FusedMoE, Llama4RenormalizeMoeRoutingMethod,
@ -32,7 +32,6 @@ from ..modules.linear import (Linear, TensorParallelMode, WeightMode,
WeightsLoadingConfig)
from ..modules.multi_stream_utils import maybe_execute_in_parallel
from ..modules.rms_norm import RMSNorm
from ..modules.rotary_embedding import RotaryEmbedding
from ..speculative import Eagle3SpecMetadata, SpecMetadata
from .modeling_multimodal_utils import fuse_input_embeds
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
@ -53,86 +52,51 @@ class Llama4Attention(Attention):
aux_stream: Optional[torch.cuda.Stream] = None,
):
config = model_config.pretrained_config
self.aux_stream = aux_stream
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
self.use_rope = not nope_layer
self.use_qk_norm = use_qk_norm and not nope_layer
if self.use_rope and not self.use_qk_norm:
pos_embd_params = PositionalEmbeddingParams(
type=PositionEmbeddingType.rope_gptj,
rope=RopeParams.from_config(config),
is_neox=False,
)
else:
pos_embd_params = None
pos_embd_params = PositionalEmbeddingParams(
type=PositionEmbeddingType.rope_gptj,
rope=RopeParams.from_config(config),
is_neox=False,
) if self.use_rope else None
super().__init__(hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
max_position_embeddings=config.max_position_embeddings,
bias=config.attention_bias,
pos_embd_params=pos_embd_params,
layer_idx=layer_idx,
dtype=config.torch_dtype,
config=model_config)
super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
max_position_embeddings=config.max_position_embeddings,
bias=config.attention_bias,
pos_embd_params=pos_embd_params,
layer_idx=layer_idx,
dtype=config.torch_dtype,
config=model_config,
qk_norm_type=QkNormType.post_rope
if use_qk_norm else QkNormType.none,
)
if self.use_rope and self.use_qk_norm:
# here we must disable rope fusion regardless of attn_backend
self.enable_rope_fusion = False
self.rotary_emb = RotaryEmbedding(
RopeParams.from_config(config),
head_dim=self.head_dim,
is_neox=False,
)
if self.use_qk_norm:
if self.use_rope and use_qk_norm:
self.head_dim = config.hidden_size // config.num_attention_heads
self.qk_norm = RMSNorm(hidden_size=self.head_dim,
eps=1e-6,
dtype=config.torch_dtype,
has_weights=False)
else:
self.qk_norm = None
self.aux_stream = aux_stream
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
self.attn_temperature_tuning = attn_temperature_tuning and nope_layer
self.floor_scale = getattr(config, "floor_scale", 8192.0)
self.attn_scale = getattr(config, "attn_scale", 0.1)
def _attn_qkv(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_metadata: AttentionMetadata,
attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.
CAUSAL,
mrope_config: Optional[dict] = None,
all_reduce_params: Optional[AllReduceParams] = None):
out_scale = None
if self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 or self.o_proj.has_fp8_block_scales:
out_scale = self.o_proj.inv_input_scale
def apply_qk_norm(self, q, k):
q, k, v = self.convert_qkv(q, k, v)
attn_output = self.attn.forward(q,
k,
v,
attn_metadata,
out_scale=out_scale,
attention_mask=attention_mask,
mrope_config=mrope_config)
def q_l2norm():
return self.qk_norm(q.reshape(-1, self.head_dim)).reshape(
-1, self.q_size)
attn_output = self.o_proj(attn_output,
all_reduce_params=all_reduce_params)
def k_l2norm():
return self.qk_norm(k.reshape(-1, self.head_dim)).reshape(
-1, self.kv_size)
return attn_output
def _qk_norm(self, q, k):
# TODO: make this more efficient.
q_l2norm = lambda: self.qk_norm(q.reshape(-1, self.head_dim)).reshape(
-1, self.q_size)
k_l2norm = lambda: self.qk_norm(k.reshape(-1, self.head_dim)).reshape(
-1, self.kv_size)
q, k = maybe_execute_in_parallel(
q_l2norm,
k_l2norm,
@ -155,31 +119,6 @@ class Llama4Attention(Attention):
q = (q * attn_scale).to(q.dtype)
return q
def _forward_rope(
self,
position_ids: Optional[torch.LongTensor],
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.
CAUSAL,
mrope_config: Optional[dict] = None,
all_reduce_params: Optional[AllReduceParams] = None,
):
if self.use_qk_norm:
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
assert self.rotary_emb is not None and not self.enable_rope_fusion, "qk_norm requires attention rope fusion disabled"
q, k = self.rotary_emb(position_ids, [q, k])
q, k = self._qk_norm(q, k)
return self._attn_qkv(q, k, v, attn_metadata, attention_mask,
mrope_config, all_reduce_params)
else:
# When qk_norm is disabled, use the classic attention path that handles RoPE fusion
return super().forward(position_ids, hidden_states, attn_metadata,
attention_mask, mrope_config,
all_reduce_params)
def _forward_nope(
self,
position_ids: Optional[torch.LongTensor],
@ -191,11 +130,26 @@ class Llama4Attention(Attention):
all_reduce_params: Optional[AllReduceParams] = None,
):
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k, v = self.split_qkv(qkv)
if self.attn_temperature_tuning:
q = self._attention_scaling(q, position_ids)
return self._attn_qkv(q, k, v, attn_metadata, attention_mask,
mrope_config, all_reduce_params)
out_scale = None
if self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 or self.o_proj.has_fp8_block_scales:
out_scale = self.o_proj.inv_input_scale
q, k, v = self.convert_qkv(q, k, v)
attn_output = self.attn.forward(q,
k,
v,
attn_metadata,
out_scale=out_scale,
attention_mask=attention_mask,
mrope_config=mrope_config)
attn_output = self.o_proj(attn_output,
all_reduce_params=all_reduce_params)
return attn_output
def forward(
self,
@ -211,9 +165,16 @@ class Llama4Attention(Attention):
) -> torch.Tensor:
assert lora_params is None, "LORA is not supported for Llama4Attention"
if self.use_rope:
return self._forward_rope(position_ids, hidden_states,
attn_metadata, attention_mask,
mrope_config, all_reduce_params)
return super().forward(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
attention_mask=attention_mask,
mrope_config=mrope_config,
all_reduce_params=all_reduce_params,
lora_params=lora_params,
**kwargs,
)
else:
return self._forward_nope(position_ids, hidden_states,
attn_metadata, attention_mask,

View File

@ -9,11 +9,12 @@ from tensorrt_llm.functional import PositionEmbeddingType
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
from ..model_config import ModelConfig
from ..modules.attention import Attention
from ..modules.attention import Attention, QkNormType
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import TensorParallelMode
from ..modules.multi_stream_utils import maybe_execute_in_parallel
from ..modules.rms_norm import RMSNorm
from ..pipeline_interface import PipelineInterface
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
@ -51,6 +52,7 @@ class Qwen3Attention(Attention):
dtype=config.torch_dtype,
dense_bias=config.attention_bias,
config=model_config,
qk_norm_type=QkNormType.pre_rope,
)
self.q_norm = RMSNorm(hidden_size=self.head_dim,
@ -64,6 +66,26 @@ class Qwen3Attention(Attention):
self.aux_stream = torch.cuda.Stream()
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
def apply_qk_norm(self, q, k):
def q_l2norm():
return self.q_norm(q.reshape(-1, self.head_dim)).reshape(
-1, self.q_size)
def k_l2norm():
return self.k_norm(k.reshape(-1, self.head_dim)).reshape(
-1, self.kv_size)
q, k = maybe_execute_in_parallel(
q_l2norm,
k_l2norm,
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
return q, k
class Qwen3DecoderLayer(DecoderLayer):

View File

@ -5,17 +5,14 @@ from torch import nn
from tqdm import tqdm
from transformers import Qwen3MoeConfig
from tensorrt_llm.functional import PositionEmbeddingType
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
from ..model_config import ModelConfig
from ..modules.attention import Attention
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.fused_moe import FusedMoE, RenormalizeMoeRoutingMethod
from ..modules.linear import Linear, TensorParallelMode
from ..modules.rms_norm import RMSNorm
from .modeling_qwen3 import Qwen3Attention
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
duplicate_kv_weight, register_auto_model)
@ -81,57 +78,13 @@ class Qwen3MoE(nn.Module):
return final_hidden_states.view(orig_shape)
class Qwen3MoEAttention(Attention):
def __init__(
self,
model_config: ModelConfig[Qwen3MoeConfig],
layer_idx: Optional[int] = None,
):
config = model_config.pretrained_config
if getattr(config, "rope_scaling", None) is not None:
pos_embd_params = PositionalEmbeddingParams(
type=PositionEmbeddingType.from_string(
config.rope_scaling["type"]),
rope=RopeParams.from_config(config),
)
else:
pos_embd_params = PositionalEmbeddingParams(
type=PositionEmbeddingType.rope_gpt_neox,
rope=RopeParams.from_config(config),
)
super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
max_position_embeddings=config.max_position_embeddings,
bias=config.attention_bias,
pos_embd_params=pos_embd_params,
layer_idx=layer_idx,
dtype=config.torch_dtype,
dense_bias=config.attention_bias,
config=model_config,
)
self.q_norm = RMSNorm(hidden_size=self.head_dim,
eps=1e-6,
dtype=config.torch_dtype,
has_weights=True)
self.k_norm = RMSNorm(hidden_size=self.head_dim,
eps=1e-6,
dtype=config.torch_dtype,
has_weights=True)
self.aux_stream = torch.cuda.Stream()
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
class Qwen3MoEDecoderLayer(DecoderLayer):
def __init__(self, model_config: ModelConfig[Qwen3MoeConfig],
layer_idx: int, aux_stream: torch.cuda.Stream):
super().__init__()
config = model_config.pretrained_config
self.self_attn = Qwen3MoEAttention(
self.self_attn = Qwen3Attention(
model_config,
layer_idx=layer_idx,
)

View File

@ -1,4 +1,5 @@
import math
from enum import IntEnum
from typing import Optional
import torch
@ -19,6 +20,12 @@ from .rms_norm import RMSNorm
from .rotary_embedding import RotaryEmbedding
class QkNormType(IntEnum):
none = 0
pre_rope = 1
post_rope = 2
class Attention(nn.Module):
def __init__(
@ -34,6 +41,7 @@ class Attention(nn.Module):
dtype: torch.dtype = None,
dense_bias: Optional[bool] = None,
config: Optional[ModelConfig] = None,
qk_norm_type: QkNormType = QkNormType.none,
):
super().__init__()
self.layer_idx = layer_idx
@ -41,15 +49,13 @@ class Attention(nn.Module):
config = config or ModelConfig()
self.hidden_size = hidden_size
self.num_heads = num_attention_heads
if config:
self.head_dim = getattr(config.pretrained_config, "head_dim",
self.hidden_size // self.num_heads)
else:
self.head_dim = self.hidden_size // self.num_heads
self.head_dim = getattr(config.pretrained_config, "head_dim",
self.hidden_size // self.num_heads)
self.num_key_value_heads = num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = max_position_embeddings
self.pos_embd_params = pos_embd_params
self.qk_norm_type = qk_norm_type
self.dense_bias = dense_bias
if dense_bias is None:
@ -119,12 +125,9 @@ class Attention(nn.Module):
self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
[self.hidden_size])
use_qk_norm = (config.pretrained_config and
(config.pretrained_config.model_type == 'qwen3'
or config.pretrained_config.model_type == 'qwen3_moe'))
attn_cls = get_attention_backend(self.attn_backend)
self.enable_rope_fusion = attn_cls.support_fused_rope(
) and not use_qk_norm
) and qk_norm_type != QkNormType.post_rope
self.attn = create_attention(
self.attn_backend,
self.layer_idx,
@ -157,9 +160,14 @@ class Attention(nn.Module):
# which could be modified after __init__
self.attn.update_quant_config(self.quant_config)
def split_qkv(self, q, k=None, v=None):
if k is None and v is None:
q, k, v = q.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
return q, k, v
def convert_qkv(self, q, k, v):
if k is None and v is None and not self.support_fused_qkv:
q, k, v = q.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k, v = self.split_qkv(q)
elif k is not None and v is not None and self.support_fused_qkv:
qkv = torch.concat([q, k, v], dim=-1)
q, k, v = qkv, None, None
@ -191,32 +199,14 @@ class Attention(nn.Module):
qkv = qkv + qkv_lora
q, k, v = qkv, None, None
if self.qk_norm_type == QkNormType.pre_rope:
q, k, v = self.split_qkv(q, k, v)
q, k = self.apply_qk_norm(q, k)
if self.apply_rotary_emb and position_ids is not None:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
if hasattr(self, 'q_norm') and hasattr(self, 'k_norm'):
# Add qk-norm
if hasattr(self, 'ln_events'):
q_l2norm = lambda: self.q_norm(q.reshape(-1, self.head_dim)
).reshape(-1, self.q_size)
k_l2norm = lambda: self.k_norm(k.reshape(-1, self.head_dim)
).reshape(-1, self.kv_size)
q, k = maybe_execute_in_parallel(
q_l2norm,
k_l2norm,
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
else:
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
q, k, v = self.split_qkv(q, k, v)
q, k = self.rotary_emb(position_ids, [q, k])
if self.qk_norm_type == QkNormType.post_rope:
q, k = self.apply_qk_norm(q, k)
out_scale = None
if self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 or self.o_proj.has_fp8_block_scales:
@ -237,6 +227,11 @@ class Attention(nn.Module):
layer_idx=self.layer_idx)
return attn_output
def apply_qk_norm(self, q, k):
raise NotImplementedError(
f"QK norm is not implemented for {self.__class__.__name__}."
"Please override the `apply_qk_norm` method in the subclass.")
class MLA(nn.Module):