mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
refactor: Allow models to override apply_qk_norm. (#4078)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
parent
c9e2a963e0
commit
b35f9a67f9
@ -22,7 +22,7 @@ from ..attention_backend import AttentionMetadata
|
||||
from ..attention_backend.interface import (PositionalEmbeddingParams,
|
||||
PredefinedAttentionMask, RopeParams)
|
||||
from ..model_config import ModelConfig
|
||||
from ..modules.attention import Attention
|
||||
from ..modules.attention import Attention, QkNormType
|
||||
from ..modules.decoder_layer import DecoderLayer
|
||||
from ..modules.embedding import Embedding
|
||||
from ..modules.fused_moe import (FusedMoE, Llama4RenormalizeMoeRoutingMethod,
|
||||
@ -32,7 +32,6 @@ from ..modules.linear import (Linear, TensorParallelMode, WeightMode,
|
||||
WeightsLoadingConfig)
|
||||
from ..modules.multi_stream_utils import maybe_execute_in_parallel
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from ..modules.rotary_embedding import RotaryEmbedding
|
||||
from ..speculative import Eagle3SpecMetadata, SpecMetadata
|
||||
from .modeling_multimodal_utils import fuse_input_embeds
|
||||
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
|
||||
@ -53,86 +52,51 @@ class Llama4Attention(Attention):
|
||||
aux_stream: Optional[torch.cuda.Stream] = None,
|
||||
):
|
||||
config = model_config.pretrained_config
|
||||
self.aux_stream = aux_stream
|
||||
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
|
||||
|
||||
self.use_rope = not nope_layer
|
||||
self.use_qk_norm = use_qk_norm and not nope_layer
|
||||
if self.use_rope and not self.use_qk_norm:
|
||||
pos_embd_params = PositionalEmbeddingParams(
|
||||
type=PositionEmbeddingType.rope_gptj,
|
||||
rope=RopeParams.from_config(config),
|
||||
is_neox=False,
|
||||
)
|
||||
else:
|
||||
pos_embd_params = None
|
||||
pos_embd_params = PositionalEmbeddingParams(
|
||||
type=PositionEmbeddingType.rope_gptj,
|
||||
rope=RopeParams.from_config(config),
|
||||
is_neox=False,
|
||||
) if self.use_rope else None
|
||||
|
||||
super().__init__(hidden_size=config.hidden_size,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
bias=config.attention_bias,
|
||||
pos_embd_params=pos_embd_params,
|
||||
layer_idx=layer_idx,
|
||||
dtype=config.torch_dtype,
|
||||
config=model_config)
|
||||
super().__init__(
|
||||
hidden_size=config.hidden_size,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
bias=config.attention_bias,
|
||||
pos_embd_params=pos_embd_params,
|
||||
layer_idx=layer_idx,
|
||||
dtype=config.torch_dtype,
|
||||
config=model_config,
|
||||
qk_norm_type=QkNormType.post_rope
|
||||
if use_qk_norm else QkNormType.none,
|
||||
)
|
||||
|
||||
if self.use_rope and self.use_qk_norm:
|
||||
# here we must disable rope fusion regardless of attn_backend
|
||||
self.enable_rope_fusion = False
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
RopeParams.from_config(config),
|
||||
head_dim=self.head_dim,
|
||||
is_neox=False,
|
||||
)
|
||||
|
||||
if self.use_qk_norm:
|
||||
if self.use_rope and use_qk_norm:
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
self.qk_norm = RMSNorm(hidden_size=self.head_dim,
|
||||
eps=1e-6,
|
||||
dtype=config.torch_dtype,
|
||||
has_weights=False)
|
||||
else:
|
||||
self.qk_norm = None
|
||||
self.aux_stream = aux_stream
|
||||
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
|
||||
|
||||
self.attn_temperature_tuning = attn_temperature_tuning and nope_layer
|
||||
self.floor_scale = getattr(config, "floor_scale", 8192.0)
|
||||
self.attn_scale = getattr(config, "attn_scale", 0.1)
|
||||
|
||||
def _attn_qkv(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.
|
||||
CAUSAL,
|
||||
mrope_config: Optional[dict] = None,
|
||||
all_reduce_params: Optional[AllReduceParams] = None):
|
||||
out_scale = None
|
||||
if self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 or self.o_proj.has_fp8_block_scales:
|
||||
out_scale = self.o_proj.inv_input_scale
|
||||
def apply_qk_norm(self, q, k):
|
||||
|
||||
q, k, v = self.convert_qkv(q, k, v)
|
||||
attn_output = self.attn.forward(q,
|
||||
k,
|
||||
v,
|
||||
attn_metadata,
|
||||
out_scale=out_scale,
|
||||
attention_mask=attention_mask,
|
||||
mrope_config=mrope_config)
|
||||
def q_l2norm():
|
||||
return self.qk_norm(q.reshape(-1, self.head_dim)).reshape(
|
||||
-1, self.q_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output,
|
||||
all_reduce_params=all_reduce_params)
|
||||
def k_l2norm():
|
||||
return self.qk_norm(k.reshape(-1, self.head_dim)).reshape(
|
||||
-1, self.kv_size)
|
||||
|
||||
return attn_output
|
||||
|
||||
def _qk_norm(self, q, k):
|
||||
# TODO: make this more efficient.
|
||||
q_l2norm = lambda: self.qk_norm(q.reshape(-1, self.head_dim)).reshape(
|
||||
-1, self.q_size)
|
||||
k_l2norm = lambda: self.qk_norm(k.reshape(-1, self.head_dim)).reshape(
|
||||
-1, self.kv_size)
|
||||
q, k = maybe_execute_in_parallel(
|
||||
q_l2norm,
|
||||
k_l2norm,
|
||||
@ -155,31 +119,6 @@ class Llama4Attention(Attention):
|
||||
q = (q * attn_scale).to(q.dtype)
|
||||
return q
|
||||
|
||||
def _forward_rope(
|
||||
self,
|
||||
position_ids: Optional[torch.LongTensor],
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.
|
||||
CAUSAL,
|
||||
mrope_config: Optional[dict] = None,
|
||||
all_reduce_params: Optional[AllReduceParams] = None,
|
||||
):
|
||||
if self.use_qk_norm:
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
|
||||
dim=-1)
|
||||
assert self.rotary_emb is not None and not self.enable_rope_fusion, "qk_norm requires attention rope fusion disabled"
|
||||
q, k = self.rotary_emb(position_ids, [q, k])
|
||||
q, k = self._qk_norm(q, k)
|
||||
return self._attn_qkv(q, k, v, attn_metadata, attention_mask,
|
||||
mrope_config, all_reduce_params)
|
||||
else:
|
||||
# When qk_norm is disabled, use the classic attention path that handles RoPE fusion
|
||||
return super().forward(position_ids, hidden_states, attn_metadata,
|
||||
attention_mask, mrope_config,
|
||||
all_reduce_params)
|
||||
|
||||
def _forward_nope(
|
||||
self,
|
||||
position_ids: Optional[torch.LongTensor],
|
||||
@ -191,11 +130,26 @@ class Llama4Attention(Attention):
|
||||
all_reduce_params: Optional[AllReduceParams] = None,
|
||||
):
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k, v = self.split_qkv(qkv)
|
||||
if self.attn_temperature_tuning:
|
||||
q = self._attention_scaling(q, position_ids)
|
||||
return self._attn_qkv(q, k, v, attn_metadata, attention_mask,
|
||||
mrope_config, all_reduce_params)
|
||||
out_scale = None
|
||||
if self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 or self.o_proj.has_fp8_block_scales:
|
||||
out_scale = self.o_proj.inv_input_scale
|
||||
|
||||
q, k, v = self.convert_qkv(q, k, v)
|
||||
attn_output = self.attn.forward(q,
|
||||
k,
|
||||
v,
|
||||
attn_metadata,
|
||||
out_scale=out_scale,
|
||||
attention_mask=attention_mask,
|
||||
mrope_config=mrope_config)
|
||||
|
||||
attn_output = self.o_proj(attn_output,
|
||||
all_reduce_params=all_reduce_params)
|
||||
|
||||
return attn_output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -211,9 +165,16 @@ class Llama4Attention(Attention):
|
||||
) -> torch.Tensor:
|
||||
assert lora_params is None, "LORA is not supported for Llama4Attention"
|
||||
if self.use_rope:
|
||||
return self._forward_rope(position_ids, hidden_states,
|
||||
attn_metadata, attention_mask,
|
||||
mrope_config, all_reduce_params)
|
||||
return super().forward(
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
attention_mask=attention_mask,
|
||||
mrope_config=mrope_config,
|
||||
all_reduce_params=all_reduce_params,
|
||||
lora_params=lora_params,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
return self._forward_nope(position_ids, hidden_states,
|
||||
attn_metadata, attention_mask,
|
||||
|
||||
@ -9,11 +9,12 @@ from tensorrt_llm.functional import PositionEmbeddingType
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
|
||||
from ..model_config import ModelConfig
|
||||
from ..modules.attention import Attention
|
||||
from ..modules.attention import Attention, QkNormType
|
||||
from ..modules.decoder_layer import DecoderLayer
|
||||
from ..modules.embedding import Embedding
|
||||
from ..modules.gated_mlp import GatedMLP
|
||||
from ..modules.linear import TensorParallelMode
|
||||
from ..modules.multi_stream_utils import maybe_execute_in_parallel
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from ..pipeline_interface import PipelineInterface
|
||||
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
|
||||
@ -51,6 +52,7 @@ class Qwen3Attention(Attention):
|
||||
dtype=config.torch_dtype,
|
||||
dense_bias=config.attention_bias,
|
||||
config=model_config,
|
||||
qk_norm_type=QkNormType.pre_rope,
|
||||
)
|
||||
|
||||
self.q_norm = RMSNorm(hidden_size=self.head_dim,
|
||||
@ -64,6 +66,26 @@ class Qwen3Attention(Attention):
|
||||
self.aux_stream = torch.cuda.Stream()
|
||||
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
|
||||
|
||||
def apply_qk_norm(self, q, k):
|
||||
|
||||
def q_l2norm():
|
||||
return self.q_norm(q.reshape(-1, self.head_dim)).reshape(
|
||||
-1, self.q_size)
|
||||
|
||||
def k_l2norm():
|
||||
return self.k_norm(k.reshape(-1, self.head_dim)).reshape(
|
||||
-1, self.kv_size)
|
||||
|
||||
q, k = maybe_execute_in_parallel(
|
||||
q_l2norm,
|
||||
k_l2norm,
|
||||
self.ln_events[0],
|
||||
self.ln_events[1],
|
||||
self.aux_stream,
|
||||
)
|
||||
|
||||
return q, k
|
||||
|
||||
|
||||
class Qwen3DecoderLayer(DecoderLayer):
|
||||
|
||||
|
||||
@ -5,17 +5,14 @@ from torch import nn
|
||||
from tqdm import tqdm
|
||||
from transformers import Qwen3MoeConfig
|
||||
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
|
||||
from ..model_config import ModelConfig
|
||||
from ..modules.attention import Attention
|
||||
from ..modules.decoder_layer import DecoderLayer
|
||||
from ..modules.embedding import Embedding
|
||||
from ..modules.fused_moe import FusedMoE, RenormalizeMoeRoutingMethod
|
||||
from ..modules.linear import Linear, TensorParallelMode
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from .modeling_qwen3 import Qwen3Attention
|
||||
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
|
||||
duplicate_kv_weight, register_auto_model)
|
||||
|
||||
@ -81,57 +78,13 @@ class Qwen3MoE(nn.Module):
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
|
||||
class Qwen3MoEAttention(Attention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig[Qwen3MoeConfig],
|
||||
layer_idx: Optional[int] = None,
|
||||
):
|
||||
config = model_config.pretrained_config
|
||||
if getattr(config, "rope_scaling", None) is not None:
|
||||
pos_embd_params = PositionalEmbeddingParams(
|
||||
type=PositionEmbeddingType.from_string(
|
||||
config.rope_scaling["type"]),
|
||||
rope=RopeParams.from_config(config),
|
||||
)
|
||||
else:
|
||||
pos_embd_params = PositionalEmbeddingParams(
|
||||
type=PositionEmbeddingType.rope_gpt_neox,
|
||||
rope=RopeParams.from_config(config),
|
||||
)
|
||||
super().__init__(
|
||||
hidden_size=config.hidden_size,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
bias=config.attention_bias,
|
||||
pos_embd_params=pos_embd_params,
|
||||
layer_idx=layer_idx,
|
||||
dtype=config.torch_dtype,
|
||||
dense_bias=config.attention_bias,
|
||||
config=model_config,
|
||||
)
|
||||
|
||||
self.q_norm = RMSNorm(hidden_size=self.head_dim,
|
||||
eps=1e-6,
|
||||
dtype=config.torch_dtype,
|
||||
has_weights=True)
|
||||
self.k_norm = RMSNorm(hidden_size=self.head_dim,
|
||||
eps=1e-6,
|
||||
dtype=config.torch_dtype,
|
||||
has_weights=True)
|
||||
self.aux_stream = torch.cuda.Stream()
|
||||
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
|
||||
|
||||
|
||||
class Qwen3MoEDecoderLayer(DecoderLayer):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[Qwen3MoeConfig],
|
||||
layer_idx: int, aux_stream: torch.cuda.Stream):
|
||||
super().__init__()
|
||||
config = model_config.pretrained_config
|
||||
self.self_attn = Qwen3MoEAttention(
|
||||
self.self_attn = Qwen3Attention(
|
||||
model_config,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import math
|
||||
from enum import IntEnum
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -19,6 +20,12 @@ from .rms_norm import RMSNorm
|
||||
from .rotary_embedding import RotaryEmbedding
|
||||
|
||||
|
||||
class QkNormType(IntEnum):
|
||||
none = 0
|
||||
pre_rope = 1
|
||||
post_rope = 2
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@ -34,6 +41,7 @@ class Attention(nn.Module):
|
||||
dtype: torch.dtype = None,
|
||||
dense_bias: Optional[bool] = None,
|
||||
config: Optional[ModelConfig] = None,
|
||||
qk_norm_type: QkNormType = QkNormType.none,
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
@ -41,15 +49,13 @@ class Attention(nn.Module):
|
||||
config = config or ModelConfig()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_attention_heads
|
||||
if config:
|
||||
self.head_dim = getattr(config.pretrained_config, "head_dim",
|
||||
self.hidden_size // self.num_heads)
|
||||
else:
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.head_dim = getattr(config.pretrained_config, "head_dim",
|
||||
self.hidden_size // self.num_heads)
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.pos_embd_params = pos_embd_params
|
||||
self.qk_norm_type = qk_norm_type
|
||||
self.dense_bias = dense_bias
|
||||
|
||||
if dense_bias is None:
|
||||
@ -119,12 +125,9 @@ class Attention(nn.Module):
|
||||
self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
|
||||
[self.hidden_size])
|
||||
|
||||
use_qk_norm = (config.pretrained_config and
|
||||
(config.pretrained_config.model_type == 'qwen3'
|
||||
or config.pretrained_config.model_type == 'qwen3_moe'))
|
||||
attn_cls = get_attention_backend(self.attn_backend)
|
||||
self.enable_rope_fusion = attn_cls.support_fused_rope(
|
||||
) and not use_qk_norm
|
||||
) and qk_norm_type != QkNormType.post_rope
|
||||
self.attn = create_attention(
|
||||
self.attn_backend,
|
||||
self.layer_idx,
|
||||
@ -157,9 +160,14 @@ class Attention(nn.Module):
|
||||
# which could be modified after __init__
|
||||
self.attn.update_quant_config(self.quant_config)
|
||||
|
||||
def split_qkv(self, q, k=None, v=None):
|
||||
if k is None and v is None:
|
||||
q, k, v = q.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
return q, k, v
|
||||
|
||||
def convert_qkv(self, q, k, v):
|
||||
if k is None and v is None and not self.support_fused_qkv:
|
||||
q, k, v = q.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k, v = self.split_qkv(q)
|
||||
elif k is not None and v is not None and self.support_fused_qkv:
|
||||
qkv = torch.concat([q, k, v], dim=-1)
|
||||
q, k, v = qkv, None, None
|
||||
@ -191,32 +199,14 @@ class Attention(nn.Module):
|
||||
qkv = qkv + qkv_lora
|
||||
|
||||
q, k, v = qkv, None, None
|
||||
if self.qk_norm_type == QkNormType.pre_rope:
|
||||
q, k, v = self.split_qkv(q, k, v)
|
||||
q, k = self.apply_qk_norm(q, k)
|
||||
if self.apply_rotary_emb and position_ids is not None:
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
|
||||
dim=-1)
|
||||
if hasattr(self, 'q_norm') and hasattr(self, 'k_norm'):
|
||||
# Add qk-norm
|
||||
if hasattr(self, 'ln_events'):
|
||||
q_l2norm = lambda: self.q_norm(q.reshape(-1, self.head_dim)
|
||||
).reshape(-1, self.q_size)
|
||||
k_l2norm = lambda: self.k_norm(k.reshape(-1, self.head_dim)
|
||||
).reshape(-1, self.kv_size)
|
||||
q, k = maybe_execute_in_parallel(
|
||||
q_l2norm,
|
||||
k_l2norm,
|
||||
self.ln_events[0],
|
||||
self.ln_events[1],
|
||||
self.aux_stream,
|
||||
)
|
||||
else:
|
||||
q_by_head = q.reshape(-1, self.head_dim)
|
||||
q_by_head = self.q_norm(q_by_head)
|
||||
q = q_by_head.view(q.shape)
|
||||
k_by_head = k.reshape(-1, self.head_dim)
|
||||
k_by_head = self.k_norm(k_by_head)
|
||||
k = k_by_head.view(k.shape)
|
||||
|
||||
q, k, v = self.split_qkv(q, k, v)
|
||||
q, k = self.rotary_emb(position_ids, [q, k])
|
||||
if self.qk_norm_type == QkNormType.post_rope:
|
||||
q, k = self.apply_qk_norm(q, k)
|
||||
out_scale = None
|
||||
|
||||
if self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 or self.o_proj.has_fp8_block_scales:
|
||||
@ -237,6 +227,11 @@ class Attention(nn.Module):
|
||||
layer_idx=self.layer_idx)
|
||||
return attn_output
|
||||
|
||||
def apply_qk_norm(self, q, k):
|
||||
raise NotImplementedError(
|
||||
f"QK norm is not implemented for {self.__class__.__name__}."
|
||||
"Please override the `apply_qk_norm` method in the subclass.")
|
||||
|
||||
|
||||
class MLA(nn.Module):
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user