mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
feat: Integration of Fused QKNorm+RoPE. (#4611)
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
This commit is contained in:
parent
6493401986
commit
9c4b8f66b4
@ -53,11 +53,11 @@ class Gemma3Attention(Attention):
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
bias=False,
|
||||
pos_embd_params=pos_embd_params,
|
||||
qk_norm_type=QkNormType.pre_rope,
|
||||
layer_idx=layer_idx,
|
||||
dtype=config.torch_dtype,
|
||||
dense_bias=False,
|
||||
config=model_config,
|
||||
qk_norm_type=QkNormType.pre_rope,
|
||||
q_scaling=q_scaling,
|
||||
)
|
||||
self.q_norm = RMSNorm(hidden_size=config.head_dim,
|
||||
|
||||
@ -73,11 +73,11 @@ class Llama4Attention(Attention):
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
bias=config.attention_bias,
|
||||
pos_embd_params=pos_embd_params,
|
||||
qk_norm_type=QkNormType.post_rope
|
||||
if use_qk_norm else QkNormType.none,
|
||||
layer_idx=layer_idx,
|
||||
dtype=config.torch_dtype,
|
||||
config=model_config,
|
||||
qk_norm_type=QkNormType.post_rope
|
||||
if use_qk_norm else QkNormType.none,
|
||||
attention_chunk_size=attention_chunk_size,
|
||||
)
|
||||
|
||||
|
||||
@ -26,8 +26,10 @@ class Qwen3Attention(Attention):
|
||||
self,
|
||||
model_config: ModelConfig[Qwen3Config],
|
||||
layer_idx: Optional[int] = None,
|
||||
fuse_qk_norm_rope: bool = True,
|
||||
):
|
||||
config = model_config.pretrained_config
|
||||
|
||||
if getattr(config, "rope_scaling", None) is not None:
|
||||
pos_embd_params = PositionalEmbeddingParams(
|
||||
type=PositionEmbeddingType.from_string(
|
||||
@ -40,20 +42,27 @@ class Qwen3Attention(Attention):
|
||||
rope=RopeParams.from_config(config),
|
||||
)
|
||||
|
||||
self.fuse_qk_norm_rope = fuse_qk_norm_rope
|
||||
|
||||
super().__init__(
|
||||
hidden_size=config.hidden_size,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
bias=config.attention_bias,
|
||||
pos_embd_params=pos_embd_params,
|
||||
pos_embd_params=pos_embd_params
|
||||
if not self.fuse_qk_norm_rope else None,
|
||||
qk_norm_type=QkNormType.pre_rope,
|
||||
layer_idx=layer_idx,
|
||||
dtype=config.torch_dtype,
|
||||
dense_bias=config.attention_bias,
|
||||
config=model_config,
|
||||
qk_norm_type=QkNormType.pre_rope,
|
||||
)
|
||||
|
||||
# If fuse_qk_norm_rope is true, we pass pos_embd_params=None to super().__init__,
|
||||
# so we need to do assignment to record the actual pos_embd_params.
|
||||
self.pos_embd_params = pos_embd_params
|
||||
|
||||
self.q_norm = RMSNorm(hidden_size=self.head_dim,
|
||||
eps=1e-6,
|
||||
dtype=config.torch_dtype,
|
||||
@ -85,6 +94,21 @@ class Qwen3Attention(Attention):
|
||||
|
||||
return q, k
|
||||
|
||||
def apply_rope(self, qkv: torch.Tensor, position_ids: torch.Tensor):
|
||||
if not self.fuse_qk_norm_rope:
|
||||
return super().apply_rope(qkv, position_ids)
|
||||
else:
|
||||
return self.apply_qk_norm_rope(qkv, position_ids)
|
||||
|
||||
def apply_qk_norm_rope(self, qkv, position_ids):
|
||||
torch.ops.trtllm.fused_qk_norm_rope(
|
||||
qkv, self.num_heads, self.num_key_value_heads,
|
||||
self.num_key_value_heads, self.head_dim,
|
||||
self.q_norm.variance_epsilon, self.q_norm.weight,
|
||||
self.k_norm.weight, self.pos_embd_params.rope.theta,
|
||||
self.pos_embd_params.is_neox, position_ids.view(-1))
|
||||
return qkv, None, None
|
||||
|
||||
|
||||
class Qwen3DecoderLayer(DecoderLayer):
|
||||
|
||||
|
||||
@ -43,11 +43,11 @@ class Attention(nn.Module):
|
||||
max_position_embeddings: int,
|
||||
bias: bool,
|
||||
pos_embd_params: Optional[PositionalEmbeddingParams] = None,
|
||||
qk_norm_type: QkNormType = QkNormType.none,
|
||||
layer_idx: Optional[int] = None,
|
||||
dtype: torch.dtype = None,
|
||||
dense_bias: Optional[bool] = None,
|
||||
config: Optional[ModelConfig] = None,
|
||||
qk_norm_type: QkNormType = QkNormType.none,
|
||||
q_scaling: float = 1.0,
|
||||
attention_chunk_size: Optional[int] = None,
|
||||
):
|
||||
@ -61,11 +61,11 @@ class Attention(nn.Module):
|
||||
max_position_embeddings (int): The maximum position embeddings.
|
||||
bias (bool): Whether to use bias in the linear layers.
|
||||
pos_embd_params (PositionalEmbeddingParams): The positional embedding parameters.
|
||||
qk_norm_type (QkNormType): The type of QK normalization.
|
||||
layer_idx (int): The layer index.
|
||||
dtype (torch.dtype): The data type.
|
||||
dense_bias (bool): Whether to use bias in the output projection layer.
|
||||
config (ModelConfig): The model configuration.
|
||||
qk_norm_type (QkNormType): The type of QK normalization.
|
||||
q_scaling (float): The scaling factor for the qk_scale. The definition is $O = softmax(QK^T * qk_scale) * V, qk_scale = 1 / (sqrt(head_dim) * q_scaling)$. The default value is 1.0.
|
||||
attention_chunk_size (int): See [Chunked Attention] below.
|
||||
"""
|
||||
@ -154,7 +154,7 @@ class Attention(nn.Module):
|
||||
|
||||
self.quant_config = config.get_quant_config()
|
||||
self.attn_backend = config.attn_backend
|
||||
self.pos_embd_params = pos_embd_params
|
||||
attn_cls = get_attention_backend(self.attn_backend)
|
||||
|
||||
# These two modules are mutually exclusive - either splitted_qkv_lora or fused_qkv_lora will be used,
|
||||
# but never both at the same time. splitted_qkv_lora handles Q,K,V separately while fused_qkv_lora
|
||||
@ -169,9 +169,20 @@ class Attention(nn.Module):
|
||||
self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
|
||||
[self.hidden_size])
|
||||
|
||||
attn_cls = get_attention_backend(self.attn_backend)
|
||||
# enable_rope_fusion: Whether to fuse RoPE into the attention OP.
|
||||
# If true, RoPE will be applied in self.attn.forward.
|
||||
# If false, RoPE will be applied in self.apply_rope.
|
||||
self.enable_rope_fusion = attn_cls.support_fused_rope(
|
||||
) and qk_norm_type != QkNormType.post_rope
|
||||
) and self.qk_norm_type != QkNormType.post_rope
|
||||
|
||||
self.rotary_emb = None
|
||||
if not self.enable_rope_fusion and self.pos_embd_params is not None:
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
self.pos_embd_params.rope,
|
||||
head_dim=self.head_dim,
|
||||
is_neox=self.pos_embd_params.is_neox,
|
||||
)
|
||||
|
||||
self.attn = create_attention(
|
||||
self.attn_backend,
|
||||
self.layer_idx,
|
||||
@ -188,16 +199,6 @@ class Attention(nn.Module):
|
||||
|
||||
self.support_fused_qkv = self.attn.support_fused_qkv()
|
||||
|
||||
self.rotary_emb = None
|
||||
self.apply_rotary_emb = (not self.enable_rope_fusion
|
||||
and pos_embd_params is not None)
|
||||
if self.apply_rotary_emb:
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
pos_embd_params.rope,
|
||||
head_dim=self.head_dim,
|
||||
is_neox=pos_embd_params.is_neox,
|
||||
)
|
||||
|
||||
if not config.skip_create_weights_in_init:
|
||||
self.create_weights()
|
||||
|
||||
@ -261,17 +262,9 @@ class Attention(nn.Module):
|
||||
if qkv_lora is not None:
|
||||
qkv = qkv + qkv_lora
|
||||
|
||||
q, k, v = qkv, None, None
|
||||
if self.qk_norm_type == QkNormType.pre_rope:
|
||||
q, k, v = self.split_qkv(q, k, v)
|
||||
q, k = self.apply_qk_norm(q, k)
|
||||
if self.apply_rotary_emb and position_ids is not None:
|
||||
q, k, v = self.split_qkv(q, k, v)
|
||||
q, k = self.rotary_emb(position_ids, [q, k])
|
||||
if self.qk_norm_type == QkNormType.post_rope:
|
||||
q, k = self.apply_qk_norm(q, k)
|
||||
out_scale = None
|
||||
q, k, v = self.apply_rope(qkv, position_ids)
|
||||
|
||||
out_scale = None
|
||||
if self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 or self.o_proj.has_fp8_block_scales:
|
||||
out_scale = self.o_proj.inv_input_scale
|
||||
|
||||
@ -297,6 +290,29 @@ class Attention(nn.Module):
|
||||
f"QK norm is not implemented for {self.__class__.__name__}."
|
||||
"Please override the `apply_qk_norm` method in the subclass.")
|
||||
|
||||
def apply_rope(self, qkv: torch.Tensor, position_ids: torch.Tensor):
|
||||
"""
|
||||
Apply RoPE to the query and key, possibly including QK norm.
|
||||
Args:
|
||||
qkv (torch.Tensor): The query, key, and value tensor.
|
||||
position_ids (torch.Tensor): The position IDs of each token for RoPE.
|
||||
Returns:
|
||||
tuple: A tuple of (q, k, v).
|
||||
This method could be overridden in the subclass, it is possible that k/v is None and q is the concatenated qkv tensor, up to the implementation.
|
||||
Before self.attn.forward, convert_qkv will be called to make sure that the format of (q, k, v) satisfies the requirement of self.attn.
|
||||
"""
|
||||
q, k, v = qkv, None, None
|
||||
if self.qk_norm_type == QkNormType.pre_rope:
|
||||
q, k, v = self.split_qkv(q, k, v)
|
||||
q, k = self.apply_qk_norm(q, k)
|
||||
if not self.enable_rope_fusion and position_ids is not None:
|
||||
q, k, v = self.split_qkv(q, k, v)
|
||||
q, k = self.rotary_emb(position_ids, [q, k])
|
||||
if self.qk_norm_type == QkNormType.post_rope:
|
||||
q, k = self.apply_qk_norm(q, k)
|
||||
|
||||
return q, k, v
|
||||
|
||||
|
||||
def extract_extra_attrs(layer_idx: str):
|
||||
extra_attrs = get_model_extra_attrs()
|
||||
|
||||
@ -52,10 +52,10 @@ class DecodingCUDAGraphRunner:
|
||||
# Using ones instead of zeros prevents NaNs in e.g. Deepseek
|
||||
self.input_ids = torch.ones((batch_size * token_per_request, ),
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
dtype=torch.int32)
|
||||
self.position_ids = torch.zeros((1, batch_size * token_per_request),
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
dtype=torch.int32)
|
||||
|
||||
self.extra_model_inputs = {}
|
||||
self.attn_metadata = attn_metadata
|
||||
|
||||
Loading…
Reference in New Issue
Block a user