diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 8f6c23356b..e089128265 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -22,7 +22,7 @@ from ..attention_backend import AttentionMetadata from ..attention_backend.interface import (PositionalEmbeddingParams, PredefinedAttentionMask, RopeParams) from ..model_config import ModelConfig -from ..modules.attention import Attention +from ..modules.attention import Attention, QkNormType from ..modules.decoder_layer import DecoderLayer from ..modules.embedding import Embedding from ..modules.fused_moe import (FusedMoE, Llama4RenormalizeMoeRoutingMethod, @@ -32,7 +32,6 @@ from ..modules.linear import (Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig) from ..modules.multi_stream_utils import maybe_execute_in_parallel from ..modules.rms_norm import RMSNorm -from ..modules.rotary_embedding import RotaryEmbedding from ..speculative import Eagle3SpecMetadata, SpecMetadata from .modeling_multimodal_utils import fuse_input_embeds from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, @@ -53,86 +52,51 @@ class Llama4Attention(Attention): aux_stream: Optional[torch.cuda.Stream] = None, ): config = model_config.pretrained_config - self.aux_stream = aux_stream - self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] self.use_rope = not nope_layer - self.use_qk_norm = use_qk_norm and not nope_layer - if self.use_rope and not self.use_qk_norm: - pos_embd_params = PositionalEmbeddingParams( - type=PositionEmbeddingType.rope_gptj, - rope=RopeParams.from_config(config), - is_neox=False, - ) - else: - pos_embd_params = None + pos_embd_params = PositionalEmbeddingParams( + type=PositionEmbeddingType.rope_gptj, + rope=RopeParams.from_config(config), + is_neox=False, + ) if self.use_rope else None - super().__init__(hidden_size=config.hidden_size, - num_attention_heads=config.num_attention_heads, - num_key_value_heads=config.num_key_value_heads, - max_position_embeddings=config.max_position_embeddings, - bias=config.attention_bias, - pos_embd_params=pos_embd_params, - layer_idx=layer_idx, - dtype=config.torch_dtype, - config=model_config) + super().__init__( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + max_position_embeddings=config.max_position_embeddings, + bias=config.attention_bias, + pos_embd_params=pos_embd_params, + layer_idx=layer_idx, + dtype=config.torch_dtype, + config=model_config, + qk_norm_type=QkNormType.post_rope + if use_qk_norm else QkNormType.none, + ) - if self.use_rope and self.use_qk_norm: - # here we must disable rope fusion regardless of attn_backend - self.enable_rope_fusion = False - self.rotary_emb = RotaryEmbedding( - RopeParams.from_config(config), - head_dim=self.head_dim, - is_neox=False, - ) - - if self.use_qk_norm: + if self.use_rope and use_qk_norm: self.head_dim = config.hidden_size // config.num_attention_heads self.qk_norm = RMSNorm(hidden_size=self.head_dim, eps=1e-6, dtype=config.torch_dtype, has_weights=False) - else: - self.qk_norm = None + self.aux_stream = aux_stream + self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] self.attn_temperature_tuning = attn_temperature_tuning and nope_layer self.floor_scale = getattr(config, "floor_scale", 8192.0) self.attn_scale = getattr(config, "attn_scale", 0.1) - def _attn_qkv( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - attn_metadata: AttentionMetadata, - attention_mask: PredefinedAttentionMask = PredefinedAttentionMask. - CAUSAL, - mrope_config: Optional[dict] = None, - all_reduce_params: Optional[AllReduceParams] = None): - out_scale = None - if self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 or self.o_proj.has_fp8_block_scales: - out_scale = self.o_proj.inv_input_scale + def apply_qk_norm(self, q, k): - q, k, v = self.convert_qkv(q, k, v) - attn_output = self.attn.forward(q, - k, - v, - attn_metadata, - out_scale=out_scale, - attention_mask=attention_mask, - mrope_config=mrope_config) + def q_l2norm(): + return self.qk_norm(q.reshape(-1, self.head_dim)).reshape( + -1, self.q_size) - attn_output = self.o_proj(attn_output, - all_reduce_params=all_reduce_params) + def k_l2norm(): + return self.qk_norm(k.reshape(-1, self.head_dim)).reshape( + -1, self.kv_size) - return attn_output - - def _qk_norm(self, q, k): - # TODO: make this more efficient. - q_l2norm = lambda: self.qk_norm(q.reshape(-1, self.head_dim)).reshape( - -1, self.q_size) - k_l2norm = lambda: self.qk_norm(k.reshape(-1, self.head_dim)).reshape( - -1, self.kv_size) q, k = maybe_execute_in_parallel( q_l2norm, k_l2norm, @@ -155,31 +119,6 @@ class Llama4Attention(Attention): q = (q * attn_scale).to(q.dtype) return q - def _forward_rope( - self, - position_ids: Optional[torch.LongTensor], - hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, - attention_mask: PredefinedAttentionMask = PredefinedAttentionMask. - CAUSAL, - mrope_config: Optional[dict] = None, - all_reduce_params: Optional[AllReduceParams] = None, - ): - if self.use_qk_norm: - qkv = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], - dim=-1) - assert self.rotary_emb is not None and not self.enable_rope_fusion, "qk_norm requires attention rope fusion disabled" - q, k = self.rotary_emb(position_ids, [q, k]) - q, k = self._qk_norm(q, k) - return self._attn_qkv(q, k, v, attn_metadata, attention_mask, - mrope_config, all_reduce_params) - else: - # When qk_norm is disabled, use the classic attention path that handles RoPE fusion - return super().forward(position_ids, hidden_states, attn_metadata, - attention_mask, mrope_config, - all_reduce_params) - def _forward_nope( self, position_ids: Optional[torch.LongTensor], @@ -191,11 +130,26 @@ class Llama4Attention(Attention): all_reduce_params: Optional[AllReduceParams] = None, ): qkv = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k, v = self.split_qkv(qkv) if self.attn_temperature_tuning: q = self._attention_scaling(q, position_ids) - return self._attn_qkv(q, k, v, attn_metadata, attention_mask, - mrope_config, all_reduce_params) + out_scale = None + if self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 or self.o_proj.has_fp8_block_scales: + out_scale = self.o_proj.inv_input_scale + + q, k, v = self.convert_qkv(q, k, v) + attn_output = self.attn.forward(q, + k, + v, + attn_metadata, + out_scale=out_scale, + attention_mask=attention_mask, + mrope_config=mrope_config) + + attn_output = self.o_proj(attn_output, + all_reduce_params=all_reduce_params) + + return attn_output def forward( self, @@ -211,9 +165,16 @@ class Llama4Attention(Attention): ) -> torch.Tensor: assert lora_params is None, "LORA is not supported for Llama4Attention" if self.use_rope: - return self._forward_rope(position_ids, hidden_states, - attn_metadata, attention_mask, - mrope_config, all_reduce_params) + return super().forward( + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + attention_mask=attention_mask, + mrope_config=mrope_config, + all_reduce_params=all_reduce_params, + lora_params=lora_params, + **kwargs, + ) else: return self._forward_nope(position_ids, hidden_states, attn_metadata, attention_mask, diff --git a/tensorrt_llm/_torch/models/modeling_qwen3.py b/tensorrt_llm/_torch/models/modeling_qwen3.py index 34d2357f39..1361fc6ae5 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3.py @@ -9,11 +9,12 @@ from tensorrt_llm.functional import PositionEmbeddingType from ..attention_backend import AttentionMetadata from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams from ..model_config import ModelConfig -from ..modules.attention import Attention +from ..modules.attention import Attention, QkNormType from ..modules.decoder_layer import DecoderLayer from ..modules.embedding import Embedding from ..modules.gated_mlp import GatedMLP from ..modules.linear import TensorParallelMode +from ..modules.multi_stream_utils import maybe_execute_in_parallel from ..modules.rms_norm import RMSNorm from ..pipeline_interface import PipelineInterface from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, @@ -51,6 +52,7 @@ class Qwen3Attention(Attention): dtype=config.torch_dtype, dense_bias=config.attention_bias, config=model_config, + qk_norm_type=QkNormType.pre_rope, ) self.q_norm = RMSNorm(hidden_size=self.head_dim, @@ -64,6 +66,26 @@ class Qwen3Attention(Attention): self.aux_stream = torch.cuda.Stream() self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] + def apply_qk_norm(self, q, k): + + def q_l2norm(): + return self.q_norm(q.reshape(-1, self.head_dim)).reshape( + -1, self.q_size) + + def k_l2norm(): + return self.k_norm(k.reshape(-1, self.head_dim)).reshape( + -1, self.kv_size) + + q, k = maybe_execute_in_parallel( + q_l2norm, + k_l2norm, + self.ln_events[0], + self.ln_events[1], + self.aux_stream, + ) + + return q, k + class Qwen3DecoderLayer(DecoderLayer): diff --git a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py index f6d1e504e4..e232cacb70 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py @@ -5,17 +5,14 @@ from torch import nn from tqdm import tqdm from transformers import Qwen3MoeConfig -from tensorrt_llm.functional import PositionEmbeddingType - from ..attention_backend import AttentionMetadata -from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams from ..model_config import ModelConfig -from ..modules.attention import Attention from ..modules.decoder_layer import DecoderLayer from ..modules.embedding import Embedding from ..modules.fused_moe import FusedMoE, RenormalizeMoeRoutingMethod from ..modules.linear import Linear, TensorParallelMode from ..modules.rms_norm import RMSNorm +from .modeling_qwen3 import Qwen3Attention from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, duplicate_kv_weight, register_auto_model) @@ -81,57 +78,13 @@ class Qwen3MoE(nn.Module): return final_hidden_states.view(orig_shape) -class Qwen3MoEAttention(Attention): - - def __init__( - self, - model_config: ModelConfig[Qwen3MoeConfig], - layer_idx: Optional[int] = None, - ): - config = model_config.pretrained_config - if getattr(config, "rope_scaling", None) is not None: - pos_embd_params = PositionalEmbeddingParams( - type=PositionEmbeddingType.from_string( - config.rope_scaling["type"]), - rope=RopeParams.from_config(config), - ) - else: - pos_embd_params = PositionalEmbeddingParams( - type=PositionEmbeddingType.rope_gpt_neox, - rope=RopeParams.from_config(config), - ) - super().__init__( - hidden_size=config.hidden_size, - num_attention_heads=config.num_attention_heads, - num_key_value_heads=config.num_key_value_heads, - max_position_embeddings=config.max_position_embeddings, - bias=config.attention_bias, - pos_embd_params=pos_embd_params, - layer_idx=layer_idx, - dtype=config.torch_dtype, - dense_bias=config.attention_bias, - config=model_config, - ) - - self.q_norm = RMSNorm(hidden_size=self.head_dim, - eps=1e-6, - dtype=config.torch_dtype, - has_weights=True) - self.k_norm = RMSNorm(hidden_size=self.head_dim, - eps=1e-6, - dtype=config.torch_dtype, - has_weights=True) - self.aux_stream = torch.cuda.Stream() - self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] - - class Qwen3MoEDecoderLayer(DecoderLayer): def __init__(self, model_config: ModelConfig[Qwen3MoeConfig], layer_idx: int, aux_stream: torch.cuda.Stream): super().__init__() config = model_config.pretrained_config - self.self_attn = Qwen3MoEAttention( + self.self_attn = Qwen3Attention( model_config, layer_idx=layer_idx, ) diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 7fa18f5f6b..88ff963e7e 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -1,4 +1,5 @@ import math +from enum import IntEnum from typing import Optional import torch @@ -19,6 +20,12 @@ from .rms_norm import RMSNorm from .rotary_embedding import RotaryEmbedding +class QkNormType(IntEnum): + none = 0 + pre_rope = 1 + post_rope = 2 + + class Attention(nn.Module): def __init__( @@ -34,6 +41,7 @@ class Attention(nn.Module): dtype: torch.dtype = None, dense_bias: Optional[bool] = None, config: Optional[ModelConfig] = None, + qk_norm_type: QkNormType = QkNormType.none, ): super().__init__() self.layer_idx = layer_idx @@ -41,15 +49,13 @@ class Attention(nn.Module): config = config or ModelConfig() self.hidden_size = hidden_size self.num_heads = num_attention_heads - if config: - self.head_dim = getattr(config.pretrained_config, "head_dim", - self.hidden_size // self.num_heads) - else: - self.head_dim = self.hidden_size // self.num_heads + self.head_dim = getattr(config.pretrained_config, "head_dim", + self.hidden_size // self.num_heads) self.num_key_value_heads = num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = max_position_embeddings self.pos_embd_params = pos_embd_params + self.qk_norm_type = qk_norm_type self.dense_bias = dense_bias if dense_bias is None: @@ -119,12 +125,9 @@ class Attention(nn.Module): self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE], [self.hidden_size]) - use_qk_norm = (config.pretrained_config and - (config.pretrained_config.model_type == 'qwen3' - or config.pretrained_config.model_type == 'qwen3_moe')) attn_cls = get_attention_backend(self.attn_backend) self.enable_rope_fusion = attn_cls.support_fused_rope( - ) and not use_qk_norm + ) and qk_norm_type != QkNormType.post_rope self.attn = create_attention( self.attn_backend, self.layer_idx, @@ -157,9 +160,14 @@ class Attention(nn.Module): # which could be modified after __init__ self.attn.update_quant_config(self.quant_config) + def split_qkv(self, q, k=None, v=None): + if k is None and v is None: + q, k, v = q.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + return q, k, v + def convert_qkv(self, q, k, v): if k is None and v is None and not self.support_fused_qkv: - q, k, v = q.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k, v = self.split_qkv(q) elif k is not None and v is not None and self.support_fused_qkv: qkv = torch.concat([q, k, v], dim=-1) q, k, v = qkv, None, None @@ -191,32 +199,14 @@ class Attention(nn.Module): qkv = qkv + qkv_lora q, k, v = qkv, None, None + if self.qk_norm_type == QkNormType.pre_rope: + q, k, v = self.split_qkv(q, k, v) + q, k = self.apply_qk_norm(q, k) if self.apply_rotary_emb and position_ids is not None: - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], - dim=-1) - if hasattr(self, 'q_norm') and hasattr(self, 'k_norm'): - # Add qk-norm - if hasattr(self, 'ln_events'): - q_l2norm = lambda: self.q_norm(q.reshape(-1, self.head_dim) - ).reshape(-1, self.q_size) - k_l2norm = lambda: self.k_norm(k.reshape(-1, self.head_dim) - ).reshape(-1, self.kv_size) - q, k = maybe_execute_in_parallel( - q_l2norm, - k_l2norm, - self.ln_events[0], - self.ln_events[1], - self.aux_stream, - ) - else: - q_by_head = q.reshape(-1, self.head_dim) - q_by_head = self.q_norm(q_by_head) - q = q_by_head.view(q.shape) - k_by_head = k.reshape(-1, self.head_dim) - k_by_head = self.k_norm(k_by_head) - k = k_by_head.view(k.shape) - + q, k, v = self.split_qkv(q, k, v) q, k = self.rotary_emb(position_ids, [q, k]) + if self.qk_norm_type == QkNormType.post_rope: + q, k = self.apply_qk_norm(q, k) out_scale = None if self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 or self.o_proj.has_fp8_block_scales: @@ -237,6 +227,11 @@ class Attention(nn.Module): layer_idx=self.layer_idx) return attn_output + def apply_qk_norm(self, q, k): + raise NotImplementedError( + f"QK norm is not implemented for {self.__class__.__name__}." + "Please override the `apply_qk_norm` method in the subclass.") + class MLA(nn.Module):